Skip to content

Commit

Permalink
move hyper connections only to the middle of the unet
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 10, 2025
1 parent ed236d0 commit 60bcbac
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rectified-flow-pytorch"
version = "0.2.1"
version = "0.2.2"
description = "Rectified Flow in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
24 changes: 12 additions & 12 deletions rectified_flow_pytorch/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from einops import einsum, reduce, rearrange, repeat
from einops.layers.torch import Rearrange

from hyper_connections.hyper_connections_channel_first import get_init_and_expand_reduce_stream_functions
from hyper_connections.hyper_connections_channel_first import get_init_and_expand_reduce_stream_functions, Residual

from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -726,9 +726,9 @@ def __init__(
attn_klass = FullAttention if layer_full_attn else LinearAttention

self.downs.append(ModuleList([
init_hyper_conn(dim = dim_in, branch = resnet_block(dim_in, dim_in)),
init_hyper_conn(dim = dim_in, branch = resnet_block(dim_in, dim_in)),
init_hyper_conn(dim = dim_in, branch = attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads)),
Residual(branch = resnet_block(dim_in, dim_in)),
Residual(branch = resnet_block(dim_in, dim_in)),
Residual(branch = attn_klass(dim_in, dim_head = layer_attn_dim_head, heads = layer_attn_heads)),
Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1)
]))

Expand All @@ -743,16 +743,16 @@ def __init__(
attn_klass = FullAttention if layer_full_attn else LinearAttention

self.ups.append(ModuleList([
init_hyper_conn(dim = dim_out + dim_in, branch = resnet_block(dim_out + dim_in, dim_out), residual_transform = res_conv(dim_out + dim_in, dim_out)),
init_hyper_conn(dim = dim_out + dim_in, branch = resnet_block(dim_out + dim_in, dim_out), residual_transform = res_conv(dim_out + dim_in, dim_out)),
init_hyper_conn(dim = dim_out, branch = attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads)),
Residual(branch = resnet_block(dim_out + dim_in, dim_out), residual_transform = res_conv(dim_out + dim_in, dim_out)),
Residual(branch = resnet_block(dim_out + dim_in, dim_out), residual_transform = res_conv(dim_out + dim_in, dim_out)),
Residual(branch = attn_klass(dim_out, dim_head = layer_attn_dim_head, heads = layer_attn_heads)),
Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1)
]))

default_out_dim = channels * (1 if not learned_variance else 2)
self.out_dim = default(out_dim, default_out_dim)

self.final_res_block = init_hyper_conn(dim = init_dim * 2, branch = resnet_block(init_dim * 2, init_dim), residual_transform = res_conv(init_dim * 2, init_dim))
self.final_res_block = Residual(branch = resnet_block(init_dim * 2, init_dim), residual_transform = res_conv(init_dim * 2, init_dim))
self.final_conv = nn.Conv2d(init_dim, self.out_dim, 1)

@property
Expand All @@ -764,8 +764,6 @@ def forward(self, x, times):

x = self.init_conv(x)

x = self.expand_streams(x)

r = x.clone()

t = self.time_mlp(times)
Expand All @@ -782,10 +780,14 @@ def forward(self, x, times):

x = downsample(x)

x = self.expand_streams(x)

x = self.mid_block1(x, t)
x = self.mid_attn(x)
x = self.mid_block2(x, t)

x = self.reduce_streams(x)

for block1, block2, attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
Expand All @@ -800,8 +802,6 @@ def forward(self, x, times):

x = self.final_res_block(x, t)

x = self.reduce_streams(x)

return self.final_conv(x)

# dataset classes
Expand Down

0 comments on commit 60bcbac

Please sign in to comment.