Skip to content

Commit

Permalink
craft out a separate forward for flow inference of actions with cache…
Browse files Browse the repository at this point in the history
…d state key values
  • Loading branch information
lucidrains committed Nov 6, 2024
1 parent 05eb2b1 commit 852ccb3
Showing 1 changed file with 44 additions and 2 deletions.
46 changes: 44 additions & 2 deletions pi_zero_pytorch/pi_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,11 +102,53 @@ def __init__(

self.softclamp_value = softclamp_value

def forward_actions_with_cached_state(
self,
actions,
cached_state_keys_values: tuple[Tensor, Tensor],
actions_value_residual: Tensor | None = None,
):
aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1)

aq, ak, av = tuple(self.split_heads(t) for t in (aq, ak, av))

if exists(actions_value_residual):
av = 0.5 * (av + actions_value_residual)

q = aq
mk, mv = cached_state_keys_values

k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av)))

if exists(self.rotary_emb):
q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)

# attention

q = q * self.scale

sim = einsum(q, k, 'b h i d, b h j d -> b h i j')

attn = sim.softmax(dim = -1)

out = einsum(attn, v, 'b h i j, b h j d -> b h i d')

# merge attention heads

out = self.merge_heads(out)

output = (None, self.to_actions_out(out))

if not return_keys_values:
return output

return output, (mk, mv, ak, av)

def forward(
self,
multimodal_seq,
actions,
actions_value_residual = None,
actions_value_residual: Tensor | None = None,
return_keys_values = False,
flex_attn_fn: callable | None = None
):
Expand Down Expand Up @@ -427,7 +469,7 @@ def forward(
times = None,
return_actions_flow = False,
return_state_keys_values = False,
cached_state_keys_values = None,
cached_state_keys_values: list[tuple[Tensor, Tensor]] | None = None,
**kwargs
):
received_state_cache = exists(cached_state_keys_values)
Expand Down

0 comments on commit 852ccb3

Please sign in to comment.