diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 29c19dc..8abe930 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -102,11 +102,53 @@ def __init__( self.softclamp_value = softclamp_value + def forward_actions_with_cached_state( + self, + actions, + cached_state_keys_values: tuple[Tensor, Tensor], + actions_value_residual: Tensor | None = None, + ): + aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1) + + aq, ak, av = tuple(self.split_heads(t) for t in (aq, ak, av)) + + if exists(actions_value_residual): + av = 0.5 * (av + actions_value_residual) + + q = aq + mk, mv = cached_state_keys_values + + k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av))) + + if exists(self.rotary_emb): + q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k) + + # attention + + q = q * self.scale + + sim = einsum(q, k, 'b h i d, b h j d -> b h i j') + + attn = sim.softmax(dim = -1) + + out = einsum(attn, v, 'b h i j, b h j d -> b h i d') + + # merge attention heads + + out = self.merge_heads(out) + + output = (None, self.to_actions_out(out)) + + if not return_keys_values: + return output + + return output, (mk, mv, ak, av) + def forward( self, multimodal_seq, actions, - actions_value_residual = None, + actions_value_residual: Tensor | None = None, return_keys_values = False, flex_attn_fn: callable | None = None ): @@ -427,7 +469,7 @@ def forward( times = None, return_actions_flow = False, return_state_keys_values = False, - cached_state_keys_values = None, + cached_state_keys_values: list[tuple[Tensor, Tensor]] | None = None, **kwargs ): received_state_cache = exists(cached_state_keys_values)