From 3b89fffb2da081b1fce4e7f59d1aaa0a7f4ca779 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Thu, 5 Dec 2024 10:54:20 +0100
Subject: [PATCH 01/27] Add HyenaDNA tests back

---
 ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py | 1 -
 ci/tests/test_hyena_dna/test_hyena_dna_model.py       | 6 ------
 2 files changed, 7 deletions(-)

diff --git a/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py b/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
index 5eedae9f..e76b9f5f 100644
--- a/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
+++ b/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
@@ -2,7 +2,6 @@
 import torch
 from helical import HyenaDNAConfig, HyenaDNAFineTuningModel
 
-@pytest.mark.skip(reason="Work in progress.")
 class TestHyenaDNAFineTuning:
     @pytest.fixture(params=["hyenadna-tiny-1k-seqlen", "hyenadna-tiny-1k-seqlen-d256"])
     def hyenaDNAFineTune(self, request):
diff --git a/ci/tests/test_hyena_dna/test_hyena_dna_model.py b/ci/tests/test_hyena_dna/test_hyena_dna_model.py
index 4df1b9d9..a9394140 100644
--- a/ci/tests/test_hyena_dna/test_hyena_dna_model.py
+++ b/ci/tests/test_hyena_dna/test_hyena_dna_model.py
@@ -4,8 +4,6 @@
 from helical.models.hyena_dna.model import HyenaDNA
 from helical.models.hyena_dna.hyena_dna_utils import HyenaDNADataset
 
