diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 18f9ab6..0bd2692 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -33,6 +33,8 @@ from pi_zero_pytorch.tensor_typing import Float, Int, Bool +from hyper_connections import HyperConnections + import tqdm # ein notation @@ -218,97 +220,6 @@ def noise_assignment(data, noise): _, assign = linear_sum_assignment(dist.cpu()) return torch.from_numpy(assign).to(device) -# hyper connections - multiple residual streams - -class Residual(Module): - def __init__(self, **kwargs): - super().__init__() - - def prepare_with_inverse(self, residuals): - branch_input, residuals, residual_kwargs = self.prepare(residuals) - - def inverse(branch_out): - return self(branch_out, residuals, **residual_kwargs) - - return branch_input, inverse - - def prepare(self, residuals): - return residuals, residuals, dict() - - def forward(self, branch_out, residuals, **kwargs): - return branch_out + residuals - -class HyperConnections(Module): - def __init__( - self, - dim, - *, - num_residual_streams, - layer_index = None, - tanh = True, - **kwargs - ): - """ - https://arxiv.org/abs/2409.19606 - Appendix J - Algorithm 2, Dynamic only - """ - super().__init__() - - self.act = nn.Tanh() if tanh else nn.Identity() - - self.norm = nn.RMSNorm(dim) - - self.num_residual_streams = num_residual_streams - layer_index = default(layer_index, randrange(num_residual_streams)) # just choose one random residual stream if layer index not given - - self.static_beta = nn.Parameter(torch.ones(num_residual_streams)) - - init_alpha0 = torch.zeros((num_residual_streams, 1)) - init_alpha0[layer_index % num_residual_streams, 0] = 1. - - self.static_alpha = nn.Parameter(torch.cat([init_alpha0, torch.eye(num_residual_streams)], dim = 1)) - - self.dynamic_alpha_fn = nn.Parameter(torch.zeros(dim, num_residual_streams + 1)) - self.dynamic_alpha_scale = nn.Parameter(torch.ones(()) * 1e-2) - self.dynamic_beta_fn = nn.Parameter(torch.zeros(dim)) - self.dynamic_beta_scale = nn.Parameter(torch.ones(()) * 1e-2) - - def prepare_with_inverse(self, residuals): - branch_input, residuals, residual_kwargs = self.prepare(residuals) - - def inverse(branch_out): - return self(branch_out, residuals, **residual_kwargs) - - return branch_input, inverse - - def prepare(self, residuals): - - residuals = rearrange(residuals, '(b s) n d -> b n s d', s = self.num_residual_streams) - - normed = self.norm(residuals) - - wc_weight = self.act(normed @ self.dynamic_alpha_fn) - dynamic_alpha = wc_weight * self.dynamic_alpha_scale - alpha = dynamic_alpha + self.static_alpha - - dc_weight = self.act(normed @ self.dynamic_beta_fn) - dynamic_beta = dc_weight * self.dynamic_beta_scale - beta = dynamic_beta + self.static_beta - - # width connection - - mix_h = einsum(alpha, residuals, '... s t, ... s d -> ... t d') - - branch_input, residuals = mix_h[..., 0, :], mix_h[..., 1:, :] - - return branch_input, residuals, dict(beta = beta) - - def forward(self, branch_output, residuals, *, beta): - # 'depth' connection - - residuals = einsum(branch_output, beta, 'b n d, b n s -> b n s d') + residuals - return rearrange(residuals, 'b n s d -> (b s) n d') - # attention class Attention(Module): @@ -789,19 +700,11 @@ def __init__( # residual functions, with maybe hyper connections assert num_residual_streams >= 1 - is_hyper_connection = num_residual_streams > 1 - residual_klass = Residual if not is_hyper_connection else HyperConnections + init_residual_fn, self.maybe_expand_residuals, self.maybe_reduce_residuals = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) residual_fns = [] counter = count() - self.maybe_expand_residuals = identity - self.maybe_reduce_residuals = identity - - if is_hyper_connection: - self.maybe_expand_residuals = maybe(partial(repeat, pattern = 'b n d -> (b s) n d', s = num_residual_streams)) - self.maybe_reduce_residuals = maybe(partial(reduce, reduction = 'sum', pattern = '(b s) n d -> b n d', s = num_residual_streams)) - # attention and feedforward layers = [] @@ -818,8 +721,8 @@ def __init__( ])) residual_fns.append(ModuleList([ - residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = next(counter)), - residual_klass(dim = dim, num_residual_streams = num_residual_streams, layer_index = next(counter)), + init_residual_fn(dim = dim, layer_index = next(counter)), + init_residual_fn(dim = dim, layer_index = next(counter)), ])) cond_layers.append(ModuleList([ @@ -1336,7 +1239,7 @@ def forward( # joint attention - action_tokens, add_action_residual = attn_residual.prepare_with_inverse(action_tokens) + action_tokens, add_action_residual = attn_residual(action_tokens) action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) @@ -1374,7 +1277,7 @@ def forward( # action feedforward - action_tokens, add_action_ff_residual = actions_ff_residual.prepare_with_inverse(action_tokens) + action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens) action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) @@ -1403,7 +1306,7 @@ def forward( # actions attention - action_tokens, add_action_residual = attn_residual.prepare_with_inverse(action_tokens) + action_tokens, add_action_residual = attn_residual(action_tokens) action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) @@ -1424,7 +1327,7 @@ def forward( # actions feed forward - action_tokens, add_action_ff_residual = actions_ff_residual.prepare_with_inverse(action_tokens) + action_tokens, add_action_ff_residual = actions_ff_residual(action_tokens) action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) diff --git a/pyproject.toml b/pyproject.toml index a00e2dc..749202f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.1.2" +version = "0.1.4" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -30,6 +30,7 @@ dependencies = [ "einops>=0.8.0", "ema-pytorch>=0.7.3", "jaxtyping", + 'hyper-connections>=0.0.10', "rotary-embedding-torch>=0.8.5", 'scipy', "torch>=2.5",