-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpget.py
170 lines (141 loc) · 6.49 KB
/
pget.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import os
import subprocess
import time
import yaml
from tqdm import tqdm
from urllib.parse import urlparse
import fnmatch
SIZE_THRESHOLD = 50 # MB
CACHE_URI = "r2://weights" # either s3://bucket/path/ or gs://bucket/path
CDN = "https://weights.govpro.xyz"
def parse_dockerignore(fileobj):
return [line.strip() for line in fileobj if line.strip() and not line.startswith('#')]
def should_ignore(file_path, dockerignore_patterns):
# Ensure the file_path is relative to the current directory
rel_path = os.path.normpath(file_path)
if rel_path.startswith(os.path.sep):
rel_path = rel_path[1:]
return any(fnmatch.fnmatch(rel_path, pattern) for pattern in dockerignore_patterns)
def add_to_dockerignore(files):
with open('.dockerignore', 'a') as f:
for file in files:
f.write(f"\n{file}")
def make_manifest(manifest_filename: str = 'manifest.pget'):
large_files = []
# Load .dockerignore patterns
dockerignore_patterns = []
if os.path.exists('.dockerignore'):
with open('.dockerignore', 'r') as f:
dockerignore_patterns = parse_dockerignore(f)
# Step 1: Find all files larger than SIZE_THRESHOLD
for root, dirs, files in os.walk('.', topdown=True):
# Modify dirs in-place to exclude ignored directories
dirs[:] = [d for d in dirs if not should_ignore(os.path.relpath(os.path.join(root, d), '.'), dockerignore_patterns)]
for file in files:
filepath = os.path.join(root, file)
rel_filepath = os.path.relpath(filepath, '.')
if not should_ignore(rel_filepath, dockerignore_patterns):
try:
if os.path.getsize(filepath) > SIZE_THRESHOLD * 1024 * 1024:
large_files.append((filepath, os.path.getsize(filepath)))
except OSError as e:
print(f"Error accessing {filepath}: {e}")
# Step 2: List relative filepaths and their sizes
print("Large files found:")
for filepath, size in large_files:
print(f"{filepath}: {size / (1024 * 1024):.2f} MB")
# Step 3: Confirm with user
user_input = input("Please confirm you would like to cache these [Y/n]: ").strip().lower()
if user_input == 'n':
print("Ok, I won't generate a manifest at this time.")
return
# Step 4: Copy files to cache
if CACHE_URI.startswith('s3://'):
cp_command = ['aws', 's3', 'cp']
bucket = CACHE_URI
elif CACHE_URI.startswith('r2://'):
cp_command = ['aws', 's3', 'cp']
bucket = "s3://" + CACHE_URI[5:]
elif CACHE_URI.startswith('gs://'):
cp_command = ['gcloud', 'storage', 'cp']
bucket = CACHE_URI
else:
raise ValueError("Invalid CACHE_URI. Must start with 's3://' or 'gs://'")
for filepath, _ in tqdm(large_files, desc="Copying files to cache"):
dest_path = os.path.join(bucket, filepath.lstrip('./'))
if CACHE_URI.startswith('r2://'):
subprocess.run(cp_command + [filepath, dest_path, '--endpoint-url', 'https://3309f63723c6de8a36dab1a22068e3aa.r2.cloudflarestorage.com'], capture_output=True, text=True, check=True)
else:
subprocess.run(cp_command + [filepath, dest_path], check=True)
# Step 5: Generate manifest file
with open(manifest_filename, 'w') as f:
for filepath, _ in large_files:
if CDN:
parsed_uri = urlparse(CACHE_URI)
path = parsed_uri.path.strip('/')
url = f"{CDN.rstrip('/')}/{path}/{filepath.lstrip('./')}"
elif CACHE_URI.startswith('s3://'):
bucket, path = CACHE_URI[5:].split('/', 1)
url = f"https://{bucket}.s3.amazonaws.com/{path}/{filepath.lstrip('./')}"
else: # gs://
bucket, path = CACHE_URI[5:].split('/', 1)
url = f"https://storage.googleapis.com/{bucket}/{path}/{filepath.lstrip('./')}"
f.write(f"{url} {filepath}\n")
# Add cached files to .dockerignore
add_to_dockerignore([filepath for filepath, _ in large_files])
print("Added cached files to .dockerignore")
# Step 6: Update cog.yaml
with open('cog.yaml', 'r') as f:
cog_config = yaml.safe_load(f)
build_config = cog_config.get('build', {})
run_commands = build_config.get('run', [])
pget_commands = [
'curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/latest/download/pget_$(uname -s)_$(uname -m)"',
'chmod +x /usr/local/bin/pget'
]
if not all(cmd in run_commands for cmd in pget_commands):
run_commands.extend(pget_commands)
build_config['run'] = run_commands
cog_config['build'] = build_config
with open('cog.yaml', 'w') as f:
yaml.dump(cog_config, f)
print("Updated cog.yaml to install pget.")
# Step 7: Update predictor file
predict_config = cog_config.get('predict', '')
if predict_config:
predictor_file, predictor_class = predict_config.split(':')
with open(predictor_file, 'r') as f:
predictor_content = f.read()
if 'from pget import pget_manifest' not in predictor_content:
predictor_content = f"from pget import pget_manifest\n{predictor_content}"
if 'def setup(self):' in predictor_content:
predictor_content = predictor_content.replace(
'def setup(self):',
f"def setup(self):\n pget_manifest('{manifest_filename}')"
)
else:
predictor_content += f"\n def setup(self):\n pget_manifest('{manifest_filename}')\n"
with open(predictor_file, 'w') as f:
f.write(predictor_content)
print(f"Updated {predictor_file} to include pget_manifest in setup method.")
def pget_url(url: str, output_path: str):
return subprocess.check_call(["pget", url, output_path])
def pget_manifest(manifest_filename: str='manifest.pget'):
start = time.time()
with open(manifest_filename, 'r') as f:
manifest = f.read()
to_dl = []
# ensure directories exist
for line in manifest.splitlines():
_, path = line.split(" ")
os.makedirs(os.path.dirname(path), exist_ok=True)
if not os.path.exists(path):
to_dl.append(line)
# write new manifest
with open("tmp.pget", 'w') as f:
f.write("\n".join(to_dl))
# download using pget
subprocess.check_call(["pget", "multifile", "tmp.pget"])
# log metrics
timing = time.time() - start
print(f"Downloaded weights in {timing} seconds")