Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug fix: HyenaDNA variable length input sequences #161

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/tests/test_hyena_dna/test_hyena_dna_fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def hyenaDNAFineTune(self, request):
@pytest.fixture
def mock_data(self, hyenaDNAFineTune):
input_sequences = ["AAAA", "CCCC", "TTTT", "ACGT", "ACGN", "ANNT"]
labels = [0, 0, 0, 0, 0, 0, 0]
labels = [0, 0, 0, 0, 0, 0]
tokenized_sequences = hyenaDNAFineTune.process_data(input_sequences)
return tokenized_sequences, labels

Expand Down
56 changes: 22 additions & 34 deletions ci/tests/test_hyena_dna/test_hyena_dna_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from helical.models.hyena_dna.model import HyenaDNAConfig
import pytest
import torch
from helical.models.hyena_dna.model import HyenaDNA
from helical.models.hyena_dna.hyena_dna_utils import HyenaDNADataset
import numpy as np

@pytest.mark.parametrize("model_name, d_model, d_inner", [
("hyenadna-tiny-1k-seqlen", 128, 512),
Expand Down Expand Up @@ -39,39 +38,28 @@ def test_hyena_dna__invalid_model_names(model_name):
with pytest.raises(ValueError):
HyenaDNAConfig(model_name=model_name)

@pytest.mark.parametrize("input_sequence, expected_output, raise_error", [
@pytest.mark.parametrize("input_sequence, expected_output", [
# Valid DNA sequences
("", [0, 1], False),
("A", [0, 7, 1], False),
("CC", [0, 8, 8, 1], False),
("TTTT", [0, 10, 10, 10, 10, 1], False),
("ACGTN", [0, 7, 8, 9, 10, 11, 1], False),
("ACGT" * 256, [0] + [7, 8, 9, 10] * 256 + [1], False),
# Invalid sequences / sequences with uncertain 'N' nucleodites
("BHIK", [0, 6, 6, 6, 6, 1], True),
("ANNTBH", [0, 7, 11, 11, 10, 6, 6, 1], True),

("", [0, 1]),
("A", [0, 7, 1]),
("CC", [0, 8, 8, 1]),
("TTTT", [0, 10, 10, 10, 10, 1]),
("ACGTN", [0, 7, 8, 9, 10, 11, 1]),
("ACGT" * 256, [0] + [7, 8, 9, 10] * 256 + [1])
])
def test_hyena_dna_process_data(input_sequence, expected_output, raise_error):
"""
Test the process_data method of the HyenaDNA model.
The input DNA sequence is tokenized and the output is compared to the expected output.

Args:
input_sequence (str): The input DNA sequence to be processed.
expected_output (int): The expected output of the process_data method.

Returns:
None
def test_hyena_dna_process_data(input_sequence, expected_output):
model = HyenaDNA()
output = model.process_data([input_sequence])
expected = np.array(expected_output)
assert np.all(np.equal(np.array(output["input_ids"][0]), expected))

Raises:
AssertionError: If the output of the process_data method does not match the expected output.
"""
@pytest.mark.parametrize("input_sequence, expected_output", [
(
["A", "CC", "TTTT", "ACGTN", "ACGT"],
[[4, 4, 4, 4, 0, 7, 1], [4, 4, 4, 0, 8, 8, 1], [4, 0, 10, 10, 10, 10, 1], [0, 7, 8, 9, 10, 11, 1], [4, 0, 7, 8, 9, 10, 1]]
)
])
def test_hyena_dna_process_data_variable_length_sequences(input_sequence, expected_output):
model = HyenaDNA()
if raise_error:
with pytest.raises(ValueError):
model.process_data([input_sequence])
else:
output = model.process_data([input_sequence])
expected = torch.tensor([expected_output])
assert torch.equal(output.sequences, expected)
dataset = model.process_data(input_sequence)
assert np.all(np.equal(np.array(expected_output), np.array(dataset["input_ids"])))
2 changes: 1 addition & 1 deletion examples/fine_tune_models/fine_tune_hyena_dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

@hydra.main(version_base=None, config_path="../run_models/configs", config_name="hyena_dna_config")
def run_fine_tuning(cfg: DictConfig):
input_sequences = ["ACT"*20, "ATG"*20, "ATG"*20, "ACT"*20, "ATT"*20]
input_sequences = ["ACT"*20, "ATG"*10, "ATG"*20, "ACT"*10, "ATT"*20]
labels = [0, 2, 2, 0, 1]

hyena_dna_config = HyenaDNAConfig(**cfg)
Expand Down
9 changes: 5 additions & 4 deletions examples/run_models/run_hyena_dna.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

@hydra.main(version_base=None, config_path="configs", config_name="hyena_dna_config")
def run(cfg: DictConfig):

hyena_config = HyenaDNAConfig(**cfg)
model = HyenaDNA(configurer = hyena_config)
sequence = ['ACTG' * int(1024/4)]
tokenized_sequence = model.process_data(sequence)
embeddings = model.get_embeddings(tokenized_sequence)

sequence = ["A", "CC", "TTTT", "ACGTN", "ACGT"]

dataset = model.process_data(sequence)
embeddings = model.get_embeddings(dataset)
print(embeddings.shape)

if __name__ == "__main__":
Expand Down
45 changes: 34 additions & 11 deletions helical/models/hyena_dna/fine_tuning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import optim
import torch
from torch.nn.modules import loss
from .hyena_dna_utils import HyenaDNADataset
from datasets import Dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import get_scheduler
Expand Down Expand Up @@ -75,15 +75,16 @@ def _forward(self, x):

def train(
self,
train_dataset: HyenaDNADataset,
train_dataset: Dataset,
train_labels: list[int],
validation_dataset: HyenaDNADataset = None,
validation_dataset: Dataset = None,
validation_labels: list[int] = None,
optimizer: optim = optim.AdamW,
optimizer_params: dict = {'lr': 0.0001},
loss_function: loss = loss.CrossEntropyLoss(),
epochs: int = 1,
lr_scheduler_params: Optional[dict] = None):
lr_scheduler_params: Optional[dict] = None,
shuffle: bool = True):
"""Fine-tunes the Hyena-DNA model with different head modules.

Parameters
Expand All @@ -108,13 +109,15 @@ def train(
lr_scheduler_params : dict, default=None
The learning rate scheduler parameters for the transformers get_scheduler method. The optimizer will be taken from the optimizer input and should not be included in the learning scheduler parameters. If not specified, a constant learning rate will be used.
e.g. lr_scheduler_params = { 'name': 'linear', 'num_warmup_steps': 0 }. num_steps will be calculated based on the number of epochs and the length of the training dataset.
shuffle : bool, default=True
Whether to shuffle the training data or not.
"""
train_dataset.set_labels(train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=self.config["batch_size"])
train_dataset = self._add_data_column(train_dataset, np.array(train_labels))
train_dataloader = DataLoader(train_dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"], shuffle=shuffle)

if validation_dataset is not None and validation_labels is not None:
validation_dataset.set_labels(validation_labels)
validation_dataloader = DataLoader(validation_dataset, batch_size=self.config["batch_size"])
validation_dataset = self._add_data_column(validation_dataset, np.array(validation_labels))
validation_dataloader = DataLoader(validation_dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"])

self.to(self.config["device"])
self.model.train()
Expand Down Expand Up @@ -163,7 +166,7 @@ def train(

def get_outputs(
self,
dataset: HyenaDNADataset) -> np.ndarray:
dataset: Dataset) -> np.ndarray:
"""Get the outputs of the fine-tuned model.

Parameters
Expand All @@ -176,7 +179,7 @@ def get_outputs(
np.ndarray
The outputs of the model
"""
data_loader = DataLoader(dataset, batch_size=self.config["batch_size"])
data_loader = DataLoader(dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"])

self.to(self.config["device"])
self.model.eval()
Expand All @@ -189,4 +192,24 @@ def get_outputs(
output = self._forward(input_data)
outputs.append(output.detach().cpu().numpy())

return np.vstack(outputs)
return np.vstack(outputs)

def _add_data_column(self, dataset: Dataset, data: list, column_name: str="labels") -> Dataset:
"""
Add a column to the dataset.

Parameters
----------
dataset : Dataset
The dataset to add the column to.
data : list
The data to add to the column.
column_name : str, optional, default="labels"
The name of the column to add.
"""
if len(data.shape) > 1:
for i in range(len(data[0])): # Assume all inner lists are the same length
dataset = dataset.add_column(f"{column_name}", [row[i] for row in data])
else: # If 1D
dataset = dataset.add_column(column_name, data)
return dataset
35 changes: 0 additions & 35 deletions helical/models/hyena_dna/hyena_dna_utils.py

This file was deleted.

48 changes: 30 additions & 18 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from helical.models.hyena_dna.hyena_dna_config import HyenaDNAConfig
from helical.models.base_models import HelicalDNAModel
from tqdm import tqdm
from .hyena_dna_utils import HyenaDNADataset
from datasets import Dataset
from helical.models.hyena_dna.pretrained_model import HyenaDNAPreTrainedModel
import torch
from .standalone_hyenadna import CharacterTokenizer
Expand Down Expand Up @@ -69,44 +69,44 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
self.model.eval()
LOGGER.info(f"Model finished initializing.")

def process_data(self, sequence: list[str]) -> HyenaDNADataset:
"""Process the input DNA sequences.
def process_data(self, sequences: list[str], return_tensors: str="pt", padding: str="max_length", truncation: bool=True) -> Dataset:
"""Process the input DNA sequence.

Parameters
----------
sequences : list[str]
The input DNA sequences to be processed.
return_tensors : str, optional, default="pt"
The return type of the processed data.
padding : str, optional, default="max_length"
The padding strategy to be used.
truncation : bool, optional, default=True
Whether to truncate the sequences or not.

Returns
-------
HyenaDNADataset
Dataset
Containing processed DNA sequences.

"""
LOGGER.info(f"Processing data")
LOGGER.info("Processing data")

self.ensure_dna_sequence_validity(sequence)
self.ensure_dna_sequence_validity(sequences)

processed_sequences = []
for seq in tqdm(sequence, desc="Processing sequences"):
tok_seq = self.tokenizer(seq)
tok_seq_input_ids = tok_seq["input_ids"]
max_length = len(max(sequences, key=len))+2 # +2 for special tokens at the beginning and end of sequences

tensor = torch.LongTensor(tok_seq_input_ids)
tensor = tensor.to(self.device)
processed_sequences.append(tensor)
tokenized_sequences = self.tokenizer(sequences, return_tensors=return_tensors, padding=padding, truncation=truncation, max_length=max_length)

dataset = HyenaDNADataset(torch.stack(processed_sequences))
dataset = Dataset.from_dict(tokenized_sequences)
LOGGER.info(f"Data processing finished.")
giogix2 marked this conversation as resolved.
Show resolved Hide resolved

return dataset

def get_embeddings(self, dataset: HyenaDNADataset) -> torch.Tensor:
def get_embeddings(self, dataset: Dataset) -> torch.Tensor:
"""Get the embeddings for the tokenized sequence.

Parameters
----------
dataset : HyenaDNADataset
dataset : Dataset
The output dataset from `process_data`.

Returns
Expand All @@ -117,11 +117,23 @@ def get_embeddings(self, dataset: HyenaDNADataset) -> torch.Tensor:
"""
LOGGER.info(f"Inference started")

train_data_loader = DataLoader(dataset, batch_size=self.config["batch_size"])
train_data_loader = DataLoader(dataset, collate_fn=self._collate_fn, batch_size=self.config["batch_size"])
with torch.inference_mode():
embeddings = []
for batch in tqdm(train_data_loader, desc="Getting embeddings"):
input_data = batch["input_ids"].to(self.device)
embeddings.append(self.model(input_data).detach().cpu().numpy())

return np.vstack(embeddings)


def _collate_fn(self, batch):
input_ids = torch.tensor([item["input_ids"] for item in batch])
batch_dict = {
"input_ids": input_ids,
}

if "labels" in batch[0]:
batch_dict["labels"] = torch.tensor([item["labels"] for item in batch])

return batch_dict
Loading