Skip to content

Commit

Permalink
Bug fix: HyenaDNA variable length input sequences (#161)
Browse files Browse the repository at this point in the history
* Change HyenaDNA dataset to allow for different length sequences and make it easier to use

* Add logging back to data processing for HyenaDNA
  • Loading branch information
mattwoodx authored Dec 18, 2024
1 parent 5daf950 commit 69d0a83
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 105 deletions.
2 changes: 1 addition & 1 deletion ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def hyenaDNAFineTune(self, request):
@pytest.fixture
def mock_data(self, hyenaDNAFineTune):
input_sequences = ["AAAA", "CCCC", "TTTT", "ACGT", "ACGN", "ANNT"]
labels = [0, 0, 0, 0, 0, 0, 0]
labels = [0, 0, 0, 0, 0, 0]
tokenized_sequences = hyenaDNAFineTune.process_data(input_sequences)
return tokenized_sequences, labels

Expand Down
56 changes: 22 additions & 34 deletions ci/tests/test_hyena_dna/test_hyena_dna_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from helical.models.hyena_dna.model import HyenaDNAConfig
import pytest
import torch
from helical.models.hyena_dna.model import HyenaDNA
from helical.models.hyena_dna.hyena_dna_utils import HyenaDNADataset
import numpy as np

@pytest.mark.parametrize("model_name, d_model, d_inner", [
("hyenadna-tiny-1k-seqlen", 128, 512),
Expand Down Expand Up @@ -39,39 +38,28 @@ def test_hyena_dna__invalid_model_names(model_name):
with pytest.raises(ValueError):
HyenaDNAConfig(model_name=model_name)

@pytest.mark.parametrize("input_sequence, expected_output, raise_error", [
@pytest.mark.parametrize("input_sequence, expected_output", [
# Valid DNA sequences
("", [0, 1], False),
("A", [0, 7, 1], False),
("CC", [0, 8, 8, 1], False),
("TTTT", [0, 10, 10, 10, 10, 1], False),
("ACGTN", [0, 7, 8, 9, 10, 11, 1], False),
("ACGT" * 256, [0] + [7, 8, 9, 10] * 256 + [1], False),
# Invalid sequences / sequences with uncertain 'N' nucleodites
("BHIK", [0, 6, 6, 6, 6, 1], True),
("ANNTBH", [0, 7, 11, 11, 10, 6, 6, 1], True),
("", [0, 1]),
("A", [0, 7, 1]),
("CC", [0, 8, 8, 1]),
("TTTT", [0, 10, 10, 10, 10, 1]),
("ACGTN", [0, 7, 8, 9, 10, 11, 1]),
("ACGT" * 256, [0] + [7, 8, 9, 10] * 256 + [1])
])
def test_hyena_dna_process_data(input_sequence, expected_output, raise_error):
"""
Test the process_data method of the HyenaDNA model.
The input DNA sequence is tokenized and the output is compared to the expected output.
Args:
input_sequence (str): The input DNA sequence to be processed.
expected_output (int): The expected output of the process_data method.
Returns:
None
def test_hyena_dna_process_data(input_sequence, expected_output):
model = HyenaDNA()
output = model.process_data([input_sequence])
expected = np.array(expected_output)
assert np.all(np.equal(np.array(output["input_ids"][0]), expected))

Raises:
AssertionError: If the output of the process_data method does not match the expected output.
"""
@pytest.mark.parametrize("input_sequence, expected_output", [
(
["A", "CC", "TTTT", "ACGTN", "ACGT"],
[[4, 4, 4, 4, 0, 7, 1], [4, 4, 4, 0, 8, 8, 1], [4, 0, 10, 10, 10, 10, 1], [0, 7, 8, 9, 10, 11, 1], [4, 0, 7, 8, 9, 10, 1]]
)
])
def test_hyena_dna_process_data_variable_length_sequences(input_sequence, expected_output):
model = HyenaDNA()
if raise_error:
with pytest.raises(ValueError):
model.process_data([input_sequence])
else:
output = model.process_data([input_sequence])
expected = torch.tensor([expected_output])
assert torch.equal(output.sequences, expected)
dataset = model.process_data(input_sequence)
assert np.all(np.equal(np.array(expected_output), np.array(dataset["input_ids"])))
2 changes: 1 addition & 1 deletion examples/fine_tune_models/fine_tune_hyena_dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@hydra.main(version_base=None, config_path="../run_models/configs", config_name="hyena_dna_config")
def run_fine_tuning(cfg: DictConfig):
input_sequences = ["ACT"*20, "ATG"*20, "ATG"*20, "ACT"*20, "ATT"*20]
input_sequences = ["ACT"*20, "ATG"*10, "ATG"*20, "ACT"*10, "ATT"*20]
labels = [0, 2, 2, 0, 1]

hyena_dna_config = HyenaDNAConfig(**cfg)
Expand Down
9 changes: 5 additions & 4 deletions examples/run_models/run_hyena_dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

@hydra.main(version_base=None, config_path="configs", config_name="hyena_dna_config")
def run(cfg: DictConfig):

hyena_config = HyenaDNAConfig(**cfg)
model = HyenaDNA(configurer = hyena_config)
sequence = ['ACTG' * int(1024/4)]
tokenized_sequence = model.process_data(sequence)
embeddings = model.get_embeddings(tokenized_sequence)

sequence = ["A", "CC", "TTTT", "ACGTN", "ACGT"]

dataset = model.process_data(sequence)
embeddings = model.get_embeddings(dataset)
print(embeddings.shape)

if __name__ == "__main__":
Expand Down
45 changes: 34 additions & 11 deletions helical/models/hyena_dna/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import optim
import torch
from torch.nn.modules import loss
from .hyena_dna_utils import HyenaDNADataset
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import get_scheduler
Expand Down Expand Up @@ -75,15 +75,16 @@ def _forward(self, x):

def train(
self,
train_dataset: HyenaDNADataset,
train_dataset: Dataset,
train_labels: list[int],
validation_dataset: HyenaDNADataset = None,
validation_dataset: Dataset = None,
validation_labels: list[int] = None,
optimizer: optim = optim.AdamW,
optimizer_params: dict = {'lr': 0.0001},
loss_function: loss = loss.CrossEntropyLoss(),
epochs: int = 1,
lr_scheduler_params: Optional[dict] = None):
lr_scheduler_params: Optional[dict] = None,
shuffle: bool = True):
"""Fine-tunes the Hyena-DNA model with different head modules.
Parameters
Expand All @@ -108,13 +109,15 @@ def train(
lr_scheduler_params : dict, default=None
The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, a constant learning rate will be used.
e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0 }. num_steps will be calculated based on the number of epochs and the length of the training dataset.
shuffle : bool, default=True
Whether to shuffle the training data or not.
"""
train_dataset.set_labels(train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=self.config["batch_size"])
train_dataset = self._add_data_column(train_dataset, np.array(train_labels))
train_dataloader = DataLoader(train_dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"], shuffle=shuffle)

if validation_dataset is not None and validation_labels is not None:
validation_dataset.set_labels(validation_labels)
validation_dataloader = DataLoader(validation_dataset, batch_size=self.config["batch_size"])
validation_dataset = self._add_data_column(validation_dataset, np.array(validation_labels))
validation_dataloader = DataLoader(validation_dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"])

self.to(self.config["device"])
self.model.train()
Expand Down Expand Up @@ -163,7 +166,7 @@ def train(

def get_outputs(
self,
dataset: HyenaDNADataset) -> np.ndarray:
dataset: Dataset) -> np.ndarray:
"""Get the outputs of the fine-tuned model.
Parameters
Expand All @@ -176,7 +179,7 @@ def get_outputs(
np.ndarray
The outputs of the model
"""
data_loader = DataLoader(dataset, batch_size=self.config["batch_size"])
data_loader = DataLoader(dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"])

self.to(self.config["device"])
self.model.eval()
Expand All @@ -189,4 +192,24 @@ def get_outputs(
output = self._forward(input_data)
outputs.append(output.detach().cpu().numpy())

return np.vstack(outputs)
return np.vstack(outputs)

def _add_data_column(self, dataset: Dataset, data: list, column_name: str="labels") -> Dataset:
"""
Add a column to the dataset.
Parameters
----------
dataset : Dataset
The dataset to add the column to.
data : list
The data to add to the column.
column_name : str, optional, default="labels"
The name of the column to add.
"""
if len(data.shape) > 1:
for i in range(len(data[0])): # Assume all inner lists are the same length
dataset = dataset.add_column(f"{column_name}", [row[i] for row in data])
else: # If 1D
dataset = dataset.add_column(column_name, data)
return dataset
35 changes: 0 additions & 35 deletions helical/models/hyena_dna/hyena_dna_utils.py

This file was deleted.

49 changes: 30 additions & 19 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from helical.models.hyena_dna.hyena_dna_config import HyenaDNAConfig
from helical.models.base_models import HelicalDNAModel
from tqdm import tqdm
from .hyena_dna_utils import HyenaDNADataset
from datasets import Dataset
from helical.models.hyena_dna.pretrained_model import HyenaDNAPreTrainedModel
import torch
from .standalone_hyenadna import CharacterTokenizer
Expand Down Expand Up @@ -69,44 +69,43 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
self.model.eval()
LOGGER.info(f"Model finished initializing.")

def process_data(self, sequence: list[str]) -> HyenaDNADataset:
"""Process the input DNA sequences.
def process_data(self, sequences: list[str], return_tensors: str="pt", padding: str="max_length", truncation: bool=True) -> Dataset:
"""Process the input DNA sequence.
Parameters
----------
sequences : list[str]
The input DNA sequences to be processed.
return_tensors : str, optional, default="pt"
The return type of the processed data.
padding : str, optional, default="max_length"
The padding strategy to be used.
truncation : bool, optional, default=True
Whether to truncate the sequences or not.
Returns
-------
HyenaDNADataset
Dataset
Containing processed DNA sequences.
"""
LOGGER.info(f"Processing data")
LOGGER.info("Processing data")

self.ensure_dna_sequence_validity(sequence)
self.ensure_dna_sequence_validity(sequences)

processed_sequences = []
for seq in tqdm(sequence, desc="Processing sequences"):
tok_seq = self.tokenizer(seq)
tok_seq_input_ids = tok_seq["input_ids"]
max_length = len(max(sequences, key=len))+2 # +2 for special tokens at the beginning and end of sequences

tensor = torch.LongTensor(tok_seq_input_ids)
tensor = tensor.to(self.device)
processed_sequences.append(tensor)
tokenized_sequences = self.tokenizer(sequences, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)

dataset = HyenaDNADataset(torch.stack(processed_sequences))
dataset = Dataset.from_dict(tokenized_sequences)
LOGGER.info(f"Data processing finished.")

return dataset

def get_embeddings(self, dataset: HyenaDNADataset) -> torch.Tensor:
def get_embeddings(self, dataset: Dataset) -> torch.Tensor:
"""Get the embeddings for the tokenized sequence.
Parameters
----------
dataset : HyenaDNADataset
dataset : Dataset
The output dataset from `process_data`.
Returns
Expand All @@ -117,11 +116,23 @@ def get_embeddings(self, dataset: HyenaDNADataset) -> torch.Tensor:
"""
LOGGER.info(f"Inference started")

train_data_loader = DataLoader(dataset, batch_size=self.config["batch_size"])
train_data_loader = DataLoader(dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"])
with torch.inference_mode():
embeddings = []
for batch in tqdm(train_data_loader, desc="Getting embeddings"):
input_data = batch["input_ids"].to(self.device)
embeddings.append(self.model(input_data).detach().cpu().numpy())

return np.vstack(embeddings)


def _collate_fn(self, batch):
input_ids = torch.tensor([item["input_ids"] for item in batch])
batch_dict = {
"input_ids": input_ids,
}

if "labels" in batch[0]:
batch_dict["labels"] = torch.tensor([item["labels"] for item in batch])

return batch_dict

0 comments on commit 69d0a83

Please sign in to comment.