From 66207bafe45dc6c438c16cae6b0e767ba4475676 Mon Sep 17 00:00:00 2001
From: Phil Wang <lucidrains@gmail.com>
Date: Wed, 13 Nov 2024 06:35:51 -0800
Subject: [PATCH] e2e test with vit-pytorch

---
 .github/workflows/test.yaml | 24 +++++++++++++++++++
 pyproject.toml              |  3 ++-
 tests/test_pi_zero.py       | 46 +++++++++++++++++++++++++++++++++++++
 3 files changed, 72 insertions(+), 1 deletion(-)
 create mode 100644 .github/workflows/test.yaml
 create mode 100644 tests/test_pi_zero.py

diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
new file mode 100644
index 0000000..593e558
--- /dev/null
+++ b/.github/workflows/test.yaml
@@ -0,0 +1,24 @@
+name: Tests the examples in README
+on: push
+
+env:
+  TYPECHECK: True
+
+jobs:
+  test:
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v4
+      - name: Install Python
+        uses: actions/setup-python@v5
+        with:
+          python-version: "3.11"
+      - name: Install dependencies
+        run: |
+          python -m pip install uv
+          python -m uv pip install --upgrade pip
+          python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
+          python -m uv pip install -e .[test]
+      - name: Test with pytest
+        run: |
+          python -m pytest tests/
diff --git a/pyproject.toml b/pyproject.toml
index 66978f8..7748747 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -42,7 +42,8 @@ Repository = "https://github.com/lucidrains/pi-zero-pytorch"
 [project.optional-dependencies]
 examples = []
 test = [
-    "pytest"
+    "pytest",
+    "vit-pytorch"
 ]
 
 [tool.pytest.ini_options]
diff --git a/tests/test_pi_zero.py b/tests/test_pi_zero.py
new file mode 100644
index 0000000..8e103a4
--- /dev/null
+++ b/tests/test_pi_zero.py
@@ -0,0 +1,46 @@
+import torch
+from pi_zero_pytorch import π0
+
+def test_pi_zero_with_vit():
+    from vit_pytorch import ViT
+    from vit_pytorch.extractor import Extractor
+
+    v = ViT(
+        image_size = 256,
+        patch_size = 32,
+        num_classes = 1000,
+        dim = 1024,
+        depth = 6,
+        heads = 16,
+        mlp_dim = 2048,
+        dropout = 0.1,
+        emb_dropout = 0.1
+    )
+
+    v = Extractor(v, return_embeddings_only = True)
+
+    model = π0(
+        dim = 512,
+        vit = v,
+        vit_dim = 1024,
+        dim_action_input = 6,
+        dim_joint_state = 12,
+        num_tokens = 20_000
+    )
+
+    vision = torch.randn(1, 1024, 512)
+
+    images = torch.randn(1, 3, 2, 256, 256)
+
+    commands = torch.randint(0, 20_000, (1, 1024))
+    joint_state = torch.randn(1, 12)
+    actions = torch.randn(1, 32, 6)
+
+    loss, _ = model(images, commands, joint_state, actions)
+    loss.backward()
+
+    # after much training
+
+    sampled_actions = model(images, commands, joint_state, trajectory_length = 32) # (1, 32, 6)
+
+    assert sampled_actions.shape == (1, 32, 6)