From 657fcced21695632954e08a10bb13ae4bac7aafb Mon Sep 17 00:00:00 2001 From: lucidrains Date: Wed, 6 Nov 2024 10:50:18 -0800 Subject: [PATCH] wrap up the trickiest part of the project --- pi_zero_pytorch/pi_zero.py | 128 ++++++++++++++++++++++++------------- pyproject.toml | 2 +- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 8abe930..ba682d3 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -107,6 +107,7 @@ def forward_actions_with_cached_state( actions, cached_state_keys_values: tuple[Tensor, Tensor], actions_value_residual: Tensor | None = None, + return_keys_values = False ): aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1) @@ -137,12 +138,12 @@ def forward_actions_with_cached_state( out = self.merge_heads(out) - output = (None, self.to_actions_out(out)) + actions_out = self.to_actions_out(out) if not return_keys_values: - return output + return actions_out - return output, (mk, mv, ak, av) + return actions_out, (mk, mv, ak, av) def forward( self, @@ -432,6 +433,7 @@ def ode_fn(timestep, denoised_actions): joint_states, denoised_actions, times = timestep, + cached_state_keys_values = cached_state_kv, return_actions_flow = True, return_state_keys_values = True ) @@ -473,6 +475,7 @@ def forward( **kwargs ): received_state_cache = exists(cached_state_keys_values) + assert not (received_state_cache and not return_actions_flow), 'must be generating action trajectory if receiving cached state key values' if not exists(actions): return self.sample(images, token_ids, joint_state, **kwargs) @@ -502,46 +505,47 @@ def forward( time_cond = self.to_time_cond(times) action_tokens = self.to_action_tokens(actions) - # language + if not received_state_cache: + # language - labels = token_ids[:, 1:] + labels = token_ids[:, 1:] - language_tokens = self.token_emb(token_ids) + language_tokens = self.token_emb(token_ids) - # vision + # vision - if exists(self.vit): - assert images.ndim in {4, 5} - is_multiple_images = images.ndim == 5 + if exists(self.vit): + assert images.ndim in {4, 5} + is_multiple_images = images.ndim == 5 - if is_multiple_images: - images, images_frames_packed_shape = pack([images], '* c h w') + if is_multiple_images: + images, images_frames_packed_shape = pack([images], '* c h w') - with torch.no_grad(): - self.vit.eval() - visual_tokens = self.vit(images) + with torch.no_grad(): + self.vit.eval() + visual_tokens = self.vit(images) - if is_multiple_images: - visual_tokens = unpack(visual_tokens, images_frames_packed_shape, '* n d') - visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d') + if is_multiple_images: + visual_tokens = unpack(visual_tokens, images_frames_packed_shape, '* n d') + visual_tokens = rearrange(visual_tokens, 'b f n d -> b (f n) d') - else: - assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)' - visual_tokens = images + else: + assert images.ndim == 3, 'images must be already encoded as (batch, seq, feature dimension)' + visual_tokens = images - # joint state + # joint state - joint_state_tokens = self.to_joint_state_tokens(joint_state) + joint_state_tokens = self.to_joint_state_tokens(joint_state) - # concat visual rep with language + # concat visual rep with language - state_tokens, packed_shape = pack([visual_tokens, language_tokens, joint_state_tokens], 'b * d') + state_tokens, packed_shape = pack([visual_tokens, language_tokens, joint_state_tokens], 'b * d') # prepare maybe flex attention flex_attn_fn = None - if self.use_flex_attn and state_tokens.is_cuda: + if self.use_flex_attn and state_tokens.is_cuda and not received_state_cache: prefix_length = state_tokens.shape[-2] seq_len = prefix_length + action_tokens.shape[-2] @@ -563,6 +567,8 @@ def forward( # state keys and values for caching during inference + cached_state_key_values_iter = iter(default(cached_state_keys_values, [])) + state_cached_keys_values = [] # value residual learning @@ -571,43 +577,71 @@ def forward( # transformer - for ( - (attn, state_ff, actions_ff), - (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale) - ) in zip(self.layers, self.cond_layers): + if not received_state_cache: + for ( + (attn, state_ff, actions_ff), + (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale) + ) in zip(self.layers, self.cond_layers): + + action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) + + (state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, return_keys_values = True) + + state_cached_keys_values.append((state_keys, state_values)) + + actions_value_residual = default(actions_value_residual, action_values) + + action_tokens = attn_ada_layerscale(action_tokens, time_cond) + + state_tokens = state_tokens + state_attn_out + action_tokens = action_tokens + actions_attn_out - action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) + state_tokens = state_ff(state_tokens) + state_tokens - (state_attn_out, actions_attn_out), (state_keys, state_values, action_keys, action_values) = attn(state_tokens, action_tokens, flex_attn_fn = flex_attn_fn, actions_value_residual = actions_value_residual, return_keys_values = True) + action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) - state_cached_keys_values.append((state_keys, state_values)) + action_tokens = actions_ff(action_tokens) + action_tokens - actions_value_residual = default(actions_value_residual, action_values) + action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) - action_tokens = attn_ada_layerscale(action_tokens, time_cond) + else: + + for ( + (attn, state_ff, actions_ff), + (attn_ada_rmsnorm, attn_ada_layerscale, ff_ada_rmsnorm, ff_ada_layerscale) + ) in zip(self.layers, self.cond_layers): + + action_tokens = attn_ada_rmsnorm(action_tokens, time_cond) + + actions_attn_out, (state_keys, state_values, action_keys, action_values) = attn.forward_actions_with_cached_state(action_tokens, cached_state_keys_values = next(cached_state_key_values_iter), return_keys_values = True) + + state_cached_keys_values.append((state_keys, state_values)) - state_tokens = state_tokens + state_attn_out - action_tokens = action_tokens + actions_attn_out + actions_value_residual = default(actions_value_residual, action_values) - state_tokens = state_ff(state_tokens) + state_tokens + action_tokens = attn_ada_layerscale(action_tokens, time_cond) - action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) + action_tokens = action_tokens + actions_attn_out - action_tokens = actions_ff(action_tokens) + action_tokens + action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) - action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) + action_tokens = actions_ff(action_tokens) + action_tokens - # unpack and unembed to predictions + action_tokens = ff_ada_rmsnorm(action_tokens, time_cond) - visual_tokens, tokens, _ = unpack(state_tokens, packed_shape, 'b * d') + if not received_state_cache: + # unpack and unembed to predictions - # gemma uses a final softclamp before norm + visual_tokens, tokens, _ = unpack(state_tokens, packed_shape, 'b * d') - tokens, action_tokens = tuple(self.final_norm_softclamp(t) for t in (tokens, action_tokens)) + # gemma uses a final softclamp before norm + + tokens = self.final_norm_softclamp(tokens) + + action_tokens = self.final_norm_softclamp(action_tokens) # projection - tokens = self.final_norm(tokens) actions = self.final_actions_norm(action_tokens) # flow loss for actions tokens @@ -631,6 +665,8 @@ def forward( # language cross entropy loss + tokens = self.final_norm(tokens) + language_logits = self.state_to_logits(tokens) language_loss = F.cross_entropy( diff --git a/pyproject.toml b/pyproject.toml index 79fb99b..5005964 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.0.2" +version = "0.0.3" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }