diff --git a/pi_zero_pytorch/pi_zero.py b/pi_zero_pytorch/pi_zero.py index 94a1b51..03d7285 100644 --- a/pi_zero_pytorch/pi_zero.py +++ b/pi_zero_pytorch/pi_zero.py @@ -1203,7 +1203,8 @@ def forward( ), Q_LEN = seq_len, KV_LEN = seq_len, - device = state_tokens.device + device = state_tokens.device, + _compile = True, ) score_mod_fn = softclamp_score_mod(self.attn_softclamp_value) diff --git a/pyproject.toml b/pyproject.toml index 0c565f6..92d638e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "pi-zero-pytorch" -version = "0.1.5" +version = "0.1.6" description = "π0 in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" }