Skip to content

Commit

Permalink
use a hyper connection specific to 2d channel first and default to 2 …
Browse files Browse the repository at this point in the history
…streams
  • Loading branch information
lucidrains committed Jan 5, 2025
1 parent 3427d61 commit ed236d0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "rectified-flow-pytorch"
version = "0.2.0"
version = "0.2.1"
description = "Rectified Flow in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand All @@ -27,7 +27,7 @@ dependencies = [
'einops>=0.8.0',
'einx>=0.3.0',
'ema-pytorch>=0.5.2',
'hyper-connections>=0.1.7',
'hyper-connections>=0.1.8',
'pillow',
'scipy',
'torch>=2.0',
Expand Down
6 changes: 2 additions & 4 deletions rectified_flow_pytorch/rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from einops import einsum, reduce, rearrange, repeat
from einops.layers.torch import Rearrange

from hyper_connections import get_init_and_expand_reduce_stream_functions
from hyper_connections.hyper_connections_channel_first import get_init_and_expand_reduce_stream_functions

from scipy.optimize import linear_sum_assignment

Expand Down Expand Up @@ -658,7 +658,7 @@ def __init__(
attn_heads = 4,
full_attn = None, # defaults to full attention only for inner most layer
flash_attn = False,
num_residual_streams = 4
num_residual_streams = 2
):
super().__init__()

Expand Down Expand Up @@ -712,8 +712,6 @@ def __init__(
# hyper connections

init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

init_hyper_conn = partial(init_hyper_conn, channel_first = True)
res_conv = partial(nn.Conv2d, kernel_size = 1, bias = False)

# layers
Expand Down

0 comments on commit ed236d0

Please sign in to comment.