-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathdiffusion_app.py
431 lines (368 loc) · 15.5 KB
/
diffusion_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
import streamlit as st
from pathlib import Path
import sys
import datetime
import shutil
import json
import os
import torch
import traceback
import base64
from PIL import Image
from typing import Optional
import argparse
sys.path.append("./taming-transformers")
import imageio
import numpy as np
from diffusion_logic import CLIPGuidedDiffusion, DIFFUSION_METHODS_AND_WEIGHTS
# Optional
try:
import git
except ModuleNotFoundError:
pass
def generate_image(
diffusion_weights: str,
prompt: str,
seed=0,
num_steps=500,
continue_prev_run=True,
init_image: Optional[Image.Image] = None,
skip_timesteps: int = 0,
use_cutout_augmentations: bool = False,
device: Optional[torch.device] = None,
) -> None:
### Init -------------------------------------------------------------------
run = CLIPGuidedDiffusion(
prompt=prompt,
ckpt=diffusion_weights,
seed=seed,
num_steps=num_steps,
continue_prev_run=continue_prev_run,
skip_timesteps=skip_timesteps,
use_cutout_augmentations=use_cutout_augmentations,
device=device,
)
# Generate random run ID
# Used to link runs linked w/ continue_prev_run
# ref: https://stackoverflow.com/a/42703382/13095028
# Use URL and filesystem safe version since we're using this as a folder name
run_id = st.session_state["run_id"] = base64.urlsafe_b64encode(
os.urandom(6)
).decode("ascii")
if "loaded_wt" not in st.session_state:
st.session_state["loaded_wt"] = None
run_start_dt = datetime.datetime.now()
### Load model -------------------------------------------------------------
if (
continue_prev_run
and ("model" in st.session_state)
and ("clip_model" in st.session_state)
and ("diffusion" in st.session_state)
and st.session_state["loaded_wt"] == diffusion_weights
):
run.load_model(
prev_model=st.session_state["model"],
prev_diffusion=st.session_state["diffusion"],
prev_clip_model=st.session_state["clip_model"],
)
else:
(
st.session_state["model"],
st.session_state["diffusion"],
st.session_state["clip_model"],
) = run.load_model(
model_file_loc="assets/"
+ DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method)
)
st.session_state["loaded_wt"] = diffusion_method
### Model init -------------------------------------------------------------
# if continue_prev_run is True:
# run.model_init(init_image=st.session_state["prev_im"])
# elif init_image is not None:
if init_image is not None:
run.model_init(init_image=init_image)
else:
run.model_init()
### Iterate ----------------------------------------------------------------
step_counter = 0 + skip_timesteps
frames = []
try:
# Try block catches st.script_runner.StopExecution, no need of a dedicated stop button
# Reason is st.form is meant to be self-contained either within sidebar, or in main body
# The way the form is implemented in this app splits the form across both regions
# This is intended to prevent the model settings from crowding the main body
# However, touching any button resets the app state, making it impossible to
# implement a stop button that can still dump output
# Thankfully there's a built-in stop button :)
while True:
# While loop to accomodate running predetermined steps or running indefinitely
status_text.text(f"Running step {step_counter}")
ims = run.iterate()
im = ims[0]
if num_steps > 0: # skip when num_steps = -1
step_progress_bar.progress((step_counter + 1) / num_steps)
else:
step_progress_bar.progress(100)
# At every step, display and save image
im_display_slot.image(im, caption="Output image", output_format="PNG")
st.session_state["prev_im"] = im
# ref: https://stackoverflow.com/a/33117447/13095028
# im_byte_arr = io.BytesIO()
# im.save(im_byte_arr, format="JPEG")
# frames.append(im_byte_arr.getvalue()) # read()
frames.append(np.asarray(im))
step_counter += 1
if (step_counter == num_steps) and num_steps > 0:
break
# Stitch into video using imageio
writer = imageio.get_writer("temp.mp4", fps=24)
for frame in frames:
writer.append_data(frame)
writer.close()
# Save to output folder if run completed
runoutputdir = outputdir / (
run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id
)
runoutputdir.mkdir()
im.save(runoutputdir / "output.PNG", format="PNG")
shutil.copy("temp.mp4", runoutputdir / "anim.mp4")
details = {
"run_id": run_id,
"diffusion_method": diffusion_method,
"ckpt": DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method),
"num_steps": step_counter,
"planned_num_steps": num_steps,
"text_input": prompt,
"continue_prev_run": continue_prev_run,
"seed": seed,
"Xdim": imsize,
"ydim": imsize,
"start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
"end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
}
if use_cutout_augmentations:
details["use_cutout_augmentations"] = True
if "git" in sys.modules:
try:
repo = git.Repo(search_parent_directories=True)
commit_sha = repo.head.object.hexsha
details["commit_sha"] = commit_sha[:6]
except Exception as e:
print("GitPython detected but not able to write commit SHA to file")
print(f"raised Exception {e}")
with open(runoutputdir / "details.json", "w") as f:
json.dump(details, f, indent=4)
status_text.text("Done!") # End of run
except st.script_runner.StopException as e:
# Dump output to dashboard
print(f"Received Streamlit StopException")
status_text.text("Execution interruped, dumping outputs ...")
writer = imageio.get_writer("temp.mp4", fps=24)
for frame in frames:
writer.append_data(frame)
writer.close()
# Save to output folder if run completed
runoutputdir = outputdir / (
run_start_dt.strftime("%Y%m%dT%H%M%S") + "-" + run_id
)
runoutputdir.mkdir()
im.save(runoutputdir / "output.PNG", format="PNG")
shutil.copy("temp.mp4", runoutputdir / "anim.mp4")
details = {
"run_id": run_id,
"diffusion_method": diffusion_method,
"ckpt": DIFFUSION_METHODS_AND_WEIGHTS.get(diffusion_method),
"num_steps": step_counter,
"planned_num_steps": num_steps,
"text_input": prompt,
"continue_prev_run": continue_prev_run,
"seed": seed,
"Xdim": imsize,
"ydim": imsize,
"start_time": run_start_dt.strftime("%Y%m%dT%H%M%S"),
"end_time": datetime.datetime.now().strftime("%Y%m%dT%H%M%S"),
}
if use_cutout_augmentations:
details["use_cutout_augmentations"] = True
if "git" in sys.modules:
try:
repo = git.Repo(search_parent_directories=True)
commit_sha = repo.head.object.hexsha
details["commit_sha"] = commit_sha[:6]
except Exception as e:
print("GitPython detected but not able to write commit SHA to file")
print(f"raised Exception {e}")
with open(runoutputdir / "details.json", "w") as f:
json.dump(details, f, indent=4)
status_text.text("Done!") # End of run
if __name__ == "__main__":
# Argparse to capture GPU num
parser = argparse.ArgumentParser()
parser.add_argument(
"--gpu", type=str, default=None, help="Specify GPU number. Defaults to None."
)
args = parser.parse_args()
# Select specific GPU if chosen
if args.gpu is not None:
for i in args.gpu.split(","):
assert (
int(i) < torch.cuda.device_count()
), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}"
try:
device = torch.device(f"cuda:{args.gpu}")
except RuntimeError:
print(traceback.format_exc())
else:
device = None
outputdir = Path("output")
if not outputdir.exists():
outputdir.mkdir()
st.set_page_config(page_title="CLIP guided diffusion playground")
st.title("CLIP guided diffusion playground")
# Determine what weights are available in `assets/`
weights_dir = Path("assets").resolve()
available_diffusion_weights = list(weights_dir.glob("*.pt"))
available_diffusion_weights = [i.name for i in available_diffusion_weights]
diffusion_weights_and_methods = {
j: i for i, j in DIFFUSION_METHODS_AND_WEIGHTS.items()
}
available_diffusion_methods = [
diffusion_weights_and_methods.get(i) for i in available_diffusion_weights
]
# i.e. no weights found, ask user to download weights
if len(available_diffusion_methods) == 0:
st.warning(
"No weights found, download diffusion weights in `download-diffusion-weights.sh`. "
)
st.stop()
# Start of input form
with st.form("form-inputs"):
# Only element not in the sidebar, but in the form
text_input = st.text_input(
"Text prompt",
help="CLIP-guided diffusion will generate an image that best fits the prompt",
)
diffusion_method = st.sidebar.radio(
"Method",
available_diffusion_methods,
index=0,
help="Choose diffusion image generation method, corresponding to the notebooks in Eleuther's repo",
)
if diffusion_method.startswith("256"):
image_size_notice = st.sidebar.text("Image size: fixed to 256x256")
imsize = 256
elif diffusion_method.startswith("512"):
image_size_notice = st.sidebar.text("Image size: fixed to 512x512")
imsize = 512
set_seed = st.sidebar.checkbox(
"Set seed",
value=0,
help="Check to set random seed for reproducibility. Will add option to specify seed",
)
num_steps = st.sidebar.number_input(
"Num steps",
value=1000,
min_value=0,
max_value=None,
step=1,
# help="Specify -1 to run indefinitely. Use Streamlit's stop button in the top right corner to terminate execution. The exception is caught so the most recent output will be dumped to dashboard",
)
seed_widget = st.sidebar.empty()
if set_seed is True:
seed = seed_widget.number_input("Seed", value=0, help="Random seed to use")
else:
seed = None
use_custom_reference_image = st.sidebar.checkbox(
"Use reference image",
value=False,
help="Check to add a reference image. The network will attempt to match the generated image to the provided reference",
)
reference_image_widget = st.sidebar.empty()
skip_timesteps_widget = st.sidebar.empty()
if use_custom_reference_image is True:
reference_image = reference_image_widget.file_uploader(
"Upload reference image",
type=["png", "jpeg", "jpg"],
accept_multiple_files=False,
help="Reference image for the network, will be resized to fit specified dimensions",
)
# Convert from UploadedFile object to PIL Image
if reference_image is not None:
reference_image: Image.Image = Image.open(reference_image).convert(
"RGB"
) # just to be sure
skip_timesteps = skip_timesteps_widget.number_input(
"Skip timesteps (suggested 200-500)",
value=200,
help="Higher values make the output look more like the reference image",
)
else:
reference_image = None
skip_timesteps = 0
continue_prev_run = st.sidebar.checkbox(
"Skip init if models are loaded",
value=True,
help="Skips lengthy model init",
)
use_cutout_augmentations = st.sidebar.checkbox(
"Use cutout augmentations",
value=False,
help="Adds cutout augmentations in the image generation process. Uses additional 1-2 GiB of GPU memory. Increases image quality, but probably not noticeable for guided diffusion since it's already pretty HQ and consumes a lot of VRAM, but feel free to experiment. Will significantly change image composition if toggled on vs toggled off. Toggled off by default.",
)
submitted = st.form_submit_button("Run!")
# End of form
status_text = st.empty()
status_text.text("Pending input prompt")
step_progress_bar = st.progress(0)
im_display_slot = st.empty()
vid_display_slot = st.empty()
debug_slot = st.empty()
if "prev_im" in st.session_state:
im_display_slot.image(
st.session_state["prev_im"], caption="Output image", output_format="PNG"
)
with st.expander("Expand for README"):
with open("README.md", "r") as f:
# Preprocess links to redirect to github
# Thank you https://discuss.streamlit.io/u/asehmi, works like a charm!
# ref: https://discuss.streamlit.io/t/image-in-markdown/13274/8
markdown_links = [str(i) for i in Path("docs/").glob("*.md")]
images = [str(i) for i in Path("docs/images/").glob("*")]
readme_lines = f.readlines()
readme_buffer = []
for line in readme_lines:
for md_link in markdown_links:
if md_link in line:
line = line.replace(
md_link,
"https://github.com/tnwei/vqgan-clip-app/tree/main/"
+ md_link,
)
readme_buffer.append(line)
for image in images:
if image in line:
st.markdown(" ".join(readme_buffer[:-1]))
st.image(
f"https://raw.githubusercontent.com/tnwei/vqgan-clip-app/main/{image}"
)
readme_buffer.clear()
st.markdown(" ".join(readme_buffer))
with st.expander("Expand for CHANGELOG"):
with open("CHANGELOG.md", "r") as f:
st.markdown(f.read())
if submitted:
# debug_slot.write(st.session_state) # DEBUG
status_text.text("Loading weights ...")
generate_image(
diffusion_weights=diffusion_method,
prompt=text_input,
seed=seed,
num_steps=num_steps,
continue_prev_run=continue_prev_run,
init_image=reference_image,
skip_timesteps=skip_timesteps,
use_cutout_augmentations=use_cutout_augmentations,
device=device,
)
vid_display_slot.video("temp.mp4")
# debug_slot.write(st.session_state) # DEBUG