-
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathgeneration_utils.py
58 lines (46 loc) · 1.79 KB
/
generation_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from pathlib import Path
import torch
from diffusers import (
PNDMScheduler
)
import numpy as np
# -------------------------------------------------------------------------------
# Helper functions directly copied from
# https://github.com/andreasjansson/cog-stable-diffusion/blob/animate/predict_animate.py
# -------------------------------------------------------------------------------
def make_scheduler(num_inference_steps, from_scheduler=None):
scheduler = PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
scheduler.set_timesteps(num_inference_steps)
if from_scheduler:
scheduler.cur_model_output = from_scheduler.cur_model_output
scheduler.counter = from_scheduler.counter
scheduler.cur_sample = from_scheduler.cur_sample
scheduler.ets = from_scheduler.ets[:]
return scheduler
def slerp(t, v0, v1, DOT_THRESHOLD=0.9995):
"""helper function to spherically interpolate two arrays v1 v2"""
# from https://gist.github.com/nateraw/c989468b74c616ebbc6474aa8cdd9e53
if not isinstance(v0, np.ndarray):
inputs_are_torch = True
input_device = v0.device
v0 = v0.cpu().numpy()
v1 = v1.cpu().numpy()
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
if np.abs(dot) > DOT_THRESHOLD:
v2 = (1 - t) * v0 + t * v1
else:
theta_0 = np.arccos(dot)
sin_theta_0 = np.sin(theta_0)
theta_t = theta_0 * t
sin_theta_t = np.sin(theta_t)
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
s1 = sin_theta_t / sin_theta_0
v2 = s0 * v0 + s1 * v1
if inputs_are_torch:
v2 = torch.from_numpy(v2).to(input_device)
return v2
def save_pil_image(image, path):
image.save(path)
return Path(path)