diff --git a/ci/tests/test_geneformer/test_geneformer_model.py b/ci/tests/test_geneformer/test_geneformer_model.py index acfceed2..623966ad 100644 --- a/ci/tests/test_geneformer/test_geneformer_model.py +++ b/ci/tests/test_geneformer/test_geneformer_model.py @@ -1,55 +1,64 @@ import pytest -import torch -from helical.models.geneformer.model import Geneformer -from helical.models.geneformer.geneformer_config import GeneformerConfig -from helical.models.geneformer.geneformer_utils import get_embs, load_model -from helical.models.geneformer.fine_tuning_model import GeneformerFineTuningModel +from helical import GeneformerConfig, Geneformer, GeneformerFineTuningModel from anndata import AnnData +import torch +import pandas as pd +import numpy as np -class TestGeneformerModel: - @pytest.fixture(params=["gf-12L-30M-i2048", "gf-12L-95M-i4096"]) - def geneformer(self, request): - self.device = "cuda" if torch.cuda.is_available() else "cpu" - config = GeneformerConfig(model_name=request.param, device=self.device) - return Geneformer(config) - +class TestGeneformer: @pytest.fixture def mock_data(self): data = AnnData() - data.var['gene_symbols'] = ['SAMD11', 'PLEKHN1', 'HES4'] - data.var['ensembl_id'] = ['ENSG00000187634', 'ENSG00000187583', 'ENSG00000188290'] + data.var['gene_symbols'] = ['HES4', 'PLEKHN1', 'SAMD11'] data.obs["cell_type"] = ["CD4 T cells"] data.X = [[1, 2, 5]] return data - + @pytest.fixture - def fine_tune_mock_data(self): - labels = list([0]) - return labels - - def test_pass_invalid_model_name(self): - with pytest.raises(ValueError): - geneformer_config = GeneformerConfig(model_name='InvalidName') - - + def mock_embeddings_v1(self, mocker): + embs = mocker.Mock() + embs.hidden_states = [torch.tensor([[[5.0, 5.0, 5.0, 5.0, 5.0], + [1.0, 2.0, 3.0, 2.0, 1.0], + [6.0, 6.0, 6.0, 6.0, 6.0]]])]*12 + return embs + + @pytest.fixture + def mock_embeddings_v2(self, mocker): + embs = mocker.Mock() + embs.hidden_states = torch.tensor([[[6.0, 5.0, 7.0, 5.0, 5.0], + [5.0, 5.0, 5.0, 5.0, 5.0], + [1.0, 2.0, 3.0, 2.0, 1.0], + [6.0, 6.0, 6.0, 6.0, 6.0], + [6.0, 6.0, 1.0, 6.0, 2.0]]]).repeat(12, 1, 1, 1) + return embs + + @pytest.fixture(params=["gf-12L-30M-i2048", "gf-12L-95M-i4096"]) + def geneformer(self, request): + config = GeneformerConfig(model_name=request.param, batch_size=5) + geneformer = Geneformer(config) + return geneformer + def test_process_data_mapping_to_ensemble_ids(self, geneformer, mock_data): - assert mock_data.var['ensembl_id'][0] == 'ENSG00000187634' + # geneformer modifies the anndata in place and maps the gene names to ensembl id + geneformer.process_data(mock_data, gene_names="gene_symbols") + assert mock_data.var['ensembl_id'][0] == 'ENSG00000188290' # is the same as the above line but more verbose (linking the gene symbol to the ensembl id) assert mock_data.var[mock_data.var['gene_symbols'] == 'SAMD11']['ensembl_id'].values[0] == 'ENSG00000187634' assert mock_data.var[mock_data.var['gene_symbols'] == 'PLEKHN1']['ensembl_id'].values[0] == 'ENSG00000187583' assert mock_data.var[mock_data.var['gene_symbols'] == 'HES4']['ensembl_id'].values[0] == 'ENSG00000188290' - - def test_process_data_padding_and_masking_ids(self, geneformer, mock_data): - # for this token mapping, the padding token is 0 and the mask token is 1 - geneformer.process_data(mock_data, gene_names='gene_symbols') - assert geneformer.gene_token_dict.get("") == 0 - assert geneformer.gene_token_dict.get("") == 1 + + @pytest.mark.parametrize("invalid_model_names", ["gf-12L-35M-i2048", "gf-34L-30M-i5000"]) + def test_pass_invalid_model_name(self, invalid_model_names): + with pytest.raises(ValueError): + GeneformerConfig(model_name=invalid_model_names) def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self, geneformer, mock_data): + geneformer.process_data(mock_data, gene_names="gene_symbols") del mock_data.var['ensembl_id'] with pytest.raises(KeyError): geneformer.ensure_rna_data_validity(mock_data, "ensembl_id") - + + @pytest.mark.parametrize("gene_symbols, raises_error", [ (['ENSGSAMD11', 'ENSGPLEKHN1', 'ENSGHES4'], True), # humans @@ -66,126 +75,86 @@ def test_ensembl_data_is_caught(self, geneformer, mock_data, gene_symbols, raise else: geneformer.process_data(mock_data, "gene_symbols") - def test_cls_mode_with_v1_model_config(self, geneformer, mock_data): + + def test_cls_mode_with_v1_model_config(self, geneformer): if geneformer.config["special_token"]: pytest.skip("This test is only for v1 models and should thus be only executed once.") with pytest.raises(ValueError): - config = GeneformerConfig(model_name="gf-12L-30M-i2048", device="cpu", emb_mode='cls') - - def test_get_embs_cell_mode(self, geneformer, mock_data): - tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols') - model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - embs = get_embs( - model, - tokenized_dataset, - emb_mode="cell", - layer_to_quant=-1, - pad_token_id=geneformer.pad_token_id, - forward_batch_size=1, - token_gene_dict=geneformer.gene_token_dict, - device=device - ) - assert embs.shape == (1, model.config.hidden_size) - - def test_get_embs_cls_mode(self, geneformer, mock_data): - if not geneformer.config["special_token"]: - pytest.skip("This test is only for models with special tokens (v2)") - tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols') - model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - embs = get_embs( - model, - tokenized_dataset, - emb_mode="cls", - layer_to_quant=-1, - pad_token_id=geneformer.pad_token_id, - forward_batch_size=1, - gene_token_dict=geneformer.gene_token_dict, - device=device - ) - assert embs.shape == (1, model.config.hidden_size) - - def test_get_embs_gene_mode(self, geneformer, mock_data): - tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols') - model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - embs = get_embs( - model, - tokenized_dataset, - emb_mode="gene", - layer_to_quant=-1, - pad_token_id=geneformer.pad_token_id, - forward_batch_size=1, - gene_token_dict=geneformer.gene_token_dict, - device=device - ) - assert embs.shape[0] == 1 - assert embs.shape[2] == model.config.hidden_size - - def test_get_embs_different_layer(self, geneformer, mock_data): - tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols') - model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - embs_last = get_embs( - model, - tokenized_dataset, - emb_mode="cell", - layer_to_quant=-1, - pad_token_id=geneformer.pad_token_id, - forward_batch_size=1, - gene_token_dict=geneformer.gene_token_dict, - device=device - ) - embs_first = get_embs( - model, - tokenized_dataset, - emb_mode="cell", - layer_to_quant=0, - pad_token_id=geneformer.pad_token_id, - forward_batch_size=1, - gene_token_dict=geneformer.gene_token_dict, - device=device - ) - assert not torch.allclose(embs_last, embs_first) - - def test_get_embs_cell_mode(self, geneformer, mock_data): - tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols') - model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - embs = get_embs( - model, - tokenized_dataset, - emb_mode="cell", - layer_to_quant=-1, - pad_token_id=geneformer.pad_token_id, - forward_batch_size=1, - gene_token_dict=geneformer.gene_token_dict, - device=device - ) - assert embs.shape == (1, model.config.hidden_size) - - def test_cls_eos_tokens_presence(self, geneformer, mock_data): - geneformer.process_data(mock_data, gene_names='gene_symbols') - if geneformer.config["special_token"]: - assert "" in geneformer.tk.gene_token_dict - assert "" in geneformer.tk.gene_token_dict - else: - assert "" not in geneformer.tk.gene_token_dict - assert "" not in geneformer.tk.gene_token_dict + GeneformerConfig(model_name="gf-12L-30M-i2048", emb_mode='cls') + + @pytest.mark.parametrize("emb_mode", ["cell", "gene"]) + def test_get_embeddings_of_different_modes_v1(self, emb_mode, mock_data, mock_embeddings_v1, mocker): + config = GeneformerConfig(model_name="gf-12L-30M-i2048", batch_size=5, emb_mode=emb_mode) + geneformer = Geneformer(config) + mocker.patch.object(geneformer.model, "forward", return_value=mock_embeddings_v1) + + dataset = geneformer.process_data(mock_data, gene_names ="gene_symbols") + embeddings = geneformer.get_embeddings(dataset) + if emb_mode == "gene": + data_list = pd.Series({ + "ENSG00000187583": np.array([1.0, 2.0, 3.0, 2.0, 1.0]), + "ENSG00000187634": np.array([5.0, 5.0, 5.0, 5.0, 5.0]), + "ENSG00000188290": np.array([6.0, 6.0, 6.0, 6.0, 6.0]) + }) + for key in data_list.index: + assert np.all(np.equal(embeddings[0][key], data_list[key])) + + if emb_mode == "cell": + expected = np.array([[4, 4.333333, 4.666667, 4.333333, 4]]) + np.testing.assert_allclose( + embeddings, + expected, + rtol=1e-4, + atol=1e-4 + ) + + @pytest.mark.parametrize("emb_mode", ["cell", "gene", "cls"]) + def test_get_embeddings_of_different_modes_v2(self, emb_mode, mock_data, mock_embeddings_v2, mocker): + config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=5, emb_mode=emb_mode) + geneformer = Geneformer(config) + mocker.patch.object(geneformer.model, "forward", return_value=mock_embeddings_v2) + + dataset = geneformer.process_data(mock_data, gene_names ="gene_symbols") + embeddings = geneformer.get_embeddings(dataset) + if emb_mode == "gene": + data_list = pd.Series({ + "ENSG00000187583": np.array([1.0, 2.0, 3.0, 2.0, 1.0]), + "ENSG00000187634": np.array([5.0, 5.0, 5.0, 5.0, 5.0]), + "ENSG00000188290": np.array([6.0, 6.0, 6.0, 6.0, 6.0]) + }) + for key in data_list.index: + assert np.all(np.equal(embeddings[0][key], data_list[key])) + + if emb_mode == "cls": + assert (embeddings == np.array([6.0, 5.0, 7.0, 5.0, 5.0])).all() + if emb_mode == "cell": + expected = np.array([[4, 4.333333, 4.666667, 4.333333, 4]]) + np.testing.assert_allclose( + embeddings, + expected, + rtol=1e-4, + atol=1e-4 + ) + + @pytest.mark.parametrize("emb_mode", ["cell", "gene"]) + def test_fine_tune_classifier_returns_correct_shape(self, emb_mode, mock_data): + fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(emb_mode=emb_mode), fine_tuning_head="classification", output_size=1) + tokenized_dataset = fine_tuned_model.process_data(mock_data, gene_names='gene_symbols') + tokenized_dataset = tokenized_dataset.add_column('labels', list([0])) + + fine_tuned_model.train(train_dataset=tokenized_dataset, label='labels') - def test_model_input_size(self, geneformer): - assert geneformer.config["input_size"] == geneformer.configurer.model_map[geneformer.config["model_name"]]['input_size'] + outputs = fine_tuned_model.get_outputs(tokenized_dataset) + assert outputs.shape == (len(mock_data), 1) - def test_fine_tune_classifier_returns_correct_shape(self, mock_data, fine_tune_mock_data): - device = "cuda" if torch.cuda.is_available() else "cpu" - fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(device=device), fine_tuning_head="classification", output_size=1) + def test_fine_tune_classifier_cls_returns_correct_shape(self, mock_data): + fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(model_name="gf-12L-95M-i4096", emb_mode="cls"), fine_tuning_head="classification", output_size=1) tokenized_dataset = fine_tuned_model.process_data(mock_data, gene_names='gene_symbols') - tokenized_dataset = tokenized_dataset.add_column('labels', fine_tune_mock_data) + tokenized_dataset = tokenized_dataset.add_column('labels', [0]) fine_tuned_model.train(train_dataset=tokenized_dataset, label='labels') - assert fine_tuned_model is not None + outputs = fine_tuned_model.get_outputs(tokenized_dataset) - assert outputs.shape == (len(mock_data), len(fine_tune_mock_data)) + assert outputs.shape == (len(mock_data), 1) + - \ No newline at end of file diff --git a/examples/fine_tune_models/fine_tune_geneformer.py b/examples/fine_tune_models/fine_tune_geneformer.py index 2cbf7197..0715ef9a 100644 --- a/examples/fine_tune_models/fine_tune_geneformer.py +++ b/examples/fine_tune_models/fine_tune_geneformer.py @@ -11,6 +11,8 @@ def run_fine_tuning(cfg: DictConfig): # either load via huggingface # hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists") # ann_data = get_anndata_from_hf_dataset(hf_dataset) + + # or load directly from anndata file ann_data = ad.read_h5ad("./yolksac_human.h5ad") cell_types = list(ann_data.obs["LVL1"][:10]) @@ -32,5 +34,8 @@ def classes_to_ids(example): geneformer_fine_tune.train(train_dataset=dataset) + outputs = geneformer_fine_tune.get_outputs(dataset) + print(outputs) + if __name__ == "__main__": run_fine_tuning() \ No newline at end of file diff --git a/examples/run_models/run_geneformer.py b/examples/run_models/run_geneformer.py index cdcbf023..b5612291 100644 --- a/examples/run_models/run_geneformer.py +++ b/examples/run_models/run_geneformer.py @@ -20,7 +20,7 @@ def run(cfg: DictConfig): dataset = geneformer.process_data(ann_data[:10]) embeddings = geneformer.get_embeddings(dataset) - print(embeddings.shape) + print(embeddings) if __name__ == "__main__": run() \ No newline at end of file diff --git a/helical/models/classification/classifier.py b/helical/models/classification/classifier.py index d12813c4..144ca7c0 100644 --- a/helical/models/classification/classifier.py +++ b/helical/models/classification/classifier.py @@ -99,7 +99,7 @@ def train_classifier_head(self, LOGGER.info(f"Getting training embeddings with {base_model.__class__.__name__}.") dataset = base_model.process_data(train_anndata, gene_names) self.gene_names = gene_names - x = base_model.get_embeddings(dataset) + x = np.array(base_model.get_embeddings(dataset)) # then, train the classification model LOGGER.info(f"Training classification model '{self.name}'.") diff --git a/helical/models/geneformer/fine_tuning_model.py b/helical/models/geneformer/fine_tuning_model.py index 40c89bc3..45214495 100644 --- a/helical/models/geneformer/fine_tuning_model.py +++ b/helical/models/geneformer/fine_tuning_model.py @@ -4,7 +4,7 @@ import torch from torch import optim from torch.nn.modules import loss -from helical.models.geneformer.geneformer_utils import gen_attention_mask, get_model_input_size, pad_tensor_list +from helical.models.geneformer.geneformer_utils import gen_attention_mask, get_model_input_size, pad_tensor_list, _check_for_expected_special_tokens, mean_nonpadding_embs from datasets import Dataset import logging from tqdm import trange @@ -87,26 +87,41 @@ def __init__(self, self.fine_tuning_head.set_dim_size(self.config["embsize"]) - def _forward(self, input_ids: torch.Tensor, attention_mask_minibatch: torch.Tensor) -> torch.Tensor: + def _forward(self, input_ids: torch.tensor, attention_mask_minibatch: torch.tensor, original_lengths: torch.tensor) -> torch.tensor: """ Forward method of the fine-tuning model. Parameters ---------- - input_ids : torch.Tensor + input_ids : torch.tensor The input ids to the fine-tuning model. - attention_mask_minibatch : torch.Tensor + attention_mask_minibatch : torch.tensor The attention mask for the input tensor. + original_lengths: torch.tensor + The original lengths of the inputs without padding Returns ------- - torch.Tensor + torch.tensor The output tensor of the fine-tuning model. """ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask_minibatch) - final_layer = outputs.hidden_states[-1] - cls_seq = final_layer[:, 0, :] - final = self.fine_tuning_head(cls_seq) + batch_embeddings = outputs.hidden_states[-1] + + if self.emb_mode == "cls" and self.cls_present: + batch_embeddings = batch_embeddings[:, 0, :] + else: + length = original_lengths + if self.cls_present: + batch_embeddings = batch_embeddings[:, 1:, :] # Get all layers except the cls embs + if self.eos_present: + length -= 2 # length is used for the mean calculation, 2 is subtracted because we have taken both the cls and eos embeddings out + else: + length -= 1 # length is subtracted because just the cls is removed + + batch_embeddings = mean_nonpadding_embs(batch_embeddings, length) + + final = self.fine_tuning_head(batch_embeddings) return final def train( @@ -160,33 +175,9 @@ def train( lr_scheduler_params = {'name': 'linear', 'num_warmup_steps': 0, 'num_training_steps': 5} """ - model_input_size = get_model_input_size(self.model) - cls_present = any("" in key for key in self.gene_token_dict.keys()) - eos_present = any("" in key for key in self.gene_token_dict.keys()) - if self.emb_mode == "cls": - if cls_present is False: - message = " token missing in token dictionary" - logger.error(message) - raise ValueError(message) - # Check to make sure that the first token of the filtered input data is cls token - cls_token_id = self.gene_token_dict[""] - if cls_token_id != train_dataset["input_ids"][0][0]: - message = "First token is not token value" - logger.error(message) - assert ( - train_dataset["input_ids"][0][0] == cls_token_id - ), "First token is not token value" - elif self.emb_mode == "cell": - if cls_present: - logger.warning( - "CLS token present in token dictionary, excluding from average." - ) - if eos_present: - logger.warning( - "EOS token present in token dictionary, excluding from average." - ) + _check_for_expected_special_tokens(train_dataset, self.emb_mode, self.cls_present, self.eos_present, self.tk.gene_token_dict) total_batch_length = len(train_dataset) #initialise optimizer @@ -229,7 +220,7 @@ def train( input_data_minibatch, max_len, self.pad_token_id, model_input_size ).to(self.device) - outputs = self._forward(input_ids=input_data_minibatch, attention_mask_minibatch=gen_attention_mask(minibatch)) + outputs = self._forward(input_ids=input_data_minibatch, attention_mask_minibatch=gen_attention_mask(minibatch), original_lengths=minibatch["length"]) loss = loss_function(outputs, minibatch[label]) loss.backward() batch_loss += loss.item() @@ -262,7 +253,7 @@ def train( ).to(self.device) with torch.no_grad(): - outputs = self._forward(input_ids=input_data_minibatch, attention_mask_minibatch=gen_attention_mask(minibatch)) + outputs = self._forward(input_ids=input_data_minibatch, attention_mask_minibatch=gen_attention_mask(minibatch), original_lengths=minibatch["length"]) val_loss += loss_function(outputs, minibatch[label]).item() count += 1.0 testing_loop.set_postfix({"val_loss": val_loss/count}) @@ -294,31 +285,7 @@ def get_outputs( dataset_length = len(dataset) - cls_present = any("" in key for key in self.gene_token_dict.keys()) - eos_present = any("" in key for key in self.gene_token_dict.keys()) - if self.emb_mode == "cls": - if cls_present is False: - message = " token missing in token dictionary" - logger.error(message) - raise ValueError(message) - assert cls_present, " token missing in token dictionary" - # Check to make sure that the first token of the filtered input data is cls token - cls_token_id = self.gene_token_dict[""] - if cls_token_id != dataset["input_ids"][0][0]: - message = "First token is not token value" - logger.error(message) - assert ( - dataset["input_ids"][0][0] == cls_token_id - ), "First token is not token value" - elif self.emb_mode == "cell": - if cls_present: - logger.warning( - "CLS token present in token dictionary, excluding from average." - ) - if eos_present: - logger.warning( - "EOS token present in token dictionary, excluding from average." - ) + _check_for_expected_special_tokens(dataset, self.emb_mode, self.cls_present, self.eos_present, self.tk.gene_token_dict) output = [] testing_loop = trange(0, dataset_length, self.config["batch_size"], desc="Generating Outputs", leave=(not silent)) @@ -335,7 +302,7 @@ def get_outputs( ).to(self.device) with torch.no_grad(): - outputs = self._forward(input_ids=input_data_minibatch, attention_mask_minibatch=gen_attention_mask(minibatch)) + outputs = self._forward(input_ids=input_data_minibatch, attention_mask_minibatch=gen_attention_mask(minibatch), original_lengths=minibatch["length"]) output.append(outputs.clone().detach()) del outputs del minibatch diff --git a/helical/models/geneformer/geneformer_config.py b/helical/models/geneformer/geneformer_config.py index c54c1d96..6d9d88b8 100644 --- a/helical/models/geneformer/geneformer_config.py +++ b/helical/models/geneformer/geneformer_config.py @@ -15,7 +15,9 @@ class GeneformerConfig(): emb_layer : int, optional, default = -1 The embedding layer emb_mode : Literal["cls", "cell", "gene"], optional, default="cell" - The embedding mode + The embedding mode to use. "cls" is only available for Geneformer v2 models, returning the embeddings of the cls token. + For cell level embeddings, a mean across all embeddings excluding the cls token is returned. + For gene level embeddings, each gene token embedding is returned along with the corresponding ensembl ID. device : Literal["cpu", "cuda"], optional, default="cpu" The device to use. Either use "cuda" or "cpu". accelerator : bool, optional, default=False diff --git a/helical/models/geneformer/geneformer_tokenizer.py b/helical/models/geneformer/geneformer_tokenizer.py index 4569a548..6e1b71cd 100644 --- a/helical/models/geneformer/geneformer_tokenizer.py +++ b/helical/models/geneformer/geneformer_tokenizer.py @@ -293,6 +293,9 @@ def __init__( k: v for k, v in self.gene_mapping_dict.items() if v in gene_keys_set } + # Maps a token back to the Ensembl ID + self.token_to_ensembl_dict = {value: key for key, value in self.gene_token_dict.items()} + # protein-coding and miRNA gene list dictionary for selecting .h5ad columns for tokenization self.genelist_dict = dict(zip(self.gene_keys, [True] * len(self.gene_keys))) diff --git a/helical/models/geneformer/geneformer_utils.py b/helical/models/geneformer/geneformer_utils.py index cf35fe13..fcfb5432 100644 --- a/helical/models/geneformer/geneformer_utils.py +++ b/helical/models/geneformer/geneformer_utils.py @@ -11,6 +11,8 @@ from transformers import ( BertForMaskedLM ) +import pandas as pd +import numpy as np logger = logging.getLogger(__name__) @@ -35,6 +37,92 @@ def load_mappings(gene_symbols): pkl.dump(gene_id_to_ensemble, open('./human_gene_to_ensemble_id.pkl', 'wb')) return gene_id_to_ensemble +def _compute_embeddings_depending_on_mode(embeddings: torch.tensor, data_dict: dict, emb_mode: str, cls_present: bool, eos_present: bool, token_to_ensembl_dict: dict): + """ + Compute the different embeddings for each emb_mode + + Parameters + ----------- + embeddings: torch.tensor + The embedding batch output by the model. + data_dict: dict + The minibatch data dictionary used an input to the model. + emb_mode: str + The mode in which the embeddings are to be computed. + cls_present: bool + Whether the token is present in the token dictionary. + eos_present: bool + Whether the token is present in the token dictionary. + token_to_ensembl_dict: dict + The token to ensemble dictionary from the . + """ + if emb_mode == "cell": + length = data_dict['length'] + if cls_present: + embeddings = embeddings[:, 1:, :] # Get all layers except the cls embs + if eos_present: + length -= 2 # length is used for the mean calculation, 2 is subtracted because we have taken both the cls and eos embeddings out + else: + length -= 1 # length is subtracted because just the cls is removed + + batch_embeddings = mean_nonpadding_embs(embeddings, length).cpu().numpy() + + elif emb_mode == "gene": + if cls_present: + embeddings = embeddings[:, 1:, :] + if eos_present: + embeddings = embeddings[:, :-1, :] + + batch_embeddings = [] + for embedding, ids in zip(embeddings, data_dict["input_ids"]): + cell_dict = {} + if cls_present: + ids = ids[1:] + if eos_present: + ids = ids[:-1] + for id, gene_emb in zip(ids, embedding): + cell_dict[token_to_ensembl_dict[id.item()]] = gene_emb.cpu().numpy() + + batch_embeddings.append(pd.Series(cell_dict)) + + elif emb_mode == "cls": + batch_embeddings = embeddings[:, 0, :].cpu().numpy() # CLS token layer + + return batch_embeddings + +def _check_for_expected_special_tokens(dataset, emb_mode, cls_present, eos_present, gene_token_dict): + """ + Check for the expected special tokens in the dataset. + + Parameters + ----------- + dataset: dict + The batch dictionary with input ids. + emb_mode: str + The mode in which the embeddings are to be computed. + cls_present: bool + Whether the token is present in the token dictionary. + eos_present: bool + Whether the token is present in the token dictionary. + gene_token_dict: dict + The gene token dictionary from the tokenizer. + """ + if emb_mode == "cls": + message = " token missing in token dictionary" + if not cls_present: + logger.error(message) + raise ValueError(message) + + if dataset["input_ids"][0][0] != gene_token_dict[""]: + message = "First token is not token value" + logger.error(message) + raise ValueError(message) + + elif emb_mode == "cell": + if cls_present: + logger.warning("CLS token present in token dictionary, excluding from average.") + if eos_present: + logger.warning("EOS token present in token dictionary, excluding from average.") # extract embeddings def get_embs( @@ -45,6 +133,9 @@ def get_embs( pad_token_id, forward_batch_size, gene_token_dict, + token_to_ensembl_dict, + cls_present, + eos_present, device, silent=False, @@ -53,25 +144,7 @@ def get_embs( total_batch_length = len(filtered_input_data) embs_list = [] - # Check if CLS and EOS token is present in the token dictionary - cls_present = any("" in key for key in gene_token_dict.keys()) - eos_present = any("" in key for key in gene_token_dict.keys()) - if emb_mode == "cls": - assert cls_present, " token missing in token dictionary" - # Check to make sure that the first token of the filtered input data is cls token - cls_token_id = gene_token_dict[""] - assert ( - filtered_input_data["input_ids"][0][0] == cls_token_id - ), "First token is not token value" - elif emb_mode == "cell": - if cls_present: - logger.warning( - "CLS token present in token dictionary, excluding from average." - ) - if eos_present: - logger.warning( - "EOS token present in token dictionary, excluding from average." - ) + _check_for_expected_special_tokens(filtered_input_data, emb_mode, cls_present, eos_present, gene_token_dict) overall_max_len = 0 for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)): @@ -80,7 +153,6 @@ def get_embs( minibatch = filtered_input_data.select([i for i in range(i, max_range)]) max_len = int(max(minibatch["length"])) - original_lens = torch.tensor(minibatch["length"],device=device) minibatch.set_format(type="torch",device=device) input_data_minibatch = minibatch["input_ids"] @@ -97,26 +169,7 @@ def get_embs( embs_i = outputs.hidden_states[layer_to_quant] - if emb_mode == "cell": - if cls_present: - non_cls_embs = embs_i[:, 1:, :] # Get all layers except the cls embs - if eos_present: - mean_embs = mean_nonpadding_embs(non_cls_embs, original_lens - 2) - else: - mean_embs = mean_nonpadding_embs(non_cls_embs, original_lens - 1) - else: - mean_embs = mean_nonpadding_embs(embs_i, original_lens) - - embs_list.append(mean_embs) - del mean_embs - - elif emb_mode == "gene": - embs_list.append(embs_i) - - elif emb_mode == "cls": - cls_embs = embs_i[:, 0, :].clone().detach() # CLS token layer - embs_list.append(cls_embs) - del cls_embs + embs_list.extend(_compute_embeddings_depending_on_mode(embs_i, minibatch, emb_mode, cls_present, eos_present, token_to_ensembl_dict)) overall_max_len = max(overall_max_len, max_len) del outputs @@ -125,20 +178,10 @@ def get_embs( del embs_i torch.cuda.empty_cache() + if emb_mode != "gene": + embs_list = np.array(embs_list) - if emb_mode == "cell" or emb_mode == "cls": - embs_stack = torch.cat(embs_list, dim=0) - elif emb_mode == "gene": - embs_stack = pad_tensor_list( - embs_list, - overall_max_len, - pad_token_id, - model_input_size, - 1, - pad_3d_tensor, - ) - return embs_stack - + return embs_list def downsample_and_sort(data, max_ncells): num_cells = len(data) @@ -152,8 +195,6 @@ def downsample_and_sort(data, max_ncells): data_sorted = data_subset.sort("length", reverse=True) return data_sorted - - def quant_layers(model): layer_nums = [] for name, parameter in model.named_parameters(): diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index 51024988..0d863e40 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -12,7 +12,6 @@ from helical.utils.mapping import map_gene_symbols_to_ensembl_ids from datasets import Dataset from typing import Optional -from accelerate import Accelerator LOGGER = logging.getLogger(__name__) class Geneformer(HelicalRNAModel): @@ -91,17 +90,6 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None: self.layer_to_quant = quant_layers(self.model) + self.config['emb_layer'] self.emb_mode = self.config['emb_mode'] self.forward_batch_size = self.config['batch_size'] - - if self.config['accelerator']: - self.accelerator = Accelerator(project_dir=self.configurer.model_dir) - self.model = self.accelerator.prepare(self.model) - else: - self.accelerator = None - - # load token dictionary (Ensembl IDs:token) - with open(self.files_config["token_path"], "rb") as f: - self.gene_token_dict = pickle.load(f) - self.pad_token_id = self.gene_token_dict.get("") self.tk = TranscriptomeTokenizer(custom_attr_name_dict=self.config["custom_attr_name_dict"], nproc=self.config['nproc'], @@ -112,6 +100,10 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None: gene_mapping_file = self.files_config["ensembl_dict_path"], ) + self.pad_token_id = self.tk.gene_token_dict[""] + self.cls_present = True if "" in self.tk.gene_token_dict else False + self.eos_present = True if "" in self.tk.gene_token_dict else False + LOGGER.info(f"Model finished initializing.") def process_data(self, @@ -193,8 +185,11 @@ def get_embeddings(self, dataset: Dataset) -> np.array: self.layer_to_quant, self.pad_token_id, self.forward_batch_size, - self.gene_token_dict, + self.tk.gene_token_dict, + self.tk.token_to_ensembl_dict, + self.cls_present, + self.eos_present, self.device - ).cpu().detach().numpy() + ) return embeddings