Skip to content

Commit

Permalink
fix a bug where actions branch feedforward is rmsnormed, as it is alr…
Browse files Browse the repository at this point in the history
…eady adaptive rmsnormed externally
  • Loading branch information
lucidrains committed Dec 9, 2024
1 parent d35d035 commit dae4cf7
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 20 deletions.
71 changes: 52 additions & 19 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,12 +435,13 @@ def __init__(
self,
dim,
expand_factor = 4.,
dim_inner = None
dim_inner = None,
rmsnorm = True
):
super().__init__()
dim_inner = default(dim_inner, int(dim * expand_factor * 2 / 3))

self.rmsnorm = nn.RMSNorm(dim)
self.rmsnorm = nn.RMSNorm(dim) if rmsnorm else nn.Identity()
self.proj_in = LinearNoBias(dim, dim_inner * 2)
self.proj_out = LinearNoBias(dim_inner, dim)

Expand Down Expand Up @@ -629,7 +630,8 @@ def __init__(
layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, num_recurrent_memory_tokens = num_recurrent_memory_tokens, learned_value_action_residual_mix = not is_first_block, **attn_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, rmsnorm = False, **ff_kwargs),
SwiGLUFeedForward(dim = dim, expand_factor = ff_expand_factor, **ff_kwargs) if self.has_recurrent_memories else None
]))

cond_layers.append(ModuleList([
Expand Down Expand Up @@ -888,7 +890,7 @@ def forward_only_vision_language(

# transformer

for attn, ff, _ in self.layers:
for attn, ff, _, _ in self.layers:

state_attn_out = attn.forward_only_vision_language(state_tokens, rotary_emb = rotary_emb)

Expand Down Expand Up @@ -962,10 +964,13 @@ def forward(

# take care of maybe recurrent memory tokens

if not exists(past_recurrent_memory_tokens):
past_recurrent_memory_tokens = actions.new_empty((batch, 0, self.dim))

if self.has_recurrent_memories:
memory_tokens = repeat(self.memory_tokens, 'nm d -> b nm d', b = batch)
write_memory_tokens = repeat(self.memory_tokens, 'nm d -> b nm d', b = batch)
else:
memory_tokens = actions.new_empty((batch, 0, self.dim))
write_memory_tokens = actions.new_empty((batch, 0, self.dim))

# joint state + additional internal states

Expand All @@ -978,14 +983,17 @@ def forward(

internal_state_tokens = self.to_internal_state_tokens(internal_state_tokens)

# pack into [action registers] [internal + joint states] [actions] [memory tokens (write)]
# handle memory tokens, both read and write as a tuple of two tensors

memory_tokens = (past_recurrent_memory_tokens, write_memory_tokens)

# pack into [action registers] [internal + joint states] [actions]

action_tokens, inverse_pack_action_registers = pack_with_inverse([
action_register_tokens,
joint_state_tokens,
internal_state_tokens,
action_tokens,
memory_tokens
action_tokens
], 'b * d')

action_with_registers_length = action_tokens.shape[-2]
Expand Down Expand Up @@ -1045,13 +1053,9 @@ def forward(
assert self.has_recurrent_memories or not exists(past_recurrent_memory_tokens), 'you are asking for memories to be read, but `num_recurrent_memory_tokens` is 0'
assert self.has_recurrent_memories or not record_and_return_memory_tokens, 'you are asking for memories to be written, but `num_recurrent_memory_tokens` is 0'

if not exists(past_recurrent_memory_tokens):
past_recurrent_memory_tokens = visual_tokens.new_empty((batch, 0, self.dim))

# concat visual rep with language

state_tokens, inverse_packed_states = pack_with_inverse([
past_recurrent_memory_tokens,
external_state_tokens,
visual_tokens,
language_tokens,
Expand Down Expand Up @@ -1121,13 +1125,21 @@ def forward(

if not inferencing:
for (
(attn, state_ff, actions_ff),
(attn, state_ff, actions_ff, memories_ff),
(attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale)
) in zip(self.layers, self.cond_layers):

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, rotary_emb = rotary_emb, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, mask = mask, return_keys_values = True)
(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(
state_tokens,
action_tokens,
rotary_emb = rotary_emb,
flex_attn_fn = flex_attn_fn,
actions_value_residual = actions_value_residual,
mask = mask,
return_keys_values = True
)

state_cached_keys_values.append((state_keys, state_values))

Expand All @@ -1144,16 +1156,27 @@ def forward(

action_tokens = ff_ada_layerscale(action_tokens, time_cond)

memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')

memory_tokens = memories_ff(memory_tokens) + memory_tokens

memory_tokens = unpack_memory(memory_tokens)
else:

for (
(attn, state_ff, actions_ff),
(attn, state_ff, actions_ff, memories_ff),
(attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale)
) in zip(self.layers, self.cond_layers):

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), rotary_emb = rotary_emb, mask = mask, return_keys_values = True)
actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(
action_tokens,
cached_state_keys_values = next(cached_state_key_values_iter),
rotary_emb = rotary_emb,
mask = mask,
return_keys_values = True
)

state_cached_keys_values.append((state_keys, state_values))

Expand All @@ -1167,19 +1190,29 @@ def forward(

action_tokens = ff_ada_layerscale(action_tokens, time_cond)

memory_tokens, unpack_memory = pack_with_inverse(memory_tokens, 'b * d')

memory_tokens = memories_ff(memory_tokens) + memory_tokens

memory_tokens = unpack_memory(memory_tokens)

if not inferencing:
# unpack and unembed to predictions

_, _, visual_tokens, tokens, *_ = inverse_packed_states(state_tokens, 'b * d')
_, visual_tokens, tokens, *_ = inverse_packed_states(state_tokens, 'b * d')

# gemma uses a final softclamp before norm

tokens = self.final_norm_softclamp(tokens)

*_, action_tokens, written_memory_tokens = inverse_pack_action_registers(action_tokens)
*_, action_tokens = inverse_pack_action_registers(action_tokens)

action_tokens = self.final_norm_softclamp(action_tokens)

# memories

read_memories, written_memory_tokens = memory_tokens

# writeable memories norm

if self.has_recurrent_memories:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.43"
version = "0.0.44"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit dae4cf7

Please sign in to comment.