diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py
index 29c19dc..8abe930 100644
--- a/pi_zero_pytorch/pi_zero.py
+++ b/pi_zero_pytorch/pi_zero.py
@@ -102,11 +102,53 @@ def __init__(
 
         self.softclamp_value = softclamp_value
 
+    def forward_actions_with_cached_state(
+        self,
+        actions,
+        cached_state_keys_values: tuple[Tensor, Tensor],
+        actions_value_residual: Tensor | None = None,
+    ):
+        aq, ak, av = self.to_actions_qkv(actions).chunk(3, dim = -1)
+
+        aq, ak, av = tuple(self.split_heads(t) for t in (aq, ak, av))
+
+        if exists(actions_value_residual):
+            av = 0.5 * (av + actions_value_residual)
+
+        q = aq
+        mk, mv = cached_state_keys_values
+
+        k, v = tuple(torch.cat(tensors, dim = -2) for tensors in zip((mk, mv), (ak, av)))
+
+        if exists(self.rotary_emb):
+            q, k = self.rotary_emb.rotate_queries_with_cached_keys(q, k)
+
+        # attention
+
+        q = q * self.scale
+
+        sim = einsum(q, k, 'b h i d, b h j d -> b h i j')
+
+        attn = sim.softmax(dim = -1)
+
+        out = einsum(attn, v, 'b h i j, b h j d -> b h i d')
+
+        # merge attention heads
+
+        out = self.merge_heads(out)
+
+        output = (None, self.to_actions_out(out))
+
+        if not return_keys_values:
+            return output
+
+        return output, (mk, mv, ak, av)
+
     def forward(
         self,
         multimodal_seq,
         actions,
-        actions_value_residual = None,
+        actions_value_residual: Tensor | None = None,
         return_keys_values = False,
         flex_attn_fn: callable | None = None
     ):
@@ -427,7 +469,7 @@ def forward(
         times = None,
         return_actions_flow = False,
         return_state_keys_values = False,
-        cached_state_keys_values = None,
+        cached_state_keys_values: list[tuple[Tensor, Tensor]] | None = None,
         **kwargs
     ):
         received_state_cache = exists(cached_state_keys_values)