generated from fofr/cog-comfyui
-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathweights_downloader.py
85 lines (75 loc) · 3.03 KB
/
weights_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import subprocess
import time
import os
from weights_manifest import WeightsManifest
class WeightsDownloader:
supported_filetypes = [
".ckpt",
".safetensors",
".pt",
".pth",
".bin",
".onnx",
".torchscript",
".engine",
".patch"
]
def __init__(self):
self.weights_manifest = WeightsManifest()
self.weights_map = self.weights_manifest.weights_map
def get_weights_by_type(self, type):
return self.weights_manifest.get_weights_by_type(type)
def download_weights(self, weight_str):
if weight_str in self.weights_map:
if self.weights_manifest.is_non_commercial_only(weight_str):
print(
f"⚠️ {weight_str} is for non-commercial use only. Unless you have obtained a commercial license.\nDetails: https://github.com/fofr/cog-comfyui/blob/main/weights_licenses.md"
)
if isinstance(self.weights_map[weight_str], list):
for weight in self.weights_map[weight_str]:
self.download_if_not_exists(
weight_str, weight["url"], weight["dest"]
)
else:
self.download_if_not_exists(
weight_str,
self.weights_map[weight_str]["url"],
self.weights_map[weight_str]["dest"],
)
# else:
# raise ValueError(
# f"{weight_str} unavailable. View the list of available weights: https://github.com/fofr/cog-comfyui/blob/main/supported_weights.md"
# )
def check_if_file_exists(self, weight_str, dest):
if dest.endswith(weight_str):
path_string = dest
else:
path_string = os.path.join(dest, weight_str)
return os.path.exists(path_string)
def download_if_not_exists(self, weight_str, url, dest):
if self.check_if_file_exists(weight_str, dest):
print(f"✅ {weight_str} exists in {dest}")
return
WeightsDownloader.download(weight_str, url, dest)
@staticmethod
def download(weight_str, url, dest):
if "/" in weight_str:
subfolder = weight_str.rsplit("/", 1)[0]
dest = os.path.join(dest, subfolder)
os.makedirs(dest, exist_ok=True)
print(f"⏳ Downloading {weight_str} to {dest}")
start = time.time()
subprocess.check_call(
["pget", "--log-level", "warn", "-xf", url, dest], close_fds=False
)
elapsed_time = time.time() - start
try:
file_size_bytes = os.path.getsize(
os.path.join(dest, os.path.basename(weight_str))
)
file_size_megabytes = file_size_bytes / (1024 * 1024)
print(
f"✅ {weight_str} downloaded to {dest} in {elapsed_time:.2f}s, size: {file_size_megabytes:.2f}MB"
)
except FileNotFoundError:
print(f"✅ {weight_str} downloaded to {dest} in {elapsed_time:.2f}s")