-
-@pytest.mark.skip(reason="Work in progress.")
 @pytest.mark.parametrize("model_name, d_model, d_inner", [
     ("hyenadna-tiny-1k-seqlen", 128, 512),
     ("hyenadna-tiny-1k-seqlen-d256", 256, 1024)
@@ -24,8 +22,6 @@ def test_hyena_dna__valid_model_names(model_name, d_model, d_inner):
     assert configurer.config["d_model"] == d_model
     assert configurer.config["d_inner"] == d_inner
 
-
-@pytest.mark.skip(reason="Work in progress.")
 @pytest.mark.parametrize("model_name", [
     ("wrong_name")
 ])
@@ -43,8 +39,6 @@ def test_hyena_dna__invalid_model_names(model_name):
     with pytest.raises(ValueError):
         HyenaDNAConfig(model_name=model_name)
 
-
-@pytest.mark.skip(reason="Work in progress.")
 @pytest.mark.parametrize("input_sequence, expected_output", [
     # Valid DNA sequences
     ("", [0, 1]),

From bd3b59ee0371f42093526b809b1a2150620f68a0 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Thu, 5 Dec 2024 11:04:44 +0100
Subject: [PATCH 02/27] Add HyenaDNA notebook tests

---
 .github/workflows/main.yml    | 16 ++++++++++++----
 .github/workflows/release.yml |  2 +-
 2 files changed, 13 insertions(+), 5 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 21542afb..3fb8a56b 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -61,9 +61,17 @@ jobs:
         run: |
           python examples/run_models/run_uce.py
 
-      # - name: Execute Hyena
-      #   run: |
-      #     python examples/run_models/run_hyena_dna.py
+      - name: Execute Hyena
+        run: |
+          python examples/run_models/run_hyena_dna.py
+
+      - name: Execute Helix-mRNA
+        run: |
+          python examples/run_models/run_helix_mrna.py
+
+      - name: Execute Mamba2-mRNA
+        run: |
+          python examples/run_models/run_mamba2_mrna.py
 
       - name: Execute benchmarking
         run: |
@@ -97,4 +105,4 @@ jobs:
 
       - name: Run Notebooks
         run: |
-          pytest --durations=0 --nbmake $(find ./examples/notebooks -name "*.ipynb" ! -name "Hyena*")
+          pytest --durations=0 --nbmake ./examples/notebooks/*.ipynb
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index bd39750f..4fa373ad 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -100,7 +100,7 @@ jobs:
 
       - name: Run Notebooks
         run: |
-          pytest --durations=0 --nbmake $(find ./examples/notebooks -name "*.ipynb" ! -name "Hyena*")
+          pytest --durations=0 --nbmake ./examples/notebooks/*.ipynb
 
   release:
     needs: notebooks

From a2516c3f31afee08ff82712af1ae1d1e40543f7b Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Thu, 5 Dec 2024 12:21:27 +0100
Subject: [PATCH 03/27] Cleaup HyenaDNA and add fine_tuning_example

---
 .../fine_tune_models/fine_tune_hyena_dna.py   | 22 +++++
 .../run_models/configs/hyena_dna_config.yaml  |  1 +
 examples/run_models/run_hyena_dna.py          |  2 +-
 helical/models/hyena_dna/fine_tuning_model.py | 82 +++++++++----------
 helical/models/hyena_dna/hyena_dna_config.py  |  1 +
 helical/models/hyena_dna/hyena_dna_utils.py   | 16 +++-
 helical/models/hyena_dna/model.py             | 10 +--
 .../models/hyena_dna/standalone_hyenadna.py   | 27 +++---
 8 files changed, 95 insertions(+), 66 deletions(-)
 create mode 100644 examples/fine_tune_models/fine_tune_hyena_dna.py

diff --git a/examples/fine_tune_models/fine_tune_hyena_dna.py b/examples/fine_tune_models/fine_tune_hyena_dna.py
new file mode 100644
index 00000000..b940cd73
--- /dev/null
+++ b/examples/fine_tune_models/fine_tune_hyena_dna.py
@@ -0,0 +1,22 @@
+from helical import HyenaDNAFineTuningModel, HyenaDNAConfig
+import hydra
+from omegaconf import DictConfig
+
+@hydra.main(version_base=None, config_path="../run_models/configs", config_name="hyena_dna_config")
+def run_fine_tuning(cfg: DictConfig):
+    input_sequences = ["ACT"*20, "ATG"*20, "ATG"*20, "ACT"*20, "ATT"*20]
+    labels = [0, 2, 2, 0, 1]
+
+    hyena_dna_config = HyenaDNAConfig(**cfg)
+    hyena_dna_fine_tune = HyenaDNAFineTuningModel(hyena_config=hyena_dna_config, fine_tuning_head="classification", output_size=3)
+
+    train_dataset = hyena_dna_fine_tune.process_data(input_sequences)
+
+    hyena_dna_fine_tune.train(train_dataset=train_dataset, train_labels=labels)
+
+    outputs = hyena_dna_fine_tune.get_outputs(train_dataset)
+    
+    print(outputs.shape)
+
+if __name__ == "__main__":
+    run_fine_tuning()
diff --git a/examples/run_models/configs/hyena_dna_config.yaml b/examples/run_models/configs/hyena_dna_config.yaml
index 7aedc280..3f42367c 100644
--- a/examples/run_models/configs/hyena_dna_config.yaml
+++ b/examples/run_models/configs/hyena_dna_config.yaml
@@ -1,6 +1,7 @@
 model_name: "hyenadna-tiny-1k-seqlen-d256"
 n_layer: 2
 vocab_size: 12
+batch_size: 2
 resid_dropout: 0.0
 embed_dropout: 0.1
 fused_mlp: False,   
diff --git a/examples/run_models/run_hyena_dna.py b/examples/run_models/run_hyena_dna.py
index cd703114..c1a0d1d6 100644
--- a/examples/run_models/run_hyena_dna.py
+++ b/examples/run_models/run_hyena_dna.py
@@ -7,7 +7,7 @@ def run(cfg: DictConfig):
 
     hyena_config = HyenaDNAConfig(**cfg)
     model = HyenaDNA(configurer = hyena_config)   
-    sequence = 'ACTG' * int(1024/4)
+    sequence = ['ACTG' * int(1024/4)]
     tokenized_sequence = model.process_data(sequence)
     embeddings = model.get_embeddings(tokenized_sequence)
     print(embeddings.shape)
diff --git a/helical/models/hyena_dna/fine_tuning_model.py b/helical/models/hyena_dna/fine_tuning_model.py
index 90dbc18f..16336d14 100644
--- a/helical/models/hyena_dna/fine_tuning_model.py
+++ b/helical/models/hyena_dna/fine_tuning_model.py
@@ -21,26 +21,24 @@ class HyenaDNAFineTuningModel(HelicalBaseFineTuningModel, HyenaDNA):
     Example
     ----------
     ```python
-    from datasets import load_dataset
-    from helical import HyenaDNAConfig, HyenaDNAFineTuningModel
+    from helical import HyenaDNAFineTuningModel, HyenaDNAConfig
     import torch
 
     device = "cuda" if torch.cuda.is_available() else "cpu"
 
-    # Load a Hugging Face dataset and task type
-    ds = load_dataset("dataset", "task")
+    input_sequences = ["ACT"*20, "ATG"*20, "ATG"*20, "ACT"*20, "ATT"*20]
+    labels = [0, 2, 2, 0, 1]
 
-    # Define the desired configs
-    config = HyenaDNAConfig(device=device, batch_size=10)
+    hyena_dna_config = HyenaDNAConfig(batch_size=1, device=device)
+    hyena_dna_fine_tune = HyenaDNAFineTuningModel(hyena_config=hyena_dna_config, fine_tuning_head="classification", output_size=3)
 
-    # Define the fine-tuning model with the configs we instantiated above
-    hyena_fine_tune = HyenaDNAFineTuningModel(config, "classification", number_unique_outputs)
+    train_dataset = hyena_dna_fine_tune.process_data(input_sequences)
 
-    # Prepare the sequences for input to the model
-    input_dataset = hyena_fine_tune.process_data(ds["train"]["sequence"])
+    hyena_dna_fine_tune.train(train_dataset=train_dataset, train_labels=labels)
 
-    # train the fine-tuning model on some downstream task
-    hyena_fine_tune.train(input_dataset, ds["train"]["label"])
+    outputs = hyena_dna_fine_tune.get_outputs(train_dataset)
+    
+    print(outputs.shape)
     ```
     
     Parameters
@@ -54,7 +52,7 @@ class HyenaDNAFineTuningModel(HelicalBaseFineTuningModel, HyenaDNA):
     
     Methods
     -------
-    train(train_input_data: HyenaDNADataset, train_labels: list[int], validation_input_data: HyenaDNADataset = None, validation_labels: list[int] = None, optimizer: torch.optim, optimizer_params: dict, loss_function: torch.nn.modules.loss, epochs: int, lr_scheduler_params: dict = None)
+    train(train_dataset: HyenaDNADataset, train_labels: list[int], validation_dataset: HyenaDNADataset = None, validation_labels: list[int] = None, optimizer: torch.optim, optimizer_params: dict, loss_function: torch.nn.modules.loss, epochs: int, lr_scheduler_params: dict = None)
         Fine-tunes the Hyena-DNA model with different head modules.
     get_outputs(input_data: HyenaDNADataset) -> np.ndarray
         Get the outputs of the fine-tuned model.
@@ -77,9 +75,9 @@ def _forward(self, x):
 
     def train(        
         self,
-        train_input_data: HyenaDNADataset,
+        train_dataset: HyenaDNADataset,
         train_labels: list[int],     
-        validation_input_data: HyenaDNADataset = None,
+        validation_dataset: HyenaDNADataset = None,
         validation_labels: list[int] = None,
         optimizer: optim = optim.AdamW,
         optimizer_params: dict = {'lr': 0.0001}, 
@@ -90,11 +88,11 @@ def train(
 
         Parameters
         ----------
-        train_input_data : HyenaDNADataset
+        train_dataset : HyenaDNADataset
             A helical Hyena-DNA processed dataset for fine-tuning
         train_labels : list[int]
             The labels for the training data. These should be stored as unique per class integers.
-        validation_input_data : HyenaDNADataset, default=None
+        validation_dataset : HyenaDNADataset, default=None
             A helical Hyena-DNA processed dataset for per epoch validation. If this is not specified, no validation will be performed.
         validation_labels : list[int], default=None
             The labels for the validation data. These should be stored as unique per class integers.
@@ -108,15 +106,15 @@ def train(
         epochs : int, optional, default=10
             The number of epochs to train the model for.
         lr_scheduler_params : dict, default=None
-            The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, no scheduler will be used.
-            e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0, 'num_training_steps': 5 }
+            The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, a constant learning rate will be used.
+            e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0 }. num_steps will be calculated based on the number of epochs and the length of the training dataset.
         """
-        train_input_data.set_labels(train_labels)
-        train_data_loader = DataLoader(train_input_data, batch_size=self.config["batch_size"])
+        train_dataset.set_labels(train_labels)
+        train_dataloader = DataLoader(train_dataset, batch_size=self.config["batch_size"])
      
-        if validation_input_data is not None and validation_labels is not None:
-            validation_input_data.set_labels(validation_labels)
-            validation_data_loader = DataLoader(validation_input_data, batch_size=self.config["batch_size"])
+        if validation_dataset is not None and validation_labels is not None:
+            validation_dataset.set_labels(validation_labels)
+            validation_dataloader = DataLoader(validation_dataset, batch_size=self.config["batch_size"])
 
         self.to(self.config["device"])
         self.model.train()
@@ -125,16 +123,16 @@ def train(
 
         lr_scheduler = None
         if lr_scheduler_params is not None: 
-            lr_scheduler = get_scheduler(optimizer=optimizer, **lr_scheduler_params)
+            lr_scheduler = get_scheduler(optimizer=optimizer, num_training_steps=epochs*len(train_dataloader),  **lr_scheduler_params)
 
         logger.info("Starting Fine-Tuning")
         for i in range(epochs):
             batch_loss = 0.0
             batches_processed = 0
-            training_loop = tqdm(train_data_loader)
-            for input_data, labels in training_loop:
-                input_data = input_data.to(self.config["device"])
-                labels = labels.to(self.config["device"])
+            training_loop = tqdm(train_dataloader)
+            for batch in training_loop:
+                input_data = batch["input_ids"].to(self.config["device"])
+                labels = batch["labels"].to(self.config["device"])
                 optimizer.zero_grad()
                 output = self._forward(input_data)
                 loss = loss_function(output, labels)
@@ -146,17 +144,17 @@ def train(
                 training_loop.set_postfix({"loss": batch_loss/batches_processed})
                 training_loop.set_description(f"Fine-Tuning: epoch {i+1}/{epochs}")
             
-            if lr_scheduler is not None:
-                lr_scheduler.step()
+                if lr_scheduler is not None:
+                    lr_scheduler.step()
 
-            if validation_input_data is not None and validation_labels is not None:
+            if validation_dataset is not None and validation_labels is not None:
                 with torch.no_grad():
                     validation_batches_processed = 0
                     val_loss = 0.0
-                    validation_loop = tqdm(validation_data_loader, desc="Fine-Tuning Validation")
-                    for input_data, val_labels in validation_loop:
-                        input_data = input_data.to(self.config["device"])
-                        val_labels = val_labels.to(self.config["device"])
+                    validation_loop = tqdm(validation_dataloader, desc="Fine-Tuning Validation")
+                    for batch in validation_loop:
+                        input_data = batch["input_ids"].to(self.config["device"])
+                        val_labels = batch["labels"].to(self.config["device"])
                         output = self._forward(input_data)
                         validation_batches_processed += 1
                         val_loss += loss_function(output, val_labels).item()
@@ -165,12 +163,12 @@ def train(
             
     def get_outputs(
             self, 
-            input_data: HyenaDNADataset) -> np.ndarray:
+            dataset: HyenaDNADataset) -> np.ndarray:
         """Get the outputs of the fine-tuned model.
         
         Parameters
         ----------
-        input_data : HyenaDNADataset
+        dataset : HyenaDNADataset
             The input data to get the outputs for.
 
         Returns
@@ -178,7 +176,7 @@ def get_outputs(
         np.ndarray
             The outputs of the model
         """
-        data_loader = DataLoader(input_data, batch_size=self.config["batch_size"])
+        data_loader = DataLoader(dataset, batch_size=self.config["batch_size"])
 
         self.to(self.config["device"])
         self.model.eval()
@@ -186,9 +184,9 @@ def get_outputs(
 
         batch_loop = tqdm(data_loader)
         outputs = []
-        for batch, *labels in batch_loop:
-            batch.to(self.config["device"])
-            output = self._forward(batch)
+        for batch in batch_loop:
+            input_data = batch["input_ids"].to(self.config["device"])
+            output = self._forward(input_data)
             outputs.append(output.detach().cpu().numpy())
         
         return np.vstack(outputs)
\ No newline at end of file
diff --git a/helical/models/hyena_dna/hyena_dna_config.py b/helical/models/hyena_dna/hyena_dna_config.py
index 086f1452..3a3e5f9a 100644
--- a/helical/models/hyena_dna/hyena_dna_config.py
+++ b/helical/models/hyena_dna/hyena_dna_config.py
@@ -10,6 +10,7 @@ class HyenaDNAConfig():
     model_name : Literal["hyenadna-tiny-1k-seqlen", "hyenadna-tiny-1k-seqlen-d256"], optional, default="hyenadna-tiny-1k-seqlen"
         The name of the model.
     batch_size : int, optional, default=5
+        The batch size to use for all tasks.
     n_layer : int, optional, default=2
         The number of layers in the model.
     vocab_size : int, optional, default=12
diff --git a/helical/models/hyena_dna/hyena_dna_utils.py b/helical/models/hyena_dna/hyena_dna_utils.py
index a08b7ae9..5bd1878d 100644
--- a/helical/models/hyena_dna/hyena_dna_utils.py
+++ b/helical/models/hyena_dna/hyena_dna_utils.py
@@ -18,10 +18,18 @@ def __len__(self):
         return len(self.sequences)
 
     def __getitem__(self, idx):
-        if self.labels is None:
-            return self.sequences[idx]
-        else:
-            return self.sequences[idx], self.labels[idx]
+        seqs = self.sequences[idx]
+        
+        # Prepare output dictionary
+        output = {
+            'input_ids': seqs,
+        }
+        
+        # Add labels if they exist
+        if self.labels is not None:
+            output['labels'] = self.labels[idx]
+
+        return output
 
     def set_labels(self, labels):
         self.labels = labels
\ No newline at end of file
diff --git a/helical/models/hyena_dna/model.py b/helical/models/hyena_dna/model.py
index f770eae4..fe1c2887 100644
--- a/helical/models/hyena_dna/model.py
+++ b/helical/models/hyena_dna/model.py
@@ -59,7 +59,7 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
         self.tokenizer = CharacterTokenizer(
             characters=['A', 'C', 'G', 'T', 'N'],  # add DNA characters, N is uncertain
             model_max_length=self.config['max_length'] + 2,  # to account for special tokens, like EOS
-            add_special_tokens=False,  # we handle special tokens elsewhere
+            # add_special_tokens=False,  # we handle special tokens elsewhere
             padding_side='left', # since HyenaDNA is causal, we pad on the left
         )
 
@@ -117,12 +117,8 @@ def get_embeddings(self, dataset: HyenaDNADataset) -> torch.Tensor:
         train_data_loader = DataLoader(dataset, batch_size=self.config["batch_size"])
         with torch.inference_mode():
             embeddings = []
-            for input_data in tqdm(train_data_loader, desc="Getting embeddings"):
-                input_data = input_data.to(self.device)
+            for batch in tqdm(train_data_loader, desc="Getting embeddings"):
+                input_data = batch["input_ids"].to(self.device)
                 embeddings.append(self.model(input_data).detach().cpu().numpy())
         
-        # output = torch.stack(embeddings)
-        # other_dims = output.shape[2:]
-
-        # reshaped_tensor = output.view(-1, *other_dims)
         return np.vstack(embeddings)
diff --git a/helical/models/hyena_dna/standalone_hyenadna.py b/helical/models/hyena_dna/standalone_hyenadna.py
index a5100a0d..f49bd244 100644
--- a/helical/models/hyena_dna/standalone_hyenadna.py
+++ b/helical/models/hyena_dna/standalone_hyenadna.py
@@ -963,6 +963,18 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid
 
         mask_token = AddedToken("[MASK]", lstrip=True, rstrip=False)
 
+        self._vocab_str_to_int = {
+            "[CLS]": 0,
+            "[SEP]": 1,
+            "[BOS]": 2,
+            "[MASK]": 3,
+            "[PAD]": 4,
+            "[RESERVED]": 5,
+            "[UNK]": 6,
+            **{ch: i + 7 for i, ch in enumerate(characters)},
+        }
+        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
+
         super().__init__(
             bos_token=bos_token,
             eos_token=sep_token,
@@ -977,18 +989,6 @@ def __init__(self, characters: Sequence[str], model_max_length: int, padding_sid
             **kwargs,
         )
 
-        self._vocab_str_to_int = {
-            "[CLS]": 0,
-            "[SEP]": 1,
-            "[BOS]": 2,
-            "[MASK]": 3,
-            "[PAD]": 4,
-            "[RESERVED]": 5,
-            "[UNK]": 6,
-            **{ch: i + 7 for i, ch in enumerate(characters)},
-        }
-        self._vocab_int_to_str = {v: k for k, v in self._vocab_str_to_int.items()}
-
     @property
     def vocab_size(self) -> int:
         return len(self._vocab_str_to_int)
@@ -1063,6 +1063,9 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs):
         with open(cfg_file, "w") as f:
             json.dump(cfg, f, indent=4)
 
+    def get_vocab(self) -> Dict[str, int]:
+        return self._vocab_str_to_int
+
     @classmethod
     def from_pretrained(cls, save_directory: Union[str, os.PathLike], **kwargs):
         cfg_file = Path(save_directory) / "tokenizer_config.json"

From d730f88c9154fa059b8b32f711352fe38ceec0a4 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Thu, 5 Dec 2024 12:23:14 +0100
Subject: [PATCH 04/27] Adjust HyenaDNA tests to reflect model changes

---
 ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py | 2 +-
 examples/notebooks/HyenaDNA-Fine-Tuning.ipynb         | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py b/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
index e76b9f5f..48d10a70 100644
--- a/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
+++ b/ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
@@ -18,6 +18,6 @@ def mock_data(self, hyenaDNAFineTune):
 
     def test_output_dimensionality_of_fine_tuned_model(self, hyenaDNAFineTune, mock_data):
         input_sequences, labels = mock_data
-        hyenaDNAFineTune.train(train_input_data=input_sequences, train_labels=labels, validation_input_data=input_sequences, validation_labels=labels)
+        hyenaDNAFineTune.train(train_dataset=input_sequences, train_labels=labels, validation_dataset=input_sequences, validation_labels=labels)
         outputs = hyenaDNAFineTune.get_outputs(input_sequences)
         assert outputs.shape == (len(input_sequences), 1)
\ No newline at end of file
diff --git a/examples/notebooks/HyenaDNA-Fine-Tuning.ipynb b/examples/notebooks/HyenaDNA-Fine-Tuning.ipynb
index b0a7ff78..0aab8100 100644
--- a/examples/notebooks/HyenaDNA-Fine-Tuning.ipynb
+++ b/examples/notebooks/HyenaDNA-Fine-Tuning.ipynb
@@ -208,7 +208,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 9,
+   "execution_count": null,
    "metadata": {},
    "outputs": [
     {
@@ -239,7 +239,7 @@
     }
    ],
    "source": [
-    "hyena_fine_tune.train(train_input_data=train_dataset, train_labels=dataset_train[\"label\"], validation_input_data=test_dataset, validation_labels=dataset_test[\"label\"], epochs=10, optimizer_params={\"lr\": 2e-6}, lr_scheduler_params={\"name\": \"linear\", \"num_warmup_steps\": 0, \"num_training_steps\": 10})"
+    "hyena_fine_tune.train(train_dataset=train_dataset, train_labels=dataset_train[\"label\"], validation_dataset=test_dataset, validation_labels=dataset_test[\"label\"], epochs=10, optimizer_params={\"lr\": 2e-6}, lr_scheduler_params={\"name\": \"linear\", \"num_warmup_steps\": 0})"
    ]
   },
   {

From 46de7ac8e434535a7f6b477877a7aacd5f84fdad Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Thu, 5 Dec 2024 16:48:01 +0100
Subject: [PATCH 05/27] Add emb_mode to scGPT

---
 examples/run_models/configs/scgpt_config.yaml |  1 +
 examples/run_models/run_scgpt.py              |  7 +++--
 helical/__init__.py                           | 30 +++++++++----------
 helical/models/scgpt/model.py                 | 28 +++++++++++++----
 helical/models/scgpt/scgpt_config.py          |  3 ++
 5 files changed, 45 insertions(+), 24 deletions(-)

diff --git a/examples/run_models/configs/scgpt_config.yaml b/examples/run_models/configs/scgpt_config.yaml
index 4add9a58..cd518a6f 100644
--- a/examples/run_models/configs/scgpt_config.yaml
+++ b/examples/run_models/configs/scgpt_config.yaml
@@ -1,5 +1,6 @@
 pad_token: "<pad>"
 batch_size: 24
+emb_mode: "gene"
 fast_transformer: True
 nlayers: 12
 nheads: 8
diff --git a/examples/run_models/run_scgpt.py b/examples/run_models/run_scgpt.py
index 7ce27793..c3091b57 100644
--- a/examples/run_models/run_scgpt.py
+++ b/examples/run_models/run_scgpt.py
@@ -10,12 +10,13 @@ def run(cfg: DictConfig):
     scgpt_config = scGPTConfig(**cfg)
     scgpt = scGPT(configurer = scgpt_config)
 
+    # print(scgpt.model)
     # either load via huggingface
-    # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
-    # ann_data = get_anndata_from_hf_dataset(hf_dataset)
+    hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
+    ann_data = get_anndata_from_hf_dataset(hf_dataset)
 
     # or load directly
-    ann_data = ad.read_h5ad("./yolksac_human.h5ad")
+    # ann_data = ad.read_h5ad("./yolksac_human.h5ad")
 
     data = scgpt.process_data(ann_data[:10])
     embeddings = scgpt.get_embeddings(data)
diff --git a/helical/__init__.py b/helical/__init__.py
index fa1a55ee..35a5999f 100644
--- a/helical/__init__.py
+++ b/helical/__init__.py
@@ -1,29 +1,29 @@
 import os
 import logging
 
-logging.captureWarnings(True)
+# logging.captureWarnings(True)
 
-class InfoAndErrorFilter(logging.Filter):
-    def filter(self, record):
-        return record.levelno in (logging.INFO, logging.ERROR)
+# class InfoAndErrorFilter(logging.Filter):
+#     def filter(self, record):
+#         return record.levelno in (logging.INFO, logging.ERROR)
 
-for handler in logging.root.handlers[:]:
-    logging.root.removeHandler(handler)
+# for handler in logging.root.handlers[:]:
+#     logging.root.removeHandler(handler)
 
-logger = logging.getLogger()
-logger.setLevel(logging.INFO)
+# logger = logging.getLogger()
+# logger.setLevel(logging.INFO)
 
-handler = logging.StreamHandler()
-handler.setLevel(logging.INFO) 
+# handler = logging.StreamHandler()
+# handler.setLevel(logging.INFO) 
 
-handler.addFilter(InfoAndErrorFilter())
+# handler.addFilter(InfoAndErrorFilter())
 
-formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
-handler.setFormatter(formatter)
+# formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
+# handler.setFormatter(formatter)
 
-logger.addHandler(handler)
+# logger.addHandler(handler)
 
-os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
+# os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
 
 
 from .models.uce.model import UCEConfig, UCE
diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 2f7ac888..69e9b83b 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -101,7 +101,7 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
             pad_value=self.config["pad_value"],
             do_mlm=False,
             do_binning=True,
-            max_length=1200,
+            max_length=self.config["MAX_LENGTH"],
             sampling=True,
             keep_first_n_tokens=1,
         )
@@ -116,9 +116,16 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
 
         device = next(self.model.parameters()).device
 
-        cell_embeddings = np.zeros(
-            (len(dataset), self.config["embsize"]), dtype=np.float32
-        )
+        # provision numpy ndarray for gene, cell and cls embeddings
+        if self.config["emb_mode"] == "gene":
+            cell_embeddings = np.zeros(
+                (len(dataset), self.config["MAX_LENGTH"]-1, self.config["embsize"]), dtype=np.float32
+            )
+        else:
+            cell_embeddings = np.zeros(
+                (len(dataset), self.config["embsize"]), dtype=np.float32
+            )
+        
         with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): #torch.autocast(device_type=str(device),enabled=True): # torch.cuda.amp.autocast(enabled=True):
             count = 0
             for data_dict in tqdm(data_loader, desc="Embedding cells"):
@@ -135,8 +142,17 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
                     else None,
                 )
 
-                embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
-                embeddings = embeddings.cpu().numpy()
+                if self.config["emb_mode"] == "cls":
+                    embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
+                    embeddings = embeddings.cpu().numpy()
+                elif self.config["emb_mode"] == "cell":
+                    embeddings = embeddings[: 1:, :] # get all embeddings except the <cls> position
+                    embeddings = torch.mean(embeddings, dim=1) # mean embeddings to get cell embedding
+                    embeddings = embeddings.cpu().numpy()
+                elif self.config["emb_mode"] == "gene":
+                    embeddings = embeddings[:, 1:, :] # get all embeddings except the <cls> position
+                    embeddings = embeddings.cpu().numpy() # keep all gene embeddings
+                    
                 cell_embeddings[count : count + len(embeddings)] = embeddings
                 count += len(embeddings)
         cell_embeddings = cell_embeddings / np.linalg.norm(
diff --git a/helical/models/scgpt/scgpt_config.py b/helical/models/scgpt/scgpt_config.py
index 776c262f..60dcda85 100644
--- a/helical/models/scgpt/scgpt_config.py
+++ b/helical/models/scgpt/scgpt_config.py
@@ -54,6 +54,7 @@ def __init__(
             self, 
             pad_token: str = "<pad>",
             batch_size: int = 24,
+            emb_mode: Literal["cls", "cell", "gene"] = "cls",
             fast_transformer: bool = True,
             nlayers: int = 12,
             nheads: int = 8,
@@ -81,6 +82,7 @@ def __init__(
             "list_of_files_to_download": list_of_files_to_download,
             "pad_token": pad_token,
             "batch_size": batch_size,
+            "emb_mode": emb_mode,
             "fast_transformer": fast_transformer,
             "nlayers": nlayers,
             "nheads": nheads,
@@ -94,4 +96,5 @@ def __init__(
             "accelerator": accelerator,
             "device": device,
             "use_fast_transformer": use_fast_transformer,
+            "MAX_LENGTH": 1200
             }
\ No newline at end of file

From 226839af206fdf7cb331f6267c8ab4cd4aa8be9e Mon Sep 17 00:00:00 2001
From: giogix2 <giovanni@Helical-AI-A6000>
Date: Fri, 6 Dec 2024 12:15:06 +0100
Subject: [PATCH 06/27] Modify pyproject.toml and make mamba_ssm optional.

---
 README.md      | 17 +++++++++++++++++
 pyproject.toml | 10 ++++++++--
 setup.py       | 38 --------------------------------------
 3 files changed, 25 insertions(+), 40 deletions(-)
 delete mode 100644 setup.py

diff --git a/README.md b/README.md
index 5dbe5d7e..c6462d67 100644
--- a/README.md
+++ b/README.md
@@ -47,6 +47,23 @@ To install the latest Helical package, you can run the command below:
 pip install --upgrade git+https://github.com/helicalAI/helical.git
 ```
 
+Alternatively, clone the repo and install it:
+```
+git clone https://github.com/helicalAI/helical.git
+pip install .
+```
+
+[Optional] To install mamba-ssm and causal-conv1d use the command below:
+```
+pip install helical[mamba-ssm]
+```
+or in case you're installing from the Helical repo cloned locally:
+```
+pip install .[mamba-ssm]
+```
+
+Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d.
+
 ### Singularity (Optional)
 If you desire to run your code in a singularity file, you can use the [singularity.def](./singularity.def) file and build an apptainer with it:
 ```
diff --git a/pyproject.toml b/pyproject.toml
index 5dd2e19b..ccfa5f81 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -45,13 +45,19 @@ dependencies = [
     'louvain==0.8.2',
     'pyensembl',
     'datasets==2.20.0',
-    'mamba-ssm==2.2.2',
+]
+
+[project.optional-dependencies]
+mamba-ssm = [
+    'mamba-ssm @ git+https://github.com/state-spaces/mamba.git@v2.2.3',
     'causal-conv1d==1.4.0',
 ]
 
+[tool.hatch.metadata]
+allow-direct-references = true # Allows installing package from git
 
 [project.urls]
 Homepage = "https://github.com/helicalAI/helical"
 Issues = "https://github.com/helicalAI/helical/issues"
 Documentation = "https://helical.readthedocs.io/"
-Repository = "https://github.com/helicalAI/helical"
+Repository = "https://github.com/helicalAI/helical"
\ No newline at end of file
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 73a5c5b1..00000000
--- a/setup.py
+++ /dev/null
@@ -1,38 +0,0 @@
-from setuptools import setup, find_packages
-
-setup(
-    name='helical',
-    url='https://github.com/helicalAI/helical-package.git',
-    author='Benoit Putzeys, Maxime Allard',
-    author_email='benoit@helical-ai.com, maxime@helical-ai.com',
-    packages=find_packages(),
-    install_requires=[
-        'requests==2.32.2',
-        'pandas==2.2.2',
-        'anndata==0.10.7',
-        'numpy==1.26.4',
-        'scikit-learn>=1.2.2',
-        'scipy==1.13.1',
-        'gitpython==3.1.43',
-        'torch>=2.0.0,<=2.3.0',
-        'torchvision>=0.15.0,<=0.18.0',
-        'accelerate==0.29.3',
-        'transformers==4.45.1',
-        'loompy==3.0.7',
-        'scib==1.1.5',
-        'scikit-misc==0.3.1',
-        'torchtext>=0.15.0,<=0.18.0',
-        'azure-identity==1.16.0',
-        'azure-storage-blob==12.19.1',
-        'azure-core==1.30.1',
-        'einops==0.8.0',
-        'omegaconf==2.3.0',
-        'hydra-core==1.3.2',
-        'tensorflow>=2.15.0,<=2.17.0',
-        'louvain==0.8.2',
-        'pyensembl',
-        'datasets==2.20.0',
-        'mamba-ssm==2.2.2',
-        'causal-conv1d==1.4.0'
-    ],  
-)

From b2b599d7b8fbbb6bc91762b05eda4a5d7269d6d5 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Fri, 6 Dec 2024 14:48:55 +0100
Subject: [PATCH 07/27] Add gene, cls and cell embedding to scGPT

---
 .github/workflows/release.yml           |  6 ++--
 ci/tests/test_scgpt/test_scgpt_model.py | 35 ++++++++++++--------
 examples/run_models/run_scgpt.py        |  8 ++---
 helical/models/scgpt/model.py           | 43 ++++++++++++++++++-------
 4 files changed, 59 insertions(+), 33 deletions(-)

diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 4fa373ad..f964c1b3 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -64,9 +64,9 @@ jobs:
         run: |
           python examples/run_models/run_uce.py
 
-      # - name: Execute Hyena
-      #   run: |
-      #     python examples/run_models/run_hyena_dna.py
+      - name: Execute Hyena
+        run: |
+          python examples/run_models/run_hyena_dna.py
 
       - name: Execute benchmarking
         run: |
diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py
index c9ba15dc..41b5eaaa 100644
--- a/ci/tests/test_scgpt/test_scgpt_model.py
+++ b/ci/tests/test_scgpt/test_scgpt_model.py
@@ -12,7 +12,7 @@ class TestSCGPTModel:
 
     # Create a dummy AnnData object
     data = AnnData()
-    data.var["gene_names"] = ['SAMD11', 'PLEKHN1', "NOT_IN_VOCAB", "<pad>", 'HES4']
+    data.var["gene_names"] = ['SAMD11', 'PLEKHN1', "NOT_IN_VOCAB", 'HES4']
     data.obs["cell_type"] = ["CD4 T cells"]
     
     vocab = {
@@ -23,7 +23,7 @@ class TestSCGPTModel:
     }
     scgpt.vocab = GeneVocab.from_dict(vocab)
     
-    data.X = [[1, 2, 5, 6, 0]]
+    data.X = [[1, 2, 5, 6]]
     
     def test_process_data(self):
         dataset = self.scgpt.process_data(self.data, gene_names = "gene_names")
@@ -33,16 +33,16 @@ def test_process_data(self):
 
         # asserts that all the genes in gene_names have been correctly translated 
         # to the corresponding ids based on the vocabulary
-        assert (self.data.var["id_in_vocab"] == [0, 1, -1, 3, 2]).all()
+        assert (self.data.var["id_in_vocab"] == [0, 1, -1, 2]).all()
 
         # make sure that the genes not present in the vocabulary are filtered out
         # meaning -1 is not present in the gene_ids
-        assert (dataset.gene_ids == [0, 1, 3, 2]).all()
+        assert (dataset.gene_ids == [0, 1, 2]).all()
 
         # ensure that the default index of the vocabulary is set to the id of the pad token
         assert self.scgpt.vocab.get_default_index() == 3
 
-        assert (dataset.count_matrix == [1, 2, 6, 0]).all()
+        assert (dataset.count_matrix == [1, 2, 6]).all()
 
     def test_correct_handling_of_batch_ids(self):
         batch_id_array = [1]
@@ -51,20 +51,27 @@ def test_correct_handling_of_batch_ids(self):
         assert (dataset.batch_ids == batch_id_array).all()
 
     def test_direct_assignment_of_genes_to_index(self):
-        self.data.var.index = ['SAMD11', 'PLEKHN1', "NOT_IN_VOCAB", "<pad>", 'HES4']
+        self.data.var.index = ['SAMD11', 'PLEKHN1', "NOT_IN_VOCAB", 'HES4']
         self.scgpt.process_data(self.data, gene_names = "index")
         
         # as set above, the gene column can also be direclty assigned to the index column
         assert self.scgpt.gene_names == "index"
         assert self.scgpt.gene_names in self.data.var
 
-
-    def test_get_embeddings(self):
+    # test get_embeddings for all three embedding modes
+    @pytest.mark.parametrize("emb_mode", ["cell", "gene", "cls"])
+    def test_get_embeddings(self, emb_mode):
+        self.scgpt.config["emb_mode"] = emb_mode
         dataset = self.scgpt.process_data(self.data, gene_names = "gene_names")
         embeddings = self.scgpt.get_embeddings(dataset)
-        assert embeddings.shape == (1, 512)
+        if emb_mode == "gene":
+            assert list(embeddings[0].keys()) == ["SAMD11", "PLEKHN1", "HES4"]
+            for key in embeddings[0].keys():
+                assert len(embeddings[0][key]) == 512
+        else:
+            assert embeddings.shape == (1, 512)
 
-    dummy_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
+    dummy_data = ad.read_h5ad("data/cell_type_sample.h5ad")
     @pytest.mark.parametrize("data, gene_names, batch_labels", 
                              [
                                 #  missing gene_names in data.var
@@ -77,11 +84,11 @@ def test_ensure_data_validity__key_error(self, data, gene_names, batch_labels):
         with pytest.raises(KeyError):
             self.scgpt.ensure_data_validity(data, gene_names, batch_labels)
     
-    err_np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
+    err_np_arr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
     err_np_arr_data.X.dtype=float
     err_np_arr_data.X[0,0] = 0.5
 
-    err_csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
+    err_csr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
     err_csr_data.X = csr_matrix(np.random.rand(100, 5), dtype=np.float32)
     @pytest.mark.parametrize("data",
                              [
@@ -95,8 +102,8 @@ def test_ensure_data_validity__value_error(self, data):
             self.scgpt.ensure_data_validity(data, "index", False)
         assert "total_counts" in data.obs
 
-    np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
-    csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
+    np_arr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
+    csr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
     csr_data.X = csr_matrix(np.random.poisson(1, size=(100, 5)), dtype=np.float32)
     @pytest.mark.parametrize("data",
                              [
diff --git a/examples/run_models/run_scgpt.py b/examples/run_models/run_scgpt.py
index c3091b57..666bad23 100644
--- a/examples/run_models/run_scgpt.py
+++ b/examples/run_models/run_scgpt.py
@@ -12,16 +12,16 @@ def run(cfg: DictConfig):
 
     # print(scgpt.model)
     # either load via huggingface
-    hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
-    ann_data = get_anndata_from_hf_dataset(hf_dataset)
+    # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
+    # ann_data = get_anndata_from_hf_dataset(hf_dataset)
 
     # or load directly
-    # ann_data = ad.read_h5ad("./yolksac_human.h5ad")
+    ann_data = ad.read_h5ad("./10k_pbmcs_proc.h5ad")
 
     data = scgpt.process_data(ann_data[:10])
     embeddings = scgpt.get_embeddings(data)
 
-    print(embeddings.shape)
+    print(embeddings)
 
 if __name__ == "__main__":
     run()
\ No newline at end of file
diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 69e9b83b..5a56d15d 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -4,6 +4,7 @@
 from helical.models.base_models import HelicalRNAModel
 from helical.models.scgpt.scgpt_config import scGPTConfig
 import numpy as np
+import pandas as pd
 from anndata import AnnData
 import logging
 from accelerate import Accelerator
@@ -118,9 +119,9 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
 
         # provision numpy ndarray for gene, cell and cls embeddings
         if self.config["emb_mode"] == "gene":
-            cell_embeddings = np.zeros(
-                (len(dataset), self.config["MAX_LENGTH"]-1, self.config["embsize"]), dtype=np.float32
-            )
+            # create dictionary mapping gene id to gene name, can't seem to find one in the vocab object
+            id_gene_dict = {i: gene for i, gene in enumerate(self.vocab.get_itos())}
+            gene_embs = []
         else:
             cell_embeddings = np.zeros(
                 (len(dataset), self.config["embsize"]), dtype=np.float32
@@ -130,6 +131,7 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
             count = 0
             for data_dict in tqdm(data_loader, desc="Embedding cells"):
                 input_gene_ids = data_dict["gene"].to(device)
+
                 src_key_padding_mask = input_gene_ids.eq(
                     self.vocab[self.config["pad_token"]]
                 )
@@ -150,16 +152,33 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
                     embeddings = torch.mean(embeddings, dim=1) # mean embeddings to get cell embedding
                     embeddings = embeddings.cpu().numpy()
                 elif self.config["emb_mode"] == "gene":
-                    embeddings = embeddings[:, 1:, :] # get all embeddings except the <cls> position
-                    embeddings = embeddings.cpu().numpy() # keep all gene embeddings
-                    
-                cell_embeddings[count : count + len(embeddings)] = embeddings
-                count += len(embeddings)
-        cell_embeddings = cell_embeddings / np.linalg.norm(
-            cell_embeddings, axis=1, keepdims=True
-        )
+                    embeddings = embeddings[:, 1:, :].cpu().numpy() # get all embeddings except the <cls> position
+                    gene_ids = data_dict["gene"].cpu().numpy()
+                    series = []
+
+                    # create a dictionary with gene name to gene embedding mappings and create pd series for each cell in batch
+                    for i, embedding in enumerate(embeddings):
+                        dict = {}
+                        for j, gene in enumerate(embedding, 1):
+                            if data_dict["gene"][i][j] != self.vocab[self.config["pad_token"]]:
+                                print(gene_ids[i][j])
+                                dict[id_gene_dict[gene_ids[i][j]]] = gene / np.linalg.norm(gene)
+                        
+                        series.append(pd.Series(dict))
+                if self.config["emb_mode"] != "gene":   
+                    cell_embeddings[count : count + len(embeddings)] = embeddings
+                    count += len(embeddings)
+                else:
+                    gene_embs.extend(series)
+        if self.config["emb_mode"] != "gene":
+            cell_embeddings = cell_embeddings / np.linalg.norm(
+                cell_embeddings, axis=1, keepdims=True
+            )
+
+            return cell_embeddings
+        else:
+            return gene_embs
 
-        return cell_embeddings
     
     def process_data(self,
                      adata: AnnData, 

From 169faf8ad5b7fcb34b076e10c9ede42cbdbe31f6 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Fri, 6 Dec 2024 17:19:30 +0100
Subject: [PATCH 08/27] Change unit tests to be CPU bound and integration tests
 to be GPU bound

---
 .github/workflows/main.yml                    | 32 ++++++++++++-
 .github/workflows/release.yml                 | 46 +++++++++++++++++--
 .../run_models/configs/geneformer_config.yaml |  2 +-
 .../run_models/configs/hyena_dna_config.yaml  |  2 +-
 examples/run_models/configs/scgpt_config.yaml |  2 +-
 examples/run_models/configs/uce_config.yaml   |  2 +-
 6 files changed, 77 insertions(+), 9 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 3fb8a56b..d3007d7b 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -8,6 +8,8 @@ on:
 jobs:
   tests:
     runs-on: self-hosted
+    env:
+      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -29,7 +31,7 @@ jobs:
 
       - name: Execute unittests
         run: |
-          pytest --cov-report=html:html_cov --cov-branch --cov-report term --cov=helical ci/
+          CUDA_VISIBLE_DEVICES=-1 pytest --cov-report=html:html_cov --cov-branch --cov-report term --cov=helical ci/
       
       - name: Upload coverage report
         uses: actions/upload-artifact@v4
@@ -49,30 +51,58 @@ jobs:
         run: |
           python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048"
 
+      - name: Fine-tune Geneformer v1
+        run: |
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+
       - name: Execute Geneformer v2
         run: |
           python examples/run_models/run_geneformer.py ++model_name="gf-12L-95M-i4096"
 
+      - name: Fine-tune Geneformer v2
+        run: |
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+
       - name: Execute scGPT
         run: |
           python examples/run_models/run_scgpt.py
 
+      - name: Fine-tune scGPT
+        run: |
+          python examples/fine_tune_models/fine_tune_scgpt.py
+
       - name: Execute UCE
         run: |
           python examples/run_models/run_uce.py
 
+      - name: Fine-tune UCE
+        run: |
+          python examples/fine_tune_models/fine_tune_UCE.py
+
       - name: Execute Hyena
         run: |
           python examples/run_models/run_hyena_dna.py
 
+      - name: Execute Hyena
+        run: |
+          python examples/fine_tune_models/fine_tune_hyena_dna.py
+
       - name: Execute Helix-mRNA
         run: |
           python examples/run_models/run_helix_mrna.py
 
+      - name: Fine-tune Helix-mRNA
+        run: |
+          python examples/fine_tune_models/fine_tune_helix_mrna.py
+
       - name: Execute Mamba2-mRNA
         run: |
           python examples/run_models/run_mamba2_mrna.py
 
+      - name: Fine-tune Mamba2-mRNA
+        run: |
+          python examples/fine_tune_models/fine_tune_mamba2_mrna.py
+
       - name: Execute benchmarking
         run: |
           pip install scanorama
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 4fa373ad..b7138f60 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -11,6 +11,8 @@ permissions:
 jobs:
   tests:
     runs-on: self-hosted
+    env:
+      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -32,7 +34,7 @@ jobs:
 
       - name: Execute unittests
         run: |
-          pytest --cov-report=html:html_cov --cov-branch --cov-report term --cov=helical ci/
+          CUDA_VISIBLE_DEVICES=-1 pytest --cov-report=html:html_cov --cov-branch --cov-report term --cov=helical ci/
       
       - name: Upload coverage report
         uses: actions/upload-artifact@v4
@@ -52,21 +54,57 @@ jobs:
         run: |
           python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048"
 
+      - name: Fine-tune Geneformer v1
+        run: |
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+
       - name: Execute Geneformer v2
         run: |
           python examples/run_models/run_geneformer.py ++model_name="gf-12L-95M-i4096"
 
+      - name: Fine-tune Geneformer v2
+        run: |
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+
       - name: Execute scGPT
         run: |
           python examples/run_models/run_scgpt.py
 
+      - name: Fine-tune scGPT
+        run: |
+          python examples/fine_tune_models/fine_tune_scgpt.py
+
       - name: Execute UCE
         run: |
           python examples/run_models/run_uce.py
 
-      # - name: Execute Hyena
-      #   run: |
-      #     python examples/run_models/run_hyena_dna.py
+      - name: Fine-tune UCE
+        run: |
+          python examples/fine_tune_models/fine_tune_UCE.py
+
+      - name: Execute Hyena
+        run: |
+          python examples/run_models/run_hyena_dna.py
+
+      - name: Execute Hyena
+        run: |
+          python examples/fine_tune_models/fine_tune_hyena_dna.py
+
+      - name: Execute Helix-mRNA
+        run: |
+          python examples/run_models/run_helix_mrna.py
+
+      - name: Fine-tune Helix-mRNA
+        run: |
+          python examples/fine_tune_models/fine_tune_helix_mrna.py
+
+      - name: Execute Mamba2-mRNA
+        run: |
+          python examples/run_models/run_mamba2_mrna.py
+
+      - name: Fine-tune Mamba2-mRNA
+        run: |
+          python examples/fine_tune_models/fine_tune_mamba2_mrna.py
 
       - name: Execute benchmarking
         run: |
diff --git a/examples/run_models/configs/geneformer_config.yaml b/examples/run_models/configs/geneformer_config.yaml
index 7ee91173..24cbc512 100644
--- a/examples/run_models/configs/geneformer_config.yaml
+++ b/examples/run_models/configs/geneformer_config.yaml
@@ -2,5 +2,5 @@ model_name: "gf-12L-30M-i2048"
 batch_size: 24
 emb_layer: -1
 emb_mode: "cell"
-device: "cpu"
+device: "cuda"
 accelerator: False
\ No newline at end of file
diff --git a/examples/run_models/configs/hyena_dna_config.yaml b/examples/run_models/configs/hyena_dna_config.yaml
index 3f42367c..bd3967c4 100644
--- a/examples/run_models/configs/hyena_dna_config.yaml
+++ b/examples/run_models/configs/hyena_dna_config.yaml
@@ -11,7 +11,7 @@ checkpoint_mixer: False
 checkpoint_mlp: False
 pad_vocab_size_multiple: 8
 return_hidden_state: True
-device: "cpu"
+device: "cuda"
 layer: 
     "_name_": "hyena"
     "emb_dim": 5
diff --git a/examples/run_models/configs/scgpt_config.yaml b/examples/run_models/configs/scgpt_config.yaml
index 4add9a58..518ac78b 100644
--- a/examples/run_models/configs/scgpt_config.yaml
+++ b/examples/run_models/configs/scgpt_config.yaml
@@ -11,5 +11,5 @@ mask_value: -1
 pad_value: -2
 world_size: 8
 accelerator: False
-device: "cpu"
+device: "cuda"
 use_fast_transformer: False
\ No newline at end of file
diff --git a/examples/run_models/configs/uce_config.yaml b/examples/run_models/configs/uce_config.yaml
index 8d8ad18c..4f9d2827 100644
--- a/examples/run_models/configs/uce_config.yaml
+++ b/examples/run_models/configs/uce_config.yaml
@@ -14,5 +14,5 @@ output_dim: 1280
 d_hid: 5120
 token_dim: 5120
 multi_gpu: False
-device: "cpu"
+device: "cuda"
 accelerator: False
\ No newline at end of file

From 967cf8323ce0aec7d5421e08b338856e8fc93593 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Fri, 6 Dec 2024 14:54:26 +0100
Subject: [PATCH 09/27] Update docstrings

---
 ci/tests/test_scgpt/test_scgpt_model.py       | 10 +++----
 helical/__init__.py                           | 30 +++++++++----------
 .../models/helix_mrna/fine_tuning_model.py    |  6 ++--
 helical/models/scgpt/fine_tuning_model.py     |  9 ++++--
 helical/models/scgpt/model.py                 |  8 +++--
 5 files changed, 36 insertions(+), 27 deletions(-)

diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py
index 41b5eaaa..c96397fa 100644
--- a/ci/tests/test_scgpt/test_scgpt_model.py
+++ b/ci/tests/test_scgpt/test_scgpt_model.py
@@ -71,7 +71,7 @@ def test_get_embeddings(self, emb_mode):
         else:
             assert embeddings.shape == (1, 512)
 
-    dummy_data = ad.read_h5ad("data/cell_type_sample.h5ad")
+    dummy_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
     @pytest.mark.parametrize("data, gene_names, batch_labels", 
                              [
                                 #  missing gene_names in data.var
@@ -84,11 +84,11 @@ def test_ensure_data_validity__key_error(self, data, gene_names, batch_labels):
         with pytest.raises(KeyError):
             self.scgpt.ensure_data_validity(data, gene_names, batch_labels)
     
-    err_np_arr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
+    err_np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
     err_np_arr_data.X.dtype=float
     err_np_arr_data.X[0,0] = 0.5
 
-    err_csr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
+    err_csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
     err_csr_data.X = csr_matrix(np.random.rand(100, 5), dtype=np.float32)
     @pytest.mark.parametrize("data",
                              [
@@ -102,8 +102,8 @@ def test_ensure_data_validity__value_error(self, data):
             self.scgpt.ensure_data_validity(data, "index", False)
         assert "total_counts" in data.obs
 
-    np_arr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
-    csr_data = ad.read_h5ad("data/cell_type_sample.h5ad")
+    np_arr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
+    csr_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
     csr_data.X = csr_matrix(np.random.poisson(1, size=(100, 5)), dtype=np.float32)
     @pytest.mark.parametrize("data",
                              [
diff --git a/helical/__init__.py b/helical/__init__.py
index 35a5999f..fa1a55ee 100644
--- a/helical/__init__.py
+++ b/helical/__init__.py
@@ -1,29 +1,29 @@
 import os
 import logging
 
-# logging.captureWarnings(True)
+logging.captureWarnings(True)
 
-# class InfoAndErrorFilter(logging.Filter):
-#     def filter(self, record):
-#         return record.levelno in (logging.INFO, logging.ERROR)
+class InfoAndErrorFilter(logging.Filter):
+    def filter(self, record):
+        return record.levelno in (logging.INFO, logging.ERROR)
 
-# for handler in logging.root.handlers[:]:
-#     logging.root.removeHandler(handler)
+for handler in logging.root.handlers[:]:
+    logging.root.removeHandler(handler)
 
-# logger = logging.getLogger()
-# logger.setLevel(logging.INFO)
+logger = logging.getLogger()
+logger.setLevel(logging.INFO)
 
-# handler = logging.StreamHandler()
-# handler.setLevel(logging.INFO) 
+handler = logging.StreamHandler()
+handler.setLevel(logging.INFO) 
 
-# handler.addFilter(InfoAndErrorFilter())
+handler.addFilter(InfoAndErrorFilter())
 
-# formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
-# handler.setFormatter(formatter)
+formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
+handler.setFormatter(formatter)
 
-# logger.addHandler(handler)
+logger.addHandler(handler)
 
-# os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
+os.environ['TRANSFORMERS_VERBOSITY'] = 'error'
 
 
 from .models.uce.model import UCEConfig, UCE
diff --git a/helical/models/helix_mrna/fine_tuning_model.py b/helical/models/helix_mrna/fine_tuning_model.py
index 7099530d..485620b5 100644
--- a/helical/models/helix_mrna/fine_tuning_model.py
+++ b/helical/models/helix_mrna/fine_tuning_model.py
@@ -117,6 +117,8 @@ def train(self,
         ----------
         train_dataset : Dataset
             A helical processed dataset for fine-tuning
+        train_labels : np.ndarray
+            The labels for the training dataset
         optimizer : torch.optim, default=torch.optim.AdamW
             The optimizer to be used for training.
         optimizer_params : dict, optional, default={'lr': 0.0001}
@@ -124,14 +126,14 @@ def train(self,
             e.g. optimizer_params = {'lr': 0.0001}
         loss_function : torch.nn.modules.loss, default=torch.nn.modules.loss.CrossEntropyLoss()
             The loss function to be used.
-        label : str, optional, default="cell_types"
-            The column in the dataset containing the training labels. These should be stored as unique per class integers.
         epochs : int, optional, default=10
             The number of epochs to train the model
         trainable_layers : int, optional, default=2
             The number of layers to train in the model. The last n layers will be trained and the rest will be frozen.
         validation_dataset : Dataset, default=None
             A helical processed dataset for per epoch validation. If this is not specified, no validation will be performed.
+        validation_labels : np.ndarray, default=None
+            The labels for the validation dataset. This is required if a validation dataset is specified.
         lr_scheduler_params : dict, default=None
             The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, no scheduler will be used.
             e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0 }. num_steps will be calculated based on the number of epochs and the length of the training dataset.
diff --git a/helical/models/scgpt/fine_tuning_model.py b/helical/models/scgpt/fine_tuning_model.py
index 6fa3d513..b2ab76fd 100644
--- a/helical/models/scgpt/fine_tuning_model.py
+++ b/helical/models/scgpt/fine_tuning_model.py
@@ -110,8 +110,13 @@ def _forward(self,
             if use_batch_labels
             else None,
         )
-        cls_emb = embeddings[:, 0, :]
-        output = self.fine_tuning_head(cls_emb)
+
+        if self.config["emb_mode"] == "cls":
+            embeddings = embeddings[:, 0, :]
+        else:
+            embeddings = embeddings[:, 1:, 0].mean(dim=1)
+
+        output = self.fine_tuning_head(embeddings)
         return output
     
     def train(
diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 5a56d15d..70de14a8 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -81,13 +81,15 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
 
         Parameters 
         ----------
-        dataset: Dataset
+        dataset : Dataset
             The processed dataset to get the embeddings from.
 
         Returns
         -------
-        np.ndarray
-            The gene embeddings in the form of a numpy array
+        np.ndarray | List[pd.Series]
+            The embeddings produced by the model. 
+            The return type depends on the `emb_mode` parameter in the configuration.
+            If `emb_mode` is set to "gene", the embeddings are returned as a list of pd.Series which contain a mapping of gene_name:embedding for each cell.
         """
         LOGGER.info(f"Inference started:")
 

From cdcac2b9ec1bf188e59e46d7490c77052ed65851 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Fri, 6 Dec 2024 18:20:52 +0100
Subject: [PATCH 10/27] Fix typo in forward of scGPT fine-tuning

---
 helical/models/scgpt/fine_tuning_model.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/helical/models/scgpt/fine_tuning_model.py b/helical/models/scgpt/fine_tuning_model.py
index b2ab76fd..015842cd 100644
--- a/helical/models/scgpt/fine_tuning_model.py
+++ b/helical/models/scgpt/fine_tuning_model.py
@@ -114,7 +114,7 @@ def _forward(self,
         if self.config["emb_mode"] == "cls":
             embeddings = embeddings[:, 0, :]
         else:
-            embeddings = embeddings[:, 1:, 0].mean(dim=1)
+            embeddings = embeddings[:, 1:, :].mean(dim=1)
 
         output = self.fine_tuning_head(embeddings)
         return output

From dc3deed6d152c7f0565934ae943991cc917c9104 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Mon, 9 Dec 2024 11:37:52 +0100
Subject: [PATCH 11/27] Update model cards, docstrings and readthedocs.yml

---
 .readthedocs.yaml                             |  3 --
 README.md                                     | 17 ++++++---
 docs/index.md                                 | 36 ++++++++++++++++---
 docs/model_cards/helix_mrna.md                |  2 ++
 docs/model_cards/mamba2_mrna.md               |  2 ++
 .../models/helix_mrna/fine_tuning_model.py    | 13 +++++++
 6 files changed, 61 insertions(+), 12 deletions(-)

diff --git a/.readthedocs.yaml b/.readthedocs.yaml
index 181d71aa..264775cb 100644
--- a/.readthedocs.yaml
+++ b/.readthedocs.yaml
@@ -8,9 +8,6 @@ build:
     pre_build:
       - "mkdir -p ./docs/notebooks"
       - "cp ./examples/notebooks/*.ipynb ./docs/notebooks"
-      - "pip install -r ./docs/requirements.txt"
-      # Build the MkDocs site
-      # - "mike deploy --push --update-aliases 1.0 latest"
 
 mkdocs:
   configuration: mkdocs.yml
diff --git a/README.md b/README.md
index c6462d67..bd7b0b64 100644
--- a/README.md
+++ b/README.md
@@ -21,7 +21,7 @@ We will update this repo on a regular basis with new models, benchmarks, modalit
 Let’s build the most exciting AI-for-Bio community together!
 <div align="center">
 
-![Workflow](https://github.com/helicalAI/helical/actions/workflows/main.yml/badge.svg) &nbsp;
+![Workflow](https://github.com/helicalAI/helical/actions/workflows/release.yml/badge.svg) &nbsp;
 ![Workflow](https://github.com/helicalAI/helical/actions/workflows/github-code-scanning/codeql/badge.svg) &nbsp;
 [![Docs](https://img.shields.io/badge/docs-available-brightgreen)](https://helical.readthedocs.io/) &nbsp;
 [![PyPI version](https://badge.fury.io/py/helical.svg)](https://badge.fury.io/py/helical) &nbsp;
@@ -29,6 +29,16 @@ Let’s build the most exciting AI-for-Bio community together!
 
 </div>
 
+## What's new?
+### 🧬 Introducing Helix-mRNA-v0: Unlocking new frontiers & use cases in mRNA therapy 🧬
+We’re thrilled to announce the release of our first-ever mRNA Bio Foundation Model, designed to:
+
+1) Be Efficient, handling long sequence lengths effortlessly
+2) Balance Diversity & Specificity, leveraging a 2-step pre-training approach
+3) Deliver High-Resolution, using single nucleotides as a resolution
+
+Check out our <a href="https://www.helical-ai.com/blog/helix-mrna-v0" target="_blank">blog post</a> to learn more about our approach and read the <a href="https://helical.readthedocs.io/en/latest/model_cards/helix_mrna/" target="_blank">model card</a> to get started.
+
 ## Installation
 
 We recommend installing Helical within a conda environment with the commands below (run them in your terminal) - this step is optional:
@@ -75,10 +85,9 @@ and then shell into the sandbox container (use the --nv flag if you have a GPU a
 apptainer shell --nv --fakeroot singularity/helical/
 ```
 
-## Installation
 ### RNA models:
-- [Helix-mRNA]((https://helical.readthedocs.io/en/latest/model_cards/helix_mrna/))
-- [Mamba2-mRNA]((https://helical.readthedocs.io/en/latest/model_cards/mamba2_mrna/))
+- [Helix-mRNA](https://helical.readthedocs.io/en/latest/model_cards/helix_mrna/)
+- [Mamba2-mRNA](https://helical.readthedocs.io/en/latest/model_cards/mamba2_mrna/)
 - [Geneformer](https://helical.readthedocs.io/en/latest/model_cards/geneformer/)
 - [scGPT](https://helical.readthedocs.io/en/latest/model_cards/scgpt/)
 - [Universal Cell Embedding (UCE)](https://helical.readthedocs.io/en/latest/model_cards/uce/)
diff --git a/docs/index.md b/docs/index.md
index 749d1d52..d43ff629 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -9,23 +9,51 @@ Helical simplifies the entire application lifecycle when building with bio found
 We will update this repo on a regular basis with new models, benchmarks, modalities and functions - so stay tuned.
 Let’s build the most exciting AI-for-Bio community together!
 
+## What's new?
+### 🧬 Introducing Helix-mRNA-v0: Unlocking new frontiers & use cases in mRNA therapy 🧬
+We’re thrilled to announce the release of our first-ever mRNA Bio Foundation Model, designed to:
+
+1) Be Efficient, handling long sequence lengths effortlessly
+2) Balance Diversity & Specificity, leveraging a 2-step pre-training approach
+3) Deliver High-Resolution, using single nucleotides as a resolution
+
+Check out our <a href="https://www.helical-ai.com/blog/helix-mrna-v0" target="_blank">blog post</a> to learn more about our approach and read the <a href="https://helical.readthedocs.io/en/latest/model_cards/helix_mrna/" target="_blank">model card</a> to get started.
+
 ## Installation
 
 We recommend installing Helical within a conda environment with the commands below (run them in your terminal) - this step is optional:
-```shell
+```bash
 conda create --name helical-package python=3.11.8
 conda activate helical-package
 ```
+
 To install the latest pip release of our Helical package, you can run the command below:
-```shell
+```bash
 pip install helical
 ```
 
 To install the latest Helical package, you can run the command below:
-```
+```bash
 pip install --upgrade git+https://github.com/helicalAI/helical.git
 ```
 
+Alternatively, clone the repo and install it:
+```bash
+git clone https://github.com/helicalAI/helical.git
+pip install .
+```
+
+[Optional] To install mamba-ssm and causal-conv1d use the command below:
+```bash
+pip install helical[mamba-ssm]
+```
+or in case you're installing from the Helical repo cloned locally:
+```bash
+pip install .[mamba-ssm]
+```
+
+Note: make sure your machine has GPU(s) and Cuda installed. Currently this is a requirement for the packages mamba-ssm and causal-conv1d.
+
 ### Singularity (Optional)
 If you desire to run your code in a singularity file, you can use the <a href="https://github.com/helicalAI/helical/blob/release/singularity.def" target="_blank">singularity.def</a> file and build an apptainer with it:
 ```
@@ -37,8 +65,6 @@ and then shell into the sandbox container (use the --nv flag if you have a GPU a
 apptainer shell --nv --fakeroot singularity/helical/
 ```
 
-
-## Installation
 ### RNA models:
 - [Helix-mRNA](./model_cards/helix_mrna.md)
 - [Mamba2-mRNA](./model_cards/mamba2_mrna.md)
diff --git a/docs/model_cards/helix_mrna.md b/docs/model_cards/helix_mrna.md
index 79fefa5c..0b11b33f 100644
--- a/docs/model_cards/helix_mrna.md
+++ b/docs/model_cards/helix_mrna.md
@@ -3,7 +3,9 @@
 ## Model Details
 
 **Model Name:** Helix-mRNA  
+
 **Model Versions:** v0 
+
 **Model Description:** Helix-mRNA is a single nucleotide resolution model that combines the Mamba2 architecture with transformer components, including attention and MLP blocks. The hybrid architecture enables precise nucleotide-level analysis and prediction of mRNA sequences. By leveraging both the efficient sequence processing capabilities of Mamba2's state space architecture and the contextual understanding of transformer attention mechanisms, Helix-mRNA processes mRNA sequences at individual nucleotide resolution. The model incorporates a special 'E' character to denote the beginning of each codon, enhancing its ability to recognize and analyze codon-level patterns in mRNA sequences.
 
 ## Model Developers
diff --git a/docs/model_cards/mamba2_mrna.md b/docs/model_cards/mamba2_mrna.md
index 653510d7..e33d472b 100644
--- a/docs/model_cards/mamba2_mrna.md
+++ b/docs/model_cards/mamba2_mrna.md
@@ -3,7 +3,9 @@
 ## Model Details
 
 **Model Name:** Mamba2-mRNA  
+
 **Model Versions:** 1.0  
+
 **Model Description:** Mamba2-mRNA is a single nucleotide resolution model built using the Mamba2 architecture. The model employs 16 Mamba layers (16L) to enable precise nucleotide-level analysis and prediction of mRNA sequences. By leveraging the efficient sequence processing capabilities of Mamba2's state space architecture, Mamba2-mRNA can process mRNA sequences at individual nucleotide resolution, making it suitable for detailed mRNA sequence analysis tasks.
 
 ## Model Developers
diff --git a/helical/models/helix_mrna/fine_tuning_model.py b/helical/models/helix_mrna/fine_tuning_model.py
index 485620b5..84a31d59 100644
--- a/helical/models/helix_mrna/fine_tuning_model.py
+++ b/helical/models/helix_mrna/fine_tuning_model.py
@@ -221,6 +221,19 @@ def train(self,
         LOGGER.info(f"Fine-Tuning Complete. Epochs: {epochs}")
 
     def get_outputs(self, dataset: Dataset) -> np.ndarray:
+        """
+        Returns the outputs of the model for the given dataset.
+        
+        Parameters
+        ----------
+        dataset : Dataset
+            The dataset object returned by the `process_data` function.
+        
+        Returns
+        ----------
+        np.ndarray
+            The outputs of the model for the given dataset
+        """
         dataloader = DataLoader(dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"], shuffle=False)
         outputs = []
 

From 6401b3c27f6f0ee1868a69db1123316a5b29959a Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com>
Date: Mon, 9 Dec 2024 15:25:48 +0100
Subject: [PATCH 12/27] Update version in pyproject.toml

---
 pyproject.toml | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

diff --git a/pyproject.toml b/pyproject.toml
index ccfa5f81..6637dbbd 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "helical"
-version = "0.0.1a9"
+version = "0.0.1a10"
 authors = [
   { name="Helical Team", email="support@helical-ai.com" },
 ]
@@ -60,4 +60,4 @@ allow-direct-references = true # Allows installing package from git
 Homepage = "https://github.com/helicalAI/helical"
 Issues = "https://github.com/helicalAI/helical/issues"
 Documentation = "https://helical.readthedocs.io/"
-Repository = "https://github.com/helicalAI/helical"
\ No newline at end of file
+Repository = "https://github.com/helicalAI/helical"

From df66c12bd371be3472021a0df1579533fd811eea Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com>
Date: Mon, 9 Dec 2024 15:50:03 +0100
Subject: [PATCH 13/27] Update how version is extracted from toml file

---
 .github/workflows/release.yml | 10 +---------
 1 file changed, 1 insertion(+), 9 deletions(-)

diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index b7138f60..911e7988 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -42,14 +42,6 @@ jobs:
           name: coverage-report
           path: html_cov/
 
-      # Does not seem to work but would be nice to have
-      # - name: Pytest coverage comment
-      #   uses: MishaKav/pytest-coverage-comment@main
-      #   with:
-      #     pytest-coverage-path: ./pytest-coverage.txt
-      #     junitxml-path: ./pytest.xml
-
-
       - name: Execute Geneformer v1
         run: |
           python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048"
@@ -160,7 +152,7 @@ jobs:
         id: get_version
         run: |
           # Extract version from setup.py (or adjust for your version file)
-          VERSION=$(python setup.py --version)
+          VERSION=$(grep "version =" pyproject.toml | cut -d '"' -f 2)
           echo "::set-output name=version::$VERSION"
         
       - name: Check if tag exists

From 48ba576ab208cd0508b7d5a765859c41ffc66d31 Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com>
Date: Mon, 9 Dec 2024 15:51:58 +0100
Subject: [PATCH 14/27] Update comment

---
 .github/workflows/release.yml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 911e7988..03eaeed7 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -151,7 +151,7 @@ jobs:
       - name: Extract version
         id: get_version
         run: |
-          # Extract version from setup.py (or adjust for your version file)
+          # Extract version from pyproject.toml (or adjust for your version file)
           VERSION=$(grep "version =" pyproject.toml | cut -d '"' -f 2)
           echo "::set-output name=version::$VERSION"
         

From dc5461ee9700f689f847e46b8ac896665974c9ca Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <benoit@helical-ai.com>
Date: Tue, 10 Dec 2024 09:50:48 +0100
Subject: [PATCH 15/27] Improve testing with mocking actual forward pass of
 scgpt model. Use functions for more granular and easier-to-read code in
 scGPT.process_data function

---
 ci/tests/test_scgpt/test_scgpt_model.py |  56 +++++++++---
 helical/models/scgpt/model.py           | 108 ++++++++++++++----------
 2 files changed, 107 insertions(+), 57 deletions(-)

diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py
index c96397fa..33bcc9c1 100644
--- a/ci/tests/test_scgpt/test_scgpt_model.py
+++ b/ci/tests/test_scgpt/test_scgpt_model.py
@@ -6,6 +6,8 @@
 import anndata as ad
 import numpy as np
 from scipy.sparse import csr_matrix
+import torch
+import pandas as pd
 
 class TestSCGPTModel:
     scgpt = scGPT()
@@ -58,19 +60,6 @@ def test_direct_assignment_of_genes_to_index(self):
         assert self.scgpt.gene_names == "index"
         assert self.scgpt.gene_names in self.data.var
 
-    # test get_embeddings for all three embedding modes
-    @pytest.mark.parametrize("emb_mode", ["cell", "gene", "cls"])
-    def test_get_embeddings(self, emb_mode):
-        self.scgpt.config["emb_mode"] = emb_mode
-        dataset = self.scgpt.process_data(self.data, gene_names = "gene_names")
-        embeddings = self.scgpt.get_embeddings(dataset)
-        if emb_mode == "gene":
-            assert list(embeddings[0].keys()) == ["SAMD11", "PLEKHN1", "HES4"]
-            for key in embeddings[0].keys():
-                assert len(embeddings[0][key]) == 512
-        else:
-            assert embeddings.shape == (1, 512)
-
     dummy_data = ad.read_h5ad("ci/tests/data/cell_type_sample.h5ad")
     @pytest.mark.parametrize("data, gene_names, batch_labels", 
                              [
@@ -123,4 +112,43 @@ def test_fine_tune_classification_returns_correct_shape(self):
         fine_tuned_model.train(train_input_data=tokenized_dataset, train_labels=labels)
         assert fine_tuned_model is not None
         outputs = fine_tuned_model.get_outputs(tokenized_dataset)
-        assert outputs.shape == (len(self.data), len(labels))
\ No newline at end of file
+        assert outputs.shape == (len(self.data), len(labels))
+
+    @pytest.mark.parametrize("emb_mode", ["cell", "gene", "cls"])
+    def test_get_embeddings_of_different_modes(self, mocker, emb_mode):
+        self.scgpt.config["emb_mode"] = emb_mode
+        self.scgpt.config["embsize"] = 5
+
+        # Mock the method directly on the instance
+        mocked_embeddings = torch.tensor([
+                                        [[1.0, 1.0, 1.0, 1.0, 1.0], 
+                                         [5.0, 5.0, 5.0, 5.0, 5.0], 
+                                         [1.0, 2.0, 3.0, 2.0, 1.0], 
+                                         [6.0, 6.0, 6.0, 6.0, 6.0]],
+                                        ])
+        mocker.patch.object(self.scgpt.model, "_encode", return_value=mocked_embeddings)
+
+        # mocking the normalization of embeddings makes it easier to test the output
+        mocker.patch.object(self.scgpt, "_normalize_embeddings", return_value=None)
+
+        dataset = self.scgpt.process_data(self.data, gene_names = "gene_names")
+        embeddings = self.scgpt.get_embeddings(dataset)
+        if emb_mode == "gene":
+            data_list = pd.Series({
+                            "SAMD11": np.array([5.0, 5.0, 5.0, 5.0, 5.0]),
+                            "PLEKHN1": np.array([1.0, 2.0, 3.0, 2.0, 1.0]),
+                            "HES4": np.array([6.0, 6.0, 6.0, 6.0, 6.0])
+                        })  
+            assert data_list.equals(embeddings[0])
+
+        if emb_mode == "cls":
+            assert (embeddings == np.array([1.0, 1.0, 1.0, 1.0, 1.0])).all()
+        if emb_mode == "cell":  
+            # average column wise excluding first row
+            expected = np.array([[4.       , 4.3333335, 4.6666665, 4.3333335, 4.       ]])
+            np.testing.assert_allclose(
+                embeddings, 
+                expected, 
+                rtol=1e-4,  # relative tolerance
+                atol=1e-4   # absolute tolerance
+            )
diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 70de14a8..64b81d00 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -119,18 +119,10 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
 
         device = next(self.model.parameters()).device
 
-        # provision numpy ndarray for gene, cell and cls embeddings
-        if self.config["emb_mode"] == "gene":
-            # create dictionary mapping gene id to gene name, can't seem to find one in the vocab object
-            id_gene_dict = {i: gene for i, gene in enumerate(self.vocab.get_itos())}
-            gene_embs = []
-        else:
-            cell_embeddings = np.zeros(
-                (len(dataset), self.config["embsize"]), dtype=np.float32
-            )
-        
+        self._initialize_embeddings()
+
         with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): #torch.autocast(device_type=str(device),enabled=True): # torch.cuda.amp.autocast(enabled=True):
-            count = 0
+            self.count = 0
             for data_dict in tqdm(data_loader, desc="Embedding cells"):
                 input_gene_ids = data_dict["gene"].to(device)
 
@@ -146,41 +138,71 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
                     else None,
                 )
 
-                if self.config["emb_mode"] == "cls":
-                    embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
-                    embeddings = embeddings.cpu().numpy()
-                elif self.config["emb_mode"] == "cell":
-                    embeddings = embeddings[: 1:, :] # get all embeddings except the <cls> position
-                    embeddings = torch.mean(embeddings, dim=1) # mean embeddings to get cell embedding
-                    embeddings = embeddings.cpu().numpy()
-                elif self.config["emb_mode"] == "gene":
-                    embeddings = embeddings[:, 1:, :].cpu().numpy() # get all embeddings except the <cls> position
-                    gene_ids = data_dict["gene"].cpu().numpy()
-                    series = []
-
-                    # create a dictionary with gene name to gene embedding mappings and create pd series for each cell in batch
-                    for i, embedding in enumerate(embeddings):
-                        dict = {}
-                        for j, gene in enumerate(embedding, 1):
-                            if data_dict["gene"][i][j] != self.vocab[self.config["pad_token"]]:
-                                print(gene_ids[i][j])
-                                dict[id_gene_dict[gene_ids[i][j]]] = gene / np.linalg.norm(gene)
-                        
-                        series.append(pd.Series(dict))
-                if self.config["emb_mode"] != "gene":   
-                    cell_embeddings[count : count + len(embeddings)] = embeddings
-                    count += len(embeddings)
-                else:
-                    gene_embs.extend(series)
-        if self.config["emb_mode"] != "gene":
-            cell_embeddings = cell_embeddings / np.linalg.norm(
-                cell_embeddings, axis=1, keepdims=True
-            )
+                self._compute_embeddings_depending_on_mode(embeddings, data_dict)
 
-            return cell_embeddings
+        self._normalize_embeddings()
+        return self.resulting_embeddings
+        
+    def _normalize_embeddings(self) -> None:
+        """
+        Normalize the embeddings of the member variable self.resulting_embeddings.
+        """
+        if self.config["emb_mode"] != "gene":
+            self.resulting_embeddings = self.resulting_embeddings / np.linalg.norm(self.resulting_embeddings, axis=1, keepdims=True)
         else:
-            return gene_embs
+            for _, series in enumerate(self.resulting_embeddings):
+                for gene in series.keys():
+                    series[gene] = series[gene] / np.linalg.norm(series[gene])    
+        
+    def _initialize_embeddings(self):
+        """
+        Initialize the embeddings of the member variable self.resulting_embeddings.
+        Including a dictionary mapping gene id to gene name.
+        """
+        if self.config["emb_mode"] == "gene":
+            self.id_gene_dict = {i: gene for i, gene in enumerate(self.vocab.get_itos())}
+        self.resulting_embeddings = []
+
+    def _compute_embeddings_depending_on_mode(self, embeddings: torch.Tensor, data_dict: dict) -> None:
+        """
+        Compute the embeddings depending on the mode set in the configuration.
 
+        Parameters
+        ----------
+        embeddings : torch.Tensor
+            The embeddings to be processed.
+        data_dict : dict
+            The data dictionary containing the data to be processed.
+        
+        Returns
+        -------
+        None
+        """
+        if self.config["emb_mode"] == "cls":
+            embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
+            embeddings = embeddings.cpu().numpy()
+            self.resulting_embeddings[self.count : self.count + len(embeddings)] = embeddings
+            self.count += len(embeddings)
+
+        elif self.config["emb_mode"] == "cell":
+            embeddings = embeddings[:, 1:, :] # get all embeddings except the <cls> position
+            embeddings = torch.mean(embeddings, dim=1) # mean embeddings to get cell embedding
+            embeddings = embeddings.cpu().numpy()
+            self.resulting_embeddings[self.count : self.count + len(embeddings)] = embeddings
+            self.count += len(embeddings)
+        
+        elif self.config["emb_mode"] == "gene":
+            embeddings = embeddings[:, 1:, :].cpu().numpy() # get all embeddings except the <cls> position
+            gene_ids = data_dict["gene"].cpu().numpy()
+
+            # create a dictionary with gene name to gene embedding mappings and create pd series for each cell in batch
+            for i, embedding in enumerate(embeddings):
+                dict = {}
+                for j, gene in enumerate(embedding, 1):
+                    if data_dict["gene"][i][j] != self.vocab[self.config["pad_token"]]:
+                        dict[self.id_gene_dict[gene_ids[i][j]]] = gene
+                
+                self.resulting_embeddings.append(pd.Series(dict))  
     
     def process_data(self,
                      adata: AnnData, 

From cc8ace21b71051e99332422708cc63edb1d7a989 Mon Sep 17 00:00:00 2001
From: Matt <matt@Matthews-MacBook-Air.local>
Date: Tue, 10 Dec 2024 15:10:33 +0100
Subject: [PATCH 16/27] Remove member variables to functions and add
 normalization testing

---
 ci/tests/test_scgpt/test_scgpt_model.py | 41 +++++++++++++++++++-
 helical/models/scgpt/model.py           | 51 ++++++++++++-------------
 2 files changed, 63 insertions(+), 29 deletions(-)

diff --git a/ci/tests/test_scgpt/test_scgpt_model.py b/ci/tests/test_scgpt/test_scgpt_model.py
index 33bcc9c1..0beb6c7a 100644
--- a/ci/tests/test_scgpt/test_scgpt_model.py
+++ b/ci/tests/test_scgpt/test_scgpt_model.py
@@ -129,7 +129,7 @@ def test_get_embeddings_of_different_modes(self, mocker, emb_mode):
         mocker.patch.object(self.scgpt.model, "_encode", return_value=mocked_embeddings)
 
         # mocking the normalization of embeddings makes it easier to test the output
-        mocker.patch.object(self.scgpt, "_normalize_embeddings", return_value=None)
+        mocker.patch.object(self.scgpt, "_normalize_embeddings", side_effect=lambda x: x)
 
         dataset = self.scgpt.process_data(self.data, gene_names = "gene_names")
         embeddings = self.scgpt.get_embeddings(dataset)
@@ -145,10 +145,47 @@ def test_get_embeddings_of_different_modes(self, mocker, emb_mode):
             assert (embeddings == np.array([1.0, 1.0, 1.0, 1.0, 1.0])).all()
         if emb_mode == "cell":  
             # average column wise excluding first row
-            expected = np.array([[4.       , 4.3333335, 4.6666665, 4.3333335, 4.       ]])
+            expected = np.array([[4., 4.3333335, 4.6666665, 4.3333335, 4.]])
             np.testing.assert_allclose(
                 embeddings, 
                 expected, 
                 rtol=1e-4,  # relative tolerance
                 atol=1e-4   # absolute tolerance
             )
+
+    @pytest.mark.parametrize("emb_mode", ["cls", "cell"])
+    def test_normalization_cell_and_cls(self, emb_mode):
+        mocked_embeddings = np.array([[1.0, 1.0, 1.0, 1.0, 1.0], 
+                                      [5.0, 5.0, 5.0, 5.0, 5.0], 
+                                      [1.0, 2.0, 3.0, 2.0, 1.0], 
+                                      [6.0, 6.0, 6.0, 6.0, 6.0]])
+        
+        expected_normalized_embeddings = np.array([[0.4472, 0.4472, 0.4472, 0.4472, 0.4472],
+                                                   [0.4472, 0.4472, 0.4472, 0.4472, 0.4472],
+                                                   [0.2294, 0.4588, 0.6882, 0.4588, 0.2294],
+                                                   [0.4472, 0.4472, 0.4472, 0.4472, 0.4472]])
+        
+        self.scgpt.config["emb_mode"] = emb_mode
+        normalized_embeddings = np.around(self.scgpt._normalize_embeddings(mocked_embeddings), decimals=4)
+        assert np.all(np.equal(normalized_embeddings, expected_normalized_embeddings))
+
+    def test_normalization_of_gene(self):
+        mocked_embeddings = [pd.Series({
+                            "SAMD11": np.array([5.0, 5.0, 5.0, 5.0, 5.0]),
+                            "PLEKHN1": np.array([1.0, 2.0, 3.0, 2.0, 1.0]),
+                            "HES4": np.array([6.0, 6.0, 6.0, 6.0, 6.0])
+                        })]
+        expected_normalized_embeddings = [pd.Series({
+                            "SAMD11": np.array([0.4472, 0.4472, 0.4472, 0.4472, 0.4472]),
+                            "PLEKHN1": np.array([0.2294, 0.4588, 0.6882, 0.4588, 0.2294]),
+                            "HES4": np.array([0.4472, 0.4472, 0.4472, 0.4472, 0.4472])
+                        })]
+
+        self.scgpt.config["emb_mode"] = "gene"
+        normalized_embeddings = self.scgpt._normalize_embeddings(mocked_embeddings)
+
+        for expected_emb, emb in zip(expected_normalized_embeddings, normalized_embeddings):
+            for true_index, index in zip(expected_emb.keys(), emb.keys()):
+                assert np.all(np.equal(expected_emb[true_index], np.around(emb[index], decimals=4)))
+
+        
\ No newline at end of file
diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 64b81d00..938d4a0c 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -119,10 +119,12 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
 
         device = next(self.model.parameters()).device
 
-        self._initialize_embeddings()
+        if self.config["emb_mode"] == "gene":
+            self.id_gene_dict = {i: gene for i, gene in enumerate(self.vocab.get_itos())}
+        
+        resulting_embeddings = []
 
         with torch.no_grad(), torch.cuda.amp.autocast(enabled=True): #torch.autocast(device_type=str(device),enabled=True): # torch.cuda.amp.autocast(enabled=True):
-            self.count = 0
             for data_dict in tqdm(data_loader, desc="Embedding cells"):
                 input_gene_ids = data_dict["gene"].to(device)
 
@@ -138,71 +140,66 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
                     else None,
                 )
 
-                self._compute_embeddings_depending_on_mode(embeddings, data_dict)
+                resulting_embeddings.extend(self._compute_embeddings_depending_on_mode(embeddings, data_dict))
 
-        self._normalize_embeddings()
-        return self.resulting_embeddings
+        resulting_embeddings = self._normalize_embeddings(resulting_embeddings)
+        return resulting_embeddings
         
-    def _normalize_embeddings(self) -> None:
+    def _normalize_embeddings(self, resulting_embeddings: torch.tensor) -> np.ndarray:
         """
-        Normalize the embeddings of the member variable self.resulting_embeddings.
+        Divides each element of each embedding by the norm of that embedding
         """
         if self.config["emb_mode"] != "gene":
-            self.resulting_embeddings = self.resulting_embeddings / np.linalg.norm(self.resulting_embeddings, axis=1, keepdims=True)
+            resulting_embeddings = resulting_embeddings / np.linalg.norm(resulting_embeddings, axis=1, keepdims=True)
         else:
-            for _, series in enumerate(self.resulting_embeddings):
+            for series in resulting_embeddings:
                 for gene in series.keys():
-                    series[gene] = series[gene] / np.linalg.norm(series[gene])    
-        
-    def _initialize_embeddings(self):
-        """
-        Initialize the embeddings of the member variable self.resulting_embeddings.
-        Including a dictionary mapping gene id to gene name.
-        """
-        if self.config["emb_mode"] == "gene":
-            self.id_gene_dict = {i: gene for i, gene in enumerate(self.vocab.get_itos())}
-        self.resulting_embeddings = []
+                    series[gene] = series[gene] / np.linalg.norm(series[gene])   
 
-    def _compute_embeddings_depending_on_mode(self, embeddings: torch.Tensor, data_dict: dict) -> None:
+        return resulting_embeddings
+
+    def _compute_embeddings_depending_on_mode(self, embeddings: torch.tensor, data_dict: dict) -> np.ndarray:
         """
         Compute the embeddings depending on the mode set in the configuration.
 
         Parameters
         ----------
-        embeddings : torch.Tensor
+        embeddings : torch.tensor
             The embeddings to be processed.
         data_dict : dict
             The data dictionary containing the data to be processed.
         
         Returns
         -------
-        None
+        np.ndarray
+            The embeddings corresponding to the mode selected
         """
         if self.config["emb_mode"] == "cls":
             embeddings = embeddings[:, 0, :]  # get the <cls> position embedding
             embeddings = embeddings.cpu().numpy()
-            self.resulting_embeddings[self.count : self.count + len(embeddings)] = embeddings
-            self.count += len(embeddings)
+            return embeddings
 
         elif self.config["emb_mode"] == "cell":
             embeddings = embeddings[:, 1:, :] # get all embeddings except the <cls> position
             embeddings = torch.mean(embeddings, dim=1) # mean embeddings to get cell embedding
             embeddings = embeddings.cpu().numpy()
-            self.resulting_embeddings[self.count : self.count + len(embeddings)] = embeddings
-            self.count += len(embeddings)
+            return embeddings
         
         elif self.config["emb_mode"] == "gene":
             embeddings = embeddings[:, 1:, :].cpu().numpy() # get all embeddings except the <cls> position
             gene_ids = data_dict["gene"].cpu().numpy()
 
             # create a dictionary with gene name to gene embedding mappings and create pd series for each cell in batch
+            batch_embeddings = []
             for i, embedding in enumerate(embeddings):
                 dict = {}
                 for j, gene in enumerate(embedding, 1):
                     if data_dict["gene"][i][j] != self.vocab[self.config["pad_token"]]:
                         dict[self.id_gene_dict[gene_ids[i][j]]] = gene
                 
-                self.resulting_embeddings.append(pd.Series(dict))  
+                batch_embeddings.append(pd.Series(dict))
+            
+            return batch_embeddings
     
     def process_data(self,
                      adata: AnnData, 

From b462a6709a67ab93f8a044d58bfeba9f55d5439a Mon Sep 17 00:00:00 2001
From: Matt <matt@Matthews-MacBook-Air.local>
Date: Tue, 10 Dec 2024 15:43:06 +0100
Subject: [PATCH 17/27] Change fine-tuning single-cell models to use the local
 file in run_models for quicker testing

---
 examples/fine_tune_models/fine_tune_UCE.py        | 9 +++++++--
 examples/fine_tune_models/fine_tune_geneformer.py | 8 +++++---
 examples/fine_tune_models/fine_tune_scgpt.py      | 7 +++++--
 3 files changed, 17 insertions(+), 7 deletions(-)

diff --git a/examples/fine_tune_models/fine_tune_UCE.py b/examples/fine_tune_models/fine_tune_UCE.py
index 342f82ee..992f8344 100644
--- a/examples/fine_tune_models/fine_tune_UCE.py
+++ b/examples/fine_tune_models/fine_tune_UCE.py
@@ -1,13 +1,18 @@
 from helical import UCEConfig, UCEFineTuningModel
 from helical.utils import get_anndata_from_hf_dataset
 from datasets import load_dataset
+import anndata as ad
 from omegaconf import DictConfig
 import hydra
 
 @hydra.main(version_base=None, config_path="../run_models/configs", config_name="uce_config")
 def run_fine_tuning(cfg: DictConfig):
-    hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
-    ann_data = get_anndata_from_hf_dataset(hf_dataset)
+    
+    # either load via huggingface
+    # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
+    # ann_data = get_anndata_from_hf_dataset(hf_dataset)
+
+    ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
 
     cell_types = ann_data.obs["LVL1"][:10].tolist()
 
diff --git a/examples/fine_tune_models/fine_tune_geneformer.py b/examples/fine_tune_models/fine_tune_geneformer.py
index bbaf1275..128a6c01 100644
--- a/examples/fine_tune_models/fine_tune_geneformer.py
+++ b/examples/fine_tune_models/fine_tune_geneformer.py
@@ -1,14 +1,16 @@
 from helical import GeneformerConfig, GeneformerFineTuningModel
 from helical.utils import get_anndata_from_hf_dataset
 from datasets import load_dataset
+import anndata as ad
 import hydra
 from omegaconf import DictConfig
 
 @hydra.main(version_base=None, config_path="../run_models/configs", config_name="geneformer_config")
 def run_fine_tuning(cfg: DictConfig):
-                            
-    hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
-    ann_data = get_anndata_from_hf_dataset(hf_dataset)
+    # Option to download from HuggingFace
+    # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
+    # ann_data = get_anndata_from_hf_dataset(hf_dataset)
+    ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
 
     cell_types = list(ann_data.obs["LVL1"][:10])
     label_set = set(cell_types)
diff --git a/examples/fine_tune_models/fine_tune_scgpt.py b/examples/fine_tune_models/fine_tune_scgpt.py
index ae27fd6c..60147414 100644
--- a/examples/fine_tune_models/fine_tune_scgpt.py
+++ b/examples/fine_tune_models/fine_tune_scgpt.py
@@ -1,14 +1,17 @@
 from helical import scGPTConfig, scGPTFineTuningModel
 from helical.utils import get_anndata_from_hf_dataset
 from datasets import load_dataset
+import anndata as ad
 from omegaconf import DictConfig
 import hydra
 
 
 @hydra.main(version_base=None, config_path="../run_models/configs", config_name="scgpt_config")
 def run_fine_tuning(cfg: DictConfig):
-    hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
-    ann_data = get_anndata_from_hf_dataset(hf_dataset)
+    # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
+    # ann_data = get_anndata_from_hf_dataset(hf_dataset)
+
+    ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
 
     cell_types = ann_data.obs["LVL1"][:10].tolist()
     label_set = set(cell_types)

From 050c8e4af37551b5780c49459de489e933192d57 Mon Sep 17 00:00:00 2001
From: Matt <matthew@helical-ai.com>
Date: Tue, 10 Dec 2024 16:00:21 +0100
Subject: [PATCH 18/27] Add cpu as default behaviour and override this to be
 GPU when running integration tests

---
 .github/workflows/main.yml                    | 28 +++++++++----------
 .github/workflows/release.yml                 | 28 +++++++++----------
 .../fine_tune_models/fine_tune_geneformer.py  |  3 +-
 examples/fine_tune_models/fine_tune_scgpt.py  |  2 ++
 .../run_models/configs/geneformer_config.yaml |  2 +-
 .../run_models/configs/helix_mrna_config.yaml |  2 +-
 .../run_models/configs/hyena_dna_config.yaml  |  2 +-
 .../configs/mamba2_mrna_config.yaml           |  2 +-
 examples/run_models/configs/scgpt_config.yaml |  2 +-
 examples/run_models/configs/uce_config.yaml   |  2 +-
 10 files changed, 38 insertions(+), 35 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index d3007d7b..16074743 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -49,59 +49,59 @@ jobs:
 
       - name: Execute Geneformer v1
         run: |
-          python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048"
+          python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048" ++device="cuda"
 
       - name: Fine-tune Geneformer v1
         run: |
-          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048" ++device="cuda"
 
       - name: Execute Geneformer v2
         run: |
-          python examples/run_models/run_geneformer.py ++model_name="gf-12L-95M-i4096"
+          python examples/run_models/run_geneformer.py ++model_name="gf-12L-95M-i4096" ++device="cuda"
 
       - name: Fine-tune Geneformer v2
         run: |
-          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048" ++device="cuda"
 
       - name: Execute scGPT
         run: |
-          python examples/run_models/run_scgpt.py
+          python examples/run_models/run_scgpt.py ++device="cuda"
 
       - name: Fine-tune scGPT
         run: |
-          python examples/fine_tune_models/fine_tune_scgpt.py
+          python examples/fine_tune_models/fine_tune_scgpt.py ++device="cuda"
 
       - name: Execute UCE
         run: |
-          python examples/run_models/run_uce.py
+          python examples/run_models/run_uce.py ++device="cuda"
 
       - name: Fine-tune UCE
         run: |
-          python examples/fine_tune_models/fine_tune_UCE.py
+          python examples/fine_tune_models/fine_tune_UCE.py ++device="cuda"
 
       - name: Execute Hyena
         run: |
-          python examples/run_models/run_hyena_dna.py
+          python examples/run_models/run_hyena_dna.py ++device="cuda"
 
       - name: Execute Hyena
         run: |
-          python examples/fine_tune_models/fine_tune_hyena_dna.py
+          python examples/fine_tune_models/fine_tune_hyena_dna.py ++device="cuda"
 
       - name: Execute Helix-mRNA
         run: |
-          python examples/run_models/run_helix_mrna.py
+          python examples/run_models/run_helix_mrna.py ++device="cuda"
 
       - name: Fine-tune Helix-mRNA
         run: |
-          python examples/fine_tune_models/fine_tune_helix_mrna.py
+          python examples/fine_tune_models/fine_tune_helix_mrna.py ++device="cuda"
 
       - name: Execute Mamba2-mRNA
         run: |
-          python examples/run_models/run_mamba2_mrna.py
+          python examples/run_models/run_mamba2_mrna.py ++device="cuda"
 
       - name: Fine-tune Mamba2-mRNA
         run: |
-          python examples/fine_tune_models/fine_tune_mamba2_mrna.py
+          python examples/fine_tune_models/fine_tune_mamba2_mrna.py ++device="cuda"
 
       - name: Execute benchmarking
         run: |
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 03eaeed7..3bb06fe1 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -44,59 +44,59 @@ jobs:
 
       - name: Execute Geneformer v1
         run: |
-          python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048"
+          python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048" ++device="cuda"
 
       - name: Fine-tune Geneformer v1
         run: |
-          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048" ++device="cuda"
 
       - name: Execute Geneformer v2
         run: |
-          python examples/run_models/run_geneformer.py ++model_name="gf-12L-95M-i4096"
+          python examples/run_models/run_geneformer.py ++model_name="gf-12L-95M-i4096" ++device="cuda"
 
       - name: Fine-tune Geneformer v2
         run: |
-          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048"
+          python examples/fine_tune_models/fine_tune_geneformer.py ++model_name="gf-12L-30M-i2048" ++device="cuda"
 
       - name: Execute scGPT
         run: |
-          python examples/run_models/run_scgpt.py
+          python examples/run_models/run_scgpt.py ++device="cuda"
 
       - name: Fine-tune scGPT
         run: |
-          python examples/fine_tune_models/fine_tune_scgpt.py
+          python examples/fine_tune_models/fine_tune_scgpt.py ++device="cuda"
 
       - name: Execute UCE
         run: |
-          python examples/run_models/run_uce.py
+          python examples/run_models/run_uce.py ++device="cuda"
 
       - name: Fine-tune UCE
         run: |
-          python examples/fine_tune_models/fine_tune_UCE.py
+          python examples/fine_tune_models/fine_tune_UCE.py ++device="cuda"
 
       - name: Execute Hyena
         run: |
-          python examples/run_models/run_hyena_dna.py
+          python examples/run_models/run_hyena_dna.py ++device="cuda"
 
       - name: Execute Hyena
         run: |
-          python examples/fine_tune_models/fine_tune_hyena_dna.py
+          python examples/fine_tune_models/fine_tune_hyena_dna.py ++device="cuda"
 
       - name: Execute Helix-mRNA
         run: |
-          python examples/run_models/run_helix_mrna.py
+          python examples/run_models/run_helix_mrna.py ++device="cuda"
 
       - name: Fine-tune Helix-mRNA
         run: |
-          python examples/fine_tune_models/fine_tune_helix_mrna.py
+          python examples/fine_tune_models/fine_tune_helix_mrna.py ++device="cuda"
 
       - name: Execute Mamba2-mRNA
         run: |
-          python examples/run_models/run_mamba2_mrna.py
+          python examples/run_models/run_mamba2_mrna.py ++device="cuda"
 
       - name: Fine-tune Mamba2-mRNA
         run: |
-          python examples/fine_tune_models/fine_tune_mamba2_mrna.py
+          python examples/fine_tune_models/fine_tune_mamba2_mrna.py ++device="cuda"
 
       - name: Execute benchmarking
         run: |
diff --git a/examples/fine_tune_models/fine_tune_geneformer.py b/examples/fine_tune_models/fine_tune_geneformer.py
index 128a6c01..d8d203b9 100644
--- a/examples/fine_tune_models/fine_tune_geneformer.py
+++ b/examples/fine_tune_models/fine_tune_geneformer.py
@@ -7,7 +7,8 @@
 
 @hydra.main(version_base=None, config_path="../run_models/configs", config_name="geneformer_config")
 def run_fine_tuning(cfg: DictConfig):
-    # Option to download from HuggingFace
+    
+    # either load via huggingface
     # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
     # ann_data = get_anndata_from_hf_dataset(hf_dataset)
     ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
diff --git a/examples/fine_tune_models/fine_tune_scgpt.py b/examples/fine_tune_models/fine_tune_scgpt.py
index 60147414..57c261d3 100644
--- a/examples/fine_tune_models/fine_tune_scgpt.py
+++ b/examples/fine_tune_models/fine_tune_scgpt.py
@@ -8,6 +8,8 @@
 
 @hydra.main(version_base=None, config_path="../run_models/configs", config_name="scgpt_config")
 def run_fine_tuning(cfg: DictConfig):
+
+    # either load via huggingface
     # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
     # ann_data = get_anndata_from_hf_dataset(hf_dataset)
 
diff --git a/examples/run_models/configs/geneformer_config.yaml b/examples/run_models/configs/geneformer_config.yaml
index 24cbc512..7ee91173 100644
--- a/examples/run_models/configs/geneformer_config.yaml
+++ b/examples/run_models/configs/geneformer_config.yaml
@@ -2,5 +2,5 @@ model_name: "gf-12L-30M-i2048"
 batch_size: 24
 emb_layer: -1
 emb_mode: "cell"
-device: "cuda"
+device: "cpu"
 accelerator: False
\ No newline at end of file
diff --git a/examples/run_models/configs/helix_mrna_config.yaml b/examples/run_models/configs/helix_mrna_config.yaml
index 948c2184..8913ea48 100644
--- a/examples/run_models/configs/helix_mrna_config.yaml
+++ b/examples/run_models/configs/helix_mrna_config.yaml
@@ -1,3 +1,3 @@
 batch_size: 10
-device: "cuda"
+device: "cpu"
 max_length: 100
\ No newline at end of file
diff --git a/examples/run_models/configs/hyena_dna_config.yaml b/examples/run_models/configs/hyena_dna_config.yaml
index bd3967c4..3f42367c 100644
--- a/examples/run_models/configs/hyena_dna_config.yaml
+++ b/examples/run_models/configs/hyena_dna_config.yaml
@@ -11,7 +11,7 @@ checkpoint_mixer: False
 checkpoint_mlp: False
 pad_vocab_size_multiple: 8
 return_hidden_state: True
-device: "cuda"
+device: "cpu"
 layer: 
     "_name_": "hyena"
     "emb_dim": 5
diff --git a/examples/run_models/configs/mamba2_mrna_config.yaml b/examples/run_models/configs/mamba2_mrna_config.yaml
index 948c2184..8913ea48 100644
--- a/examples/run_models/configs/mamba2_mrna_config.yaml
+++ b/examples/run_models/configs/mamba2_mrna_config.yaml
@@ -1,3 +1,3 @@
 batch_size: 10
-device: "cuda"
+device: "cpu"
 max_length: 100
\ No newline at end of file
diff --git a/examples/run_models/configs/scgpt_config.yaml b/examples/run_models/configs/scgpt_config.yaml
index 518ac78b..4add9a58 100644
--- a/examples/run_models/configs/scgpt_config.yaml
+++ b/examples/run_models/configs/scgpt_config.yaml
@@ -11,5 +11,5 @@ mask_value: -1
 pad_value: -2
 world_size: 8
 accelerator: False
-device: "cuda"
+device: "cpu"
 use_fast_transformer: False
\ No newline at end of file
diff --git a/examples/run_models/configs/uce_config.yaml b/examples/run_models/configs/uce_config.yaml
index 4f9d2827..8d8ad18c 100644
--- a/examples/run_models/configs/uce_config.yaml
+++ b/examples/run_models/configs/uce_config.yaml
@@ -14,5 +14,5 @@ output_dim: 1280
 d_hid: 5120
 token_dim: 5120
 multi_gpu: False
-device: "cuda"
+device: "cpu"
 accelerator: False
\ No newline at end of file

From 688b8105619543d30f2207f5a3e24a1184d7c6fe Mon Sep 17 00:00:00 2001
From: Matt <matthew@helical-ai.com>
Date: Tue, 10 Dec 2024 17:36:06 +0100
Subject: [PATCH 19/27] Correct path to local file being called from main
 directory

---
 examples/fine_tune_models/fine_tune_UCE.py        | 2 +-
 examples/fine_tune_models/fine_tune_geneformer.py | 2 +-
 examples/fine_tune_models/fine_tune_scgpt.py      | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/examples/fine_tune_models/fine_tune_UCE.py b/examples/fine_tune_models/fine_tune_UCE.py
index 992f8344..47a2724e 100644
--- a/examples/fine_tune_models/fine_tune_UCE.py
+++ b/examples/fine_tune_models/fine_tune_UCE.py
@@ -12,7 +12,7 @@ def run_fine_tuning(cfg: DictConfig):
     # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
     # ann_data = get_anndata_from_hf_dataset(hf_dataset)
 
-    ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
+    ann_data = ad.read_h5ad("./yolksac_human.h5ad")
 
     cell_types = ann_data.obs["LVL1"][:10].tolist()
 
diff --git a/examples/fine_tune_models/fine_tune_geneformer.py b/examples/fine_tune_models/fine_tune_geneformer.py
index d8d203b9..2cbf7197 100644
--- a/examples/fine_tune_models/fine_tune_geneformer.py
+++ b/examples/fine_tune_models/fine_tune_geneformer.py
@@ -11,7 +11,7 @@ def run_fine_tuning(cfg: DictConfig):
     # either load via huggingface
     # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
     # ann_data = get_anndata_from_hf_dataset(hf_dataset)
-    ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
+    ann_data = ad.read_h5ad("./yolksac_human.h5ad")
 
     cell_types = list(ann_data.obs["LVL1"][:10])
     label_set = set(cell_types)
diff --git a/examples/fine_tune_models/fine_tune_scgpt.py b/examples/fine_tune_models/fine_tune_scgpt.py
index 57c261d3..6d0719c6 100644
--- a/examples/fine_tune_models/fine_tune_scgpt.py
+++ b/examples/fine_tune_models/fine_tune_scgpt.py
@@ -13,7 +13,7 @@ def run_fine_tuning(cfg: DictConfig):
     # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
     # ann_data = get_anndata_from_hf_dataset(hf_dataset)
 
-    ann_data = ad.read_h5ad("../run_models/yolksac_human.h5ad")
+    ann_data = ad.read_h5ad("./yolksac_human.h5ad")
 
     cell_types = ann_data.obs["LVL1"][:10].tolist()
     label_set = set(cell_types)

From 779af323a8a811f4b8986ac4d84749809bab9fa2 Mon Sep 17 00:00:00 2001
From: Giovanni Ortolani <giovorto@pm.me>
Date: Tue, 10 Dec 2024 17:46:54 +0100
Subject: [PATCH 20/27] Fix mamba-ssm integration in release Pypi (#144)

* Remove git+ pip installation

---------

Co-authored-by: giogix2 <giovanni@Helical-AI-A6000>
---
 .github/workflows/main.yml    | 22 ++++++++++++++++++++
 .github/workflows/release.yml | 38 +++++++++++++++++++++++++++++++++++
 pyproject.toml                |  4 ++--
 3 files changed, 62 insertions(+), 2 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index d3007d7b..a66e2bf2 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -22,6 +22,8 @@ jobs:
       - name: Install dependencies
         run: |
             pip install -r requirements-dev.txt
+            # Currently, the self-hosted runner uses a pre-installed version of mamba-ssm (v2.2.2). Using the latest version (v2.2.4) breaks the pip installation.
+            # TODO fix the installation of helical[mamba-ssm], and add it in the GitHub actions (CI/CD pipeline).
             pip install .
             
       # First download before tests as they make use of the downloaded files 
@@ -46,6 +48,26 @@ jobs:
       #     pytest-coverage-path: ./pytest-coverage.txt
       #     junitxml-path: ./pytest.xml
 
+  integration-tests:
+    needs: tests
+    runs-on: self-hosted
+    steps:
+      - name: Checkout repository
+        uses: actions/checkout@v2
+
+      - name: setup python
+        uses: actions/setup-python@v5
+        with:
+          python-version: 3.11.8
+
+      - name: Install dependencies
+        run: |
+            pip install -r requirements-dev.txt
+      
+      # First download before tests as they make use of the downloaded files 
+      - name: Download all files
+        run: |
+          python ci/download_all.py
 
       - name: Execute Geneformer v1
         run: |
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 03eaeed7..33bf160b 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -9,6 +9,21 @@ permissions:
   contents: write
 
 jobs:
+  fresh-install:
+    runs-on: ubuntu-20.04
+    steps:
+      - name: Checkout repository
+        uses: actions/checkout@v2
+
+      - name: setup python
+        uses: actions/setup-python@v5
+        with:
+          python-version: 3.11.8
+
+      - name: Install dependencies
+        run: |
+            pip install .
+
   tests:
     runs-on: self-hosted
     env:
@@ -25,6 +40,8 @@ jobs:
       - name: Install dependencies
         run: |
             pip install -r requirements-dev.txt
+            # Currently, the self-hosted runner uses a pre-installed version of mamba-ssm (v2.2.2). Using the latest version (v2.2.4) breaks the pip installation.
+            # TODO fix the installation of helical[mamba-ssm], and add it in the GitHub actions (CI/CD pipeline).
             pip install .
             
       # First download before tests as they make use of the downloaded files 
@@ -42,6 +59,27 @@ jobs:
           name: coverage-report
           path: html_cov/
 
+  integration-tests:
+    needs: tests
+    runs-on: self-hosted
+    steps:
+      - name: Checkout repository
+        uses: actions/checkout@v2
+
+      - name: setup python
+        uses: actions/setup-python@v5
+        with:
+          python-version: 3.11.8
+
+      - name: Install dependencies
+        run: |
+            pip install -r requirements-dev.txt
+      
+      # First download before tests as they make use of the downloaded files 
+      - name: Download all files
+        run: |
+          python ci/download_all.py
+
       - name: Execute Geneformer v1
         run: |
           python examples/run_models/run_geneformer.py ++model_name="gf-12L-30M-i2048"
diff --git a/pyproject.toml b/pyproject.toml
index 6637dbbd..c9dbe4c7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "helical"
-version = "0.0.1a10"
+version = "0.0.1a11"
 authors = [
   { name="Helical Team", email="support@helical-ai.com" },
 ]
@@ -49,7 +49,7 @@ dependencies = [
 
 [project.optional-dependencies]
 mamba-ssm = [
-    'mamba-ssm @ git+https://github.com/state-spaces/mamba.git@v2.2.3',
+    'mamba-ssm==2.2.4',
     'causal-conv1d==1.4.0',
 ]
 

From 8698c50aa3776399cd7a7aef0bf5175fac22f4b2 Mon Sep 17 00:00:00 2001
From: Matthew Wood <matthew@helical-ai.com>
Date: Tue, 10 Dec 2024 19:08:35 +0100
Subject: [PATCH 21/27] Add binning seed to scGPT

---
 examples/run_models/configs/scgpt_config.yaml | 3 ++-
 helical/models/scgpt/model.py                 | 5 +++++
 helical/models/scgpt/scgpt_config.py          | 4 +++-
 3 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/examples/run_models/configs/scgpt_config.yaml b/examples/run_models/configs/scgpt_config.yaml
index a4ea3599..b7d2f651 100644
--- a/examples/run_models/configs/scgpt_config.yaml
+++ b/examples/run_models/configs/scgpt_config.yaml
@@ -13,4 +13,5 @@ pad_value: -2
 world_size: 8
 accelerator: False
 device: "cuda"
-use_fast_transformer: False
\ No newline at end of file
+use_fast_transformer: False
+binning_seed: 123
\ No newline at end of file
diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 938d4a0c..842dfd34 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -68,6 +68,8 @@ def __init__(self, configurer: scGPTConfig = configurer) -> None:
             downloader.download_via_name(file)
 
         self.model, self.vocab = load_model(self.config)
+
+        self.model.eval()
         
         if self.config["accelerator"]:
             self.accelerator = Accelerator(project_dir=self.config["model_path"].parent)
@@ -92,6 +94,9 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
             If `emb_mode` is set to "gene", the embeddings are returned as a list of pd.Series which contain a mapping of gene_name:embedding for each cell.
         """
         LOGGER.info(f"Inference started:")
+        np.random.seed(self.config["binning_seed"])
+
+        self.model.eval()
 
         try:
             use_batch_labels = dataset.batch_ids is not None
diff --git a/helical/models/scgpt/scgpt_config.py b/helical/models/scgpt/scgpt_config.py
index 60dcda85..f895a888 100644
--- a/helical/models/scgpt/scgpt_config.py
+++ b/helical/models/scgpt/scgpt_config.py
@@ -68,6 +68,7 @@ def __init__(
             accelerator: Optional[bool] = False,
             device: Literal["cpu", "cuda"] = "cpu",
             use_fast_transformer: bool = False,
+            binning_seed: int = 123
             ):
         
         model_name = 'best_model' # TODO: Include more models
@@ -96,5 +97,6 @@ def __init__(
             "accelerator": accelerator,
             "device": device,
             "use_fast_transformer": use_fast_transformer,
-            "MAX_LENGTH": 1200
+            "MAX_LENGTH": 1200,
+            "binning_seed": binning_seed
             }
\ No newline at end of file

From 55106b364b6a7bb391d3a4432af76cd0388f0783 Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <benoit@helical-ai.com>
Date: Wed, 11 Dec 2024 09:17:56 +0100
Subject: [PATCH 22/27] Cleanup main.yml worflow

---
 .github/workflows/main.yml | 27 +++++++--------------------
 1 file changed, 7 insertions(+), 20 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index a66e2bf2..1da5a638 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -8,8 +8,6 @@ on:
 jobs:
   tests:
     runs-on: self-hosted
-    env:
-      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -40,17 +38,12 @@ jobs:
         with:
           name: coverage-report
           path: html_cov/
-
-      # Does not seem to work but would be nice to have
-      # - name: Pytest coverage comment
-      #   uses: MishaKav/pytest-coverage-comment@main
-      #   with:
-      #     pytest-coverage-path: ./pytest-coverage.txt
-      #     junitxml-path: ./pytest.xml
-
+          
   integration-tests:
     needs: tests
     runs-on: self-hosted
+    env:
+      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -59,12 +52,8 @@ jobs:
         uses: actions/setup-python@v5
         with:
           python-version: 3.11.8
-
-      - name: Install dependencies
-        run: |
-            pip install -r requirements-dev.txt
-      
-      # First download before tests as they make use of the downloaded files 
+        
+      # Required to get the data
       - name: Download all files
         run: |
           python ci/download_all.py
@@ -133,6 +122,8 @@ jobs:
   notebooks:
     needs: tests
     runs-on: self-hosted
+    env:
+      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -141,10 +132,6 @@ jobs:
         uses: actions/setup-python@v5
         with:
           python-version: 3.11.8
-
-      - name: Install dependencies
-        run: |
-            pip install -r requirements-dev.txt
             
       - name: Reduce datasets to speedup checks
         run: |

From 505a5a72a4551c1b7a90dcda6caf6906ae5d0b9f Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <benoit@helical-ai.com>
Date: Wed, 11 Dec 2024 10:47:36 +0100
Subject: [PATCH 23/27] Install helical for each job to avoid different package
 is used in case another jub was scheduled in between

---
 .github/workflows/main.yml    | 12 +++++++++++-
 .github/workflows/release.yml | 24 ++++++++++++++----------
 2 files changed, 25 insertions(+), 11 deletions(-)

diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 2b2cdbaa..86adb0b6 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -52,6 +52,11 @@ jobs:
         uses: actions/setup-python@v5
         with:
           python-version: 3.11.8
+
+      # because jobs may not be run in the same order, we need to install the dependencies again
+      - name: Install helical
+        run: |
+            pip install .
         
       # Required to get the data
       - name: Download all files
@@ -132,7 +137,12 @@ jobs:
         uses: actions/setup-python@v5
         with:
           python-version: 3.11.8
-            
+
+      # because jobs may not be run in the same order, we need to install the dependencies again
+      - name: Install helical
+        run: |
+            pip install .
+        
       - name: Reduce datasets to speedup checks
         run: |
           sed -i 's/train\[:65%\]/train\[:5%\]/g' ./examples/notebooks/Cell-Type-Annotation.ipynb
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 39116a94..0b672b10 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -26,8 +26,6 @@ jobs:
 
   tests:
     runs-on: self-hosted
-    env:
-      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -62,6 +60,8 @@ jobs:
   integration-tests:
     needs: tests
     runs-on: self-hosted
+    env:
+      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -71,11 +71,12 @@ jobs:
         with:
           python-version: 3.11.8
 
-      - name: Install dependencies
+      # because jobs may not be run in the same order, we need to install the dependencies again
+      - name: Install helical
         run: |
-            pip install -r requirements-dev.txt
-      
-      # First download before tests as they make use of the downloaded files 
+            pip install .
+        
+      # Required to get the data
       - name: Download all files
         run: |
           python ci/download_all.py
@@ -116,7 +117,7 @@ jobs:
         run: |
           python examples/run_models/run_hyena_dna.py ++device="cuda"
 
-      - name: Execute Hyena
+      - name: Fine-tune Hyena
         run: |
           python examples/fine_tune_models/fine_tune_hyena_dna.py ++device="cuda"
 
@@ -144,6 +145,8 @@ jobs:
   notebooks:
     needs: tests
     runs-on: self-hosted
+    env:
+      CUDA_VISIBLE_DEVICES: 0
     steps:
       - name: Checkout repository
         uses: actions/checkout@v2
@@ -153,10 +156,11 @@ jobs:
         with:
           python-version: 3.11.8
 
-      - name: Install dependencies
+      # because jobs may not be run in the same order, we need to install the dependencies again
+      - name: Install helical
         run: |
-            pip install -r requirements-dev.txt
-            
+            pip install .
+                    
       - name: Reduce datasets to speedup checks
         run: |
           sed -i 's/train\[:65%\]/train\[:5%\]/g' ./examples/notebooks/Cell-Type-Annotation.ipynb

From 2d69c85df64e296ae4054660404e4da6f1f22623 Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com>
Date: Wed, 11 Dec 2024 11:37:21 +0100
Subject: [PATCH 24/27] Update version in pyproject.toml

---
 pyproject.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyproject.toml b/pyproject.toml
index c9dbe4c7..6923bbb9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "helical"
-version = "0.0.1a11"
+version = "0.0.1a12"
 authors = [
   { name="Helical Team", email="support@helical-ai.com" },
 ]

From 49f93c1818fbf3d6ba71b06106b895b170253e2b Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <benoit@helical-ai.com>
Date: Wed, 11 Dec 2024 12:18:30 +0100
Subject: [PATCH 25/27] Torch had another seed which was not set

---
 helical/models/scgpt/model.py | 5 ++++-
 1 file changed, 4 insertions(+), 1 deletion(-)

diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py
index 842dfd34..b718771b 100644
--- a/helical/models/scgpt/model.py
+++ b/helical/models/scgpt/model.py
@@ -94,8 +94,11 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
             If `emb_mode` is set to "gene", the embeddings are returned as a list of pd.Series which contain a mapping of gene_name:embedding for each cell.
         """
         LOGGER.info(f"Inference started:")
+                
+        # fix seeds
         np.random.seed(self.config["binning_seed"])
-
+        torch.manual_seed(self.config["binning_seed"])
+        
         self.model.eval()
 
         try:

From 3789e9eaa414a613532e40e42a6e654cffb75217 Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com>
Date: Wed, 11 Dec 2024 13:08:37 +0100
Subject: [PATCH 26/27] Update version in pyproject.toml

---
 pyproject.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pyproject.toml b/pyproject.toml
index 6923bbb9..bc8c7638 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "helical"
-version = "0.0.1a12"
+version = "0.0.1a13"
 authors = [
   { name="Helical Team", email="support@helical-ai.com" },
 ]

From 3969889018736a878ed002a2a9b2b2f19122eeef Mon Sep 17 00:00:00 2001
From: Benoit Putzeys <benoit@helical-ai.com>
Date: Wed, 11 Dec 2024 14:12:18 +0100
Subject: [PATCH 27/27] Update versions

---
 CITATION.bib                    | 4 ++--
 README.md                       | 4 ++--
 docs/index.md                   | 4 ++--
 docs/model_cards/helix_mrna.md  | 4 ++--
 docs/model_cards/mamba2_mrna.md | 4 ++--
 pyproject.toml                  | 2 +-
 6 files changed, 11 insertions(+), 11 deletions(-)

diff --git a/CITATION.bib b/CITATION.bib
index 0307184f..dded7608 100644
--- a/CITATION.bib
+++ b/CITATION.bib
@@ -1,10 +1,10 @@
 @software{allard_2024_13135902,
   author       = {Helical Team},
-  title        = {helicalAI/helical: v0.0.1-alpha8},
+  title        = {helicalAI/helical: v0.0.1a14},
   month        = nov,
   year         = 2024,
   publisher    = {Zenodo},
-  version      = {0.0.1a6},
+  version      = {0.0.1a14},
   doi          = {10.5281/zenodo.13135902},
   url          = {https://doi.org/10.5281/zenodo.13135902}
 }
diff --git a/README.md b/README.md
index bd7b0b64..866ae568 100644
--- a/README.md
+++ b/README.md
@@ -165,11 +165,11 @@ Please use this BibTeX to cite this repository in your publications:
 ```bibtex
 @software{allard_2024_13135902,
   author       = {Helical Team},
-  title        = {helicalAI/helical: v0.0.1-alpha8},
+  title        = {helicalAI/helical: v0.0.1a14},
   month        = nov,
   year         = 2024,
   publisher    = {Zenodo},
-  version      = {0.0.1a6},
+  version      = {0.0.1a14},
   doi          = {10.5281/zenodo.13135902},
   url          = {https://doi.org/10.5281/zenodo.13135902}
 }
diff --git a/docs/index.md b/docs/index.md
index d43ff629..f39d5396 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -145,11 +145,11 @@ Please use this BibTeX to cite this repository in your publications:
 ```bibtex
 @software{allard_2024_13135902,
   author       = {Helical Team},
-  title        = {helicalAI/helical: v0.0.1-alpha8},
+  title        = {helicalAI/helical: v0.0.1a14},
   month        = nov,
   year         = 2024,
   publisher    = {Zenodo},
-  version      = {0.0.1a5},
+  version      = {0.0.1a14},
   doi          = {10.5281/zenodo.13135902},
   url          = {https://doi.org/10.5281/zenodo.13135902}
 }
diff --git a/docs/model_cards/helix_mrna.md b/docs/model_cards/helix_mrna.md
index 0b11b33f..58d97b88 100644
--- a/docs/model_cards/helix_mrna.md
+++ b/docs/model_cards/helix_mrna.md
@@ -142,11 +142,11 @@ support@helical-ai.com
 ```bibtex
 @software{allard_2024_13135902,
   author       = {Helical Team},
-  title        = {helicalAI/helical: v0.0.1-alpha8},
+  title        = {helicalAI/helical: v0.0.1a14},
   month        = nov,
   year         = 2024,
   publisher    = {Zenodo},
-  version      = {0.0.1a6},
+  version      = {0.0.1a14},
   doi          = {10.5281/zenodo.13135902},
   url          = {https://doi.org/10.5281/zenodo.13135902}
 }
diff --git a/docs/model_cards/mamba2_mrna.md b/docs/model_cards/mamba2_mrna.md
index e33d472b..865ad4f2 100644
--- a/docs/model_cards/mamba2_mrna.md
+++ b/docs/model_cards/mamba2_mrna.md
@@ -123,11 +123,11 @@ support@helical-ai.com
 ```bibtex
 @software{allard_2024_13135902,
   author       = {Helical Team},
-  title        = {helicalAI/helical: v0.0.1-alpha8},
+  title        = {helicalAI/helical: v0.0.1a14},
   month        = nov,
   year         = 2024,
   publisher    = {Zenodo},
-  version      = {0.0.1a6},
+  version      = {0.0.1a14},
   doi          = {10.5281/zenodo.13135902},
   url          = {https://doi.org/10.5281/zenodo.13135902}
 }
diff --git a/pyproject.toml b/pyproject.toml
index bc8c7638..285a105f 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
 
 [project]
 name = "helical"
-version = "0.0.1a13"
+version = "0.0.1a14"
 authors = [
   { name="Helical Team", email="support@helical-ai.com" },
 ]