Skip to content

Commit

Permalink
write a working self contained nano version less than 75 loc, for use…
Browse files Browse the repository at this point in the history
… in other projects as smaller subsystem
  • Loading branch information
lucidrains committed Feb 9, 2025
1 parent 566882b commit 10c2090
Show file tree
Hide file tree
Showing 5 changed files with 136 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rectified-flow-pytorch"
version = "0.2.2"
version = "0.2.3"
description = "Rectified Flow in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
2 changes: 2 additions & 0 deletions rectified_flow_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
Trainer,
)

from rectified_flow_pytorch.nano_flow import NanoFlow

from rectified_flow_pytorch.reflow import (
Reflow,
ReflowTrainer
Expand Down
73 changes: 73 additions & 0 deletions rectified_flow_pytorch/nano_flow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import torch
from torch.nn import Module
import torch.nn.functional as F

def exists(v):
return v is not None

def default(v, d):
return v if exists(v) else d

def append_dims(t, dims):
shape = t.shape
ones = ((1,) * dims)
return t.reshape(*shape, *ones)

class NanoFlow(Module):
def __init__(
self,
model: Module,
times_cond_kwarg = None,
data_shape = None
):
super().__init__()
self.model = model
self.times_cond_kwarg = times_cond_kwarg
self.data_shape = None

@torch.no_grad()
def sample(
self,
steps = 16,
batch_size = 1,
data_shape = None
):
data_shape = default(data_shape, self.data_shape)
assert exists(data_shape), 'shape of the data must be passed in, or set at init or during training'
device = next(self.model.parameters()).device

noise = torch.randn((batch_size, *self.data_shape), device = device)
times = torch.linspace(0., 1., steps, device = device)
delta = 1. / steps

denoised = noise

for time in times:
time = time.expand(batch_size)
time_kwarg = {self.times_cond_kwarg: time} if exists(self.times_cond_kwarg) else dict()

pred_flow = self.model(noise, **time_kwarg)
denoised = denoised + delta * pred_flow

return denoised

def forward(self, data):
# shapes and variables

shape, ndim = data.shape, data.ndim
self.data_shape = default(self.data_shape, shape[1:]) # store last data shape for inference
batch, device = shape[0], data.device

# flow logic

times = torch.rand(batch, device = device)
noise = torch.randn_like(data)
flow = data - noise # flow is the velocity from noise to data, also what the model is trained to predict

padded_times = append_dims(times, ndim - 1)
noised_data = noise * padded_times + data * (1. - padded_times) # noise the data with random amounts of noise (time)

time_kwarg = {self.times_cond_kwarg: times} if exists(self.times_cond_kwarg) else dict() # maybe time conditioning, could work without it
pred_flow = self.model(noised_data, **time_kwarg)

return F.mse_loss(flow, pred_flow)
17 changes: 12 additions & 5 deletions rectified_flow_pytorch/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

from scipy.optimize import linear_sum_assignment

from rectified_flow_pytorch.nano_flow import NanoFlow

# helpers

def exists(v):
Expand Down Expand Up @@ -872,7 +874,7 @@ def cycle(dl):
class Trainer(Module):
def __init__(
self,
rectified_flow: dict | RectifiedFlow,
rectified_flow: dict | RectifiedFlow | NanoFlow,
*,
dataset: dict | Dataset,
num_train_steps = 70_000,
Expand Down Expand Up @@ -902,7 +904,7 @@ def __init__(
# determine whether to keep track of EMA (if not using consistency FM)
# which will determine which model to use for sampling

use_ema &= not self.model.use_consistency
use_ema &= not getattr(self.model, 'use_consistency', False)

self.use_ema = use_ema
self.ema_model = None
Expand All @@ -925,6 +927,8 @@ def __init__(

self.num_train_steps = num_train_steps

self.return_loss_breakdown = isinstance(rectified_flow, RectifiedFlow)

# folders

self.checkpoints_folder = Path(checkpoints_folder)
Expand Down Expand Up @@ -1000,17 +1004,20 @@ def forward(self):
self.model.train()

data = next(dl)
loss, loss_breakdown = self.model(data, return_loss_breakdown = True)

self.log(loss_breakdown._asdict(), step = step)
if self.return_loss_breakdown:
loss, loss_breakdown = self.model(data, return_loss_breakdown = True)
self.log(loss_breakdown._asdict(), step = step)
else:
loss = self.model(data)

self.accelerator.print(f'[{step}] loss: {loss.item():.3f}')
self.accelerator.backward(loss)

self.optimizer.step()
self.optimizer.zero_grad()

if self.model.use_consistency:
if getattr(self.model, 'use_consistency', False):
self.model.ema_model.update()

if self.is_main and self.use_ema:
Expand Down
48 changes: 48 additions & 0 deletions train_nano_rf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch

# hf datasets for easy oxford flowers training

import torchvision.transforms as T
from torch.utils.data import Dataset
from datasets import load_dataset

class OxfordFlowersDataset(Dataset):
def __init__(
self,
image_size
):
self.ds = load_dataset('nelorth/oxford-flowers')['train']

self.transform = T.Compose([
T.Resize((image_size, image_size)),
T.PILToTensor()
])

def __len__(self):
return len(self.ds)

def __getitem__(self, idx):
pil = self.ds[idx]['image']
tensor = self.transform(pil)
return tensor / 255.

flowers_dataset = OxfordFlowersDataset(
image_size = 128
)

# models and trainer

from rectified_flow_pytorch import NanoFlow, Unet, Trainer

model = Unet(dim = 64)

rectified_flow = NanoFlow(model, times_cond_kwarg = 'times')

trainer = Trainer(
rectified_flow,
dataset = flowers_dataset,
num_train_steps = 70_000,
results_folder = './results' # samples will be saved periodically to this folder
)

trainer()

0 comments on commit 10c2090

Please sign in to comment.