Skip to content

Commit

Permalink
wrap up the trickiest part of the project
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 6, 2024
1 parent 852ccb3 commit 657fcce
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 47 deletions.
128 changes: 82 additions & 46 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def forward_actions_with_cached_state(
actions,
cached_state_keys_values: tuple[Tensor, Tensor],
actions_value_residual: Tensor | None = None,
return_keys_values = False
):
aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1)

Expand Down Expand Up @@ -137,12 +138,12 @@ def forward_actions_with_cached_state(

out = self.merge_heads(out)

output = (None, self.to_actions_out(out))
actions_out = self.to_actions_out(out)

if not return_keys_values:
return output
return actions_out

return output, (mk, mv, ak, av)
return actions_out, (mk, mv, ak, av)

def forward(
self,
Expand Down Expand Up @@ -432,6 +433,7 @@ def ode_fn(timestep, denoised_actions):
joint_states,
denoised_actions,
times = timestep,
cached_state_keys_values = cached_state_kv,
return_actions_flow = True,
return_state_keys_values = True
)
Expand Down Expand Up @@ -473,6 +475,7 @@ def forward(
**kwargs
):
received_state_cache = exists(cached_state_keys_values)
assert not (received_state_cache and not return_actions_flow), 'must be generating action trajectory if receiving cached state key values'

if not exists(actions):
return self.sample(images, token_ids, joint_state, **kwargs)
Expand Down Expand Up @@ -502,46 +505,47 @@ def forward(
time_cond = self.to_time_cond(times)
action_tokens = self.to_action_tokens(actions)

# language
if not received_state_cache:
# language

labels = token_ids[:, 1:]
labels = token_ids[:, 1:]

language_tokens = self.token_emb(token_ids)
language_tokens = self.token_emb(token_ids)

# vision
# vision

if exists(self.vit):
assert images.ndim in {4, 5}
is_multiple_images = images.ndim == 5
if exists(self.vit):
assert images.ndim in {4, 5}
is_multiple_images = images.ndim == 5

if is_multiple_images:
images, images_frames_packed_shape = pack([images], '* c h w')
if is_multiple_images:
images, images_frames_packed_shape = pack([images], '* c h w')

with torch.no_grad():
self.vit.eval()
visual_tokens = self.vit(images)
with torch.no_grad():
self.vit.eval()
visual_tokens = self.vit(images)

if is_multiple_images:
visual_tokens = unpack(visual_tokens, images_frames_packed_shape, '* n d')
visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d')
if is_multiple_images:
visual_tokens = unpack(visual_tokens, images_frames_packed_shape, '* n d')
visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d')

else:
assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)'
visual_tokens = images
else:
assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)'
visual_tokens = images

# joint state
# joint state

joint_state_tokens = self.to_joint_state_tokens(joint_state)
joint_state_tokens = self.to_joint_state_tokens(joint_state)

# concat visual rep with language
# concat visual rep with language

state_tokens, packed_shape = pack([visual_tokens, language_tokens, joint_state_tokens], 'b * d')
state_tokens, packed_shape = pack([visual_tokens, language_tokens, joint_state_tokens], 'b * d')

# prepare maybe flex attention

flex_attn_fn = None

if self.use_flex_attn and state_tokens.is_cuda:
if self.use_flex_attn and state_tokens.is_cuda and not received_state_cache:

prefix_length = state_tokens.shape[-2]
seq_len = prefix_length + action_tokens.shape[-2]
Expand All @@ -563,6 +567,8 @@ def forward(

# state keys and values for caching during inference

cached_state_key_values_iter = iter(default(cached_state_keys_values, []))

state_cached_keys_values = []

# value residual learning
Expand All @@ -571,43 +577,71 @@ def forward(

# transformer

for (
(attn, state_ff, actions_ff),
(attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale)
) in zip(self.layers, self.cond_layers):
if not received_state_cache:
for (
(attn, state_ff, actions_ff),
(attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale)
) in zip(self.layers, self.cond_layers):

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, return_keys_values = True)

state_cached_keys_values.append((state_keys, state_values))

actions_value_residual = default(actions_value_residual, action_values)

action_tokens = attn_ada_layerscale(action_tokens, time_cond)

state_tokens = state_tokens + state_attn_out
action_tokens = action_tokens + actions_attn_out

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)
state_tokens = state_ff(state_tokens) + state_tokens

(state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, return_keys_values = True)
action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

state_cached_keys_values.append((state_keys, state_values))
action_tokens = actions_ff(action_tokens) + action_tokens

actions_value_residual = default(actions_value_residual, action_values)
action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

action_tokens = attn_ada_layerscale(action_tokens, time_cond)
else:

for (
(attn, state_ff, actions_ff),
(attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale)
) in zip(self.layers, self.cond_layers):

action_tokens = attn_ada_rmsnorm(action_tokens, time_cond)

actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), return_keys_values = True)

state_cached_keys_values.append((state_keys, state_values))

state_tokens = state_tokens + state_attn_out
action_tokens = action_tokens + actions_attn_out
actions_value_residual = default(actions_value_residual, action_values)

state_tokens = state_ff(state_tokens) + state_tokens
action_tokens = attn_ada_layerscale(action_tokens, time_cond)

action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)
action_tokens = action_tokens + actions_attn_out

action_tokens = actions_ff(action_tokens) + action_tokens
action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)
action_tokens = actions_ff(action_tokens) + action_tokens

# unpack and unembed to predictions
action_tokens = ff_ada_rmsnorm(action_tokens, time_cond)

visual_tokens, tokens, _ = unpack(state_tokens, packed_shape, 'b * d')
if not received_state_cache:
# unpack and unembed to predictions

# gemma uses a final softclamp before norm
visual_tokens, tokens, _ = unpack(state_tokens, packed_shape, 'b * d')

tokens, action_tokens = tuple(self.final_norm_softclamp(t) for t in (tokens, action_tokens))
# gemma uses a final softclamp before norm

tokens = self.final_norm_softclamp(tokens)

action_tokens = self.final_norm_softclamp(action_tokens)

# projection

tokens = self.final_norm(tokens)
actions = self.final_actions_norm(action_tokens)

# flow loss for actions tokens
Expand All @@ -631,6 +665,8 @@ def forward(

# language cross entropy loss

tokens = self.final_norm(tokens)

language_logits = self.state_to_logits(tokens)

language_loss = F.cross_entropy(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "pi-zero-pytorch"
version = "0.0.2"
version = "0.0.3"
description = "π0 in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down

0 comments on commit 657fcce

Please sign in to comment.