-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
write a working self contained nano version less than 75 loc, for use…
… in other projects as smaller subsystem
- Loading branch information
1 parent
566882b
commit 10c2090
Showing
5 changed files
with
136 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import torch | ||
from torch.nn import Module | ||
import torch.nn.functional as F | ||
|
||
def exists(v): | ||
return v is not None | ||
|
||
def default(v, d): | ||
return v if exists(v) else d | ||
|
||
def append_dims(t, dims): | ||
shape = t.shape | ||
ones = ((1,) * dims) | ||
return t.reshape(*shape, *ones) | ||
|
||
class NanoFlow(Module): | ||
def __init__( | ||
self, | ||
model: Module, | ||
times_cond_kwarg = None, | ||
data_shape = None | ||
): | ||
super().__init__() | ||
self.model = model | ||
self.times_cond_kwarg = times_cond_kwarg | ||
self.data_shape = None | ||
|
||
@torch.no_grad() | ||
def sample( | ||
self, | ||
steps = 16, | ||
batch_size = 1, | ||
data_shape = None | ||
): | ||
data_shape = default(data_shape, self.data_shape) | ||
assert exists(data_shape), 'shape of the data must be passed in, or set at init or during training' | ||
device = next(self.model.parameters()).device | ||
|
||
noise = torch.randn((batch_size, *self.data_shape), device = device) | ||
times = torch.linspace(0., 1., steps, device = device) | ||
delta = 1. / steps | ||
|
||
denoised = noise | ||
|
||
for time in times: | ||
time = time.expand(batch_size) | ||
time_kwarg = {self.times_cond_kwarg: time} if exists(self.times_cond_kwarg) else dict() | ||
|
||
pred_flow = self.model(noise, **time_kwarg) | ||
denoised = denoised + delta * pred_flow | ||
|
||
return denoised | ||
|
||
def forward(self, data): | ||
# shapes and variables | ||
|
||
shape, ndim = data.shape, data.ndim | ||
self.data_shape = default(self.data_shape, shape[1:]) # store last data shape for inference | ||
batch, device = shape[0], data.device | ||
|
||
# flow logic | ||
|
||
times = torch.rand(batch, device = device) | ||
noise = torch.randn_like(data) | ||
flow = data - noise # flow is the velocity from noise to data, also what the model is trained to predict | ||
|
||
padded_times = append_dims(times, ndim - 1) | ||
noised_data = noise * padded_times + data * (1. - padded_times) # noise the data with random amounts of noise (time) | ||
|
||
time_kwarg = {self.times_cond_kwarg: times} if exists(self.times_cond_kwarg) else dict() # maybe time conditioning, could work without it | ||
pred_flow = self.model(noised_data, **time_kwarg) | ||
|
||
return F.mse_loss(flow, pred_flow) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import torch | ||
|
||
# hf datasets for easy oxford flowers training | ||
|
||
import torchvision.transforms as T | ||
from torch.utils.data import Dataset | ||
from datasets import load_dataset | ||
|
||
class OxfordFlowersDataset(Dataset): | ||
def __init__( | ||
self, | ||
image_size | ||
): | ||
self.ds = load_dataset('nelorth/oxford-flowers')['train'] | ||
|
||
self.transform = T.Compose([ | ||
T.Resize((image_size, image_size)), | ||
T.PILToTensor() | ||
]) | ||
|
||
def __len__(self): | ||
return len(self.ds) | ||
|
||
def __getitem__(self, idx): | ||
pil = self.ds[idx]['image'] | ||
tensor = self.transform(pil) | ||
return tensor / 255. | ||
|
||
flowers_dataset = OxfordFlowersDataset( | ||
image_size = 128 | ||
) | ||
|
||
# models and trainer | ||
|
||
from rectified_flow_pytorch import NanoFlow, Unet, Trainer | ||
|
||
model = Unet(dim = 64) | ||
|
||
rectified_flow = NanoFlow(model, times_cond_kwarg = 'times') | ||
|
||
trainer = Trainer( | ||
rectified_flow, | ||
dataset = flowers_dataset, | ||
num_train_steps = 70_000, | ||
results_folder = './results' # samples will be saved periodically to this folder | ||
) | ||
|
||
trainer() |