Skip to content

Commit

Permalink
Return ensembl ID for each embedding for emb_mode="gene" (#157)
Browse files Browse the repository at this point in the history
* Update version in pyproject.toml

* Convert embs to np array before checking shape in classifier

* Convert embs to np array before checking shape in classifier

* Add docstrings and correct tests

* Incorrect if statement and test updates

---------

Co-authored-by: Benoit Putzeys <157973952+bputzeys@users.noreply.github.com>
  • Loading branch information
mattwoodx and bputzeys authored Dec 17, 2024
1 parent f9effa1 commit eede0cf
Show file tree
Hide file tree
Showing 9 changed files with 260 additions and 278 deletions.
259 changes: 114 additions & 145 deletions ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
@@ -1,55 +1,64 @@
import pytest
import torch
from helical.models.geneformer.model import Geneformer
from helical.models.geneformer.geneformer_config import GeneformerConfig
from helical.models.geneformer.geneformer_utils import get_embs, load_model
from helical.models.geneformer.fine_tuning_model import GeneformerFineTuningModel
from helical import GeneformerConfig, Geneformer, GeneformerFineTuningModel
from anndata import AnnData
import torch
import pandas as pd
import numpy as np

class TestGeneformerModel:
@pytest.fixture(params=["gf-12L-30M-i2048", "gf-12L-95M-i4096"])
def geneformer(self, request):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
config = GeneformerConfig(model_name=request.param, device=self.device)
return Geneformer(config)

class TestGeneformer:
@pytest.fixture
def mock_data(self):
data = AnnData()
data.var['gene_symbols'] = ['SAMD11', 'PLEKHN1', 'HES4']
data.var['ensembl_id'] = ['ENSG00000187634', 'ENSG00000187583', 'ENSG00000188290']
data.var['gene_symbols'] = ['HES4', 'PLEKHN1', 'SAMD11']
data.obs["cell_type"] = ["CD4 T cells"]
data.X = [[1, 2, 5]]
return data

@pytest.fixture
def fine_tune_mock_data(self):
labels = list([0])
return labels

def test_pass_invalid_model_name(self):
with pytest.raises(ValueError):
geneformer_config = GeneformerConfig(model_name='InvalidName')


def mock_embeddings_v1(self, mocker):
embs = mocker.Mock()
embs.hidden_states = [torch.tensor([[[5.0, 5.0, 5.0, 5.0, 5.0],
[1.0, 2.0, 3.0, 2.0, 1.0],
[6.0, 6.0, 6.0, 6.0, 6.0]]])]*12
return embs

@pytest.fixture
def mock_embeddings_v2(self, mocker):
embs = mocker.Mock()
embs.hidden_states = torch.tensor([[[6.0, 5.0, 7.0, 5.0, 5.0],
[5.0, 5.0, 5.0, 5.0, 5.0],
[1.0, 2.0, 3.0, 2.0, 1.0],
[6.0, 6.0, 6.0, 6.0, 6.0],
[6.0, 6.0, 1.0, 6.0, 2.0]]]).repeat(12, 1, 1, 1)
return embs

@pytest.fixture(params=["gf-12L-30M-i2048", "gf-12L-95M-i4096"])
def geneformer(self, request):
config = GeneformerConfig(model_name=request.param, batch_size=5)
geneformer = Geneformer(config)
return geneformer

def test_process_data_mapping_to_ensemble_ids(self, geneformer, mock_data):
assert mock_data.var['ensembl_id'][0] == 'ENSG00000187634'
# geneformer modifies the anndata in place and maps the gene names to ensembl id
geneformer.process_data(mock_data, gene_names="gene_symbols")
assert mock_data.var['ensembl_id'][0] == 'ENSG00000188290'
# is the same as the above line but more verbose (linking the gene symbol to the ensembl id)
assert mock_data.var[mock_data.var['gene_symbols'] == 'SAMD11']['ensembl_id'].values[0] == 'ENSG00000187634'
assert mock_data.var[mock_data.var['gene_symbols'] == 'PLEKHN1']['ensembl_id'].values[0] == 'ENSG00000187583'
assert mock_data.var[mock_data.var['gene_symbols'] == 'HES4']['ensembl_id'].values[0] == 'ENSG00000188290'

def test_process_data_padding_and_masking_ids(self, geneformer, mock_data):
# for this token mapping, the padding token is 0 and the mask token is 1
geneformer.process_data(mock_data, gene_names='gene_symbols')
assert geneformer.gene_token_dict.get("<pad>") == 0
assert geneformer.gene_token_dict.get("<mask>") == 1

@pytest.mark.parametrize("invalid_model_names", ["gf-12L-35M-i2048", "gf-34L-30M-i5000"])
def test_pass_invalid_model_name(self, invalid_model_names):
with pytest.raises(ValueError):
GeneformerConfig(model_name=invalid_model_names)

def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self, geneformer, mock_data):
geneformer.process_data(mock_data, gene_names="gene_symbols")
del mock_data.var['ensembl_id']
with pytest.raises(KeyError):
geneformer.ensure_rna_data_validity(mock_data, "ensembl_id")



@pytest.mark.parametrize("gene_symbols, raises_error",
[
(['ENSGSAMD11', 'ENSGPLEKHN1', 'ENSGHES4'], True), # humans
Expand All @@ -66,126 +75,86 @@ def test_ensembl_data_is_caught(self, geneformer, mock_data, gene_symbols, raise
else:
geneformer.process_data(mock_data, "gene_symbols")

def test_cls_mode_with_v1_model_config(self, geneformer, mock_data):

def test_cls_mode_with_v1_model_config(self, geneformer):
if geneformer.config["special_token"]:
pytest.skip("This test is only for v1 models and should thus be only executed once.")
with pytest.raises(ValueError):
config = GeneformerConfig(model_name="gf-12L-30M-i2048", device="cpu", emb_mode='cls')

def test_get_embs_cell_mode(self, geneformer, mock_data):
tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols')
model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embs = get_embs(
model,
tokenized_dataset,
emb_mode="cell",
layer_to_quant=-1,
pad_token_id=geneformer.pad_token_id,
forward_batch_size=1,
token_gene_dict=geneformer.gene_token_dict,
device=device
)
assert embs.shape == (1, model.config.hidden_size)

def test_get_embs_cls_mode(self, geneformer, mock_data):
if not geneformer.config["special_token"]:
pytest.skip("This test is only for models with special tokens (v2)")
tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols')
model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embs = get_embs(
model,
tokenized_dataset,
emb_mode="cls",
layer_to_quant=-1,
pad_token_id=geneformer.pad_token_id,
forward_batch_size=1,
gene_token_dict=geneformer.gene_token_dict,
device=device
)
assert embs.shape == (1, model.config.hidden_size)

def test_get_embs_gene_mode(self, geneformer, mock_data):
tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols')
model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embs = get_embs(
model,
tokenized_dataset,
emb_mode="gene",
layer_to_quant=-1,
pad_token_id=geneformer.pad_token_id,
forward_batch_size=1,
gene_token_dict=geneformer.gene_token_dict,
device=device
)
assert embs.shape[0] == 1
assert embs.shape[2] == model.config.hidden_size

def test_get_embs_different_layer(self, geneformer, mock_data):
tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols')
model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embs_last = get_embs(
model,
tokenized_dataset,
emb_mode="cell",
layer_to_quant=-1,
pad_token_id=geneformer.pad_token_id,
forward_batch_size=1,
gene_token_dict=geneformer.gene_token_dict,
device=device
)
embs_first = get_embs(
model,
tokenized_dataset,
emb_mode="cell",
layer_to_quant=0,
pad_token_id=geneformer.pad_token_id,
forward_batch_size=1,
gene_token_dict=geneformer.gene_token_dict,
device=device
)
assert not torch.allclose(embs_last, embs_first)

def test_get_embs_cell_mode(self, geneformer, mock_data):
tokenized_dataset = geneformer.process_data(mock_data, gene_names='gene_symbols')
model = load_model("Pretrained", geneformer.files_config["model_files_dir"], self.device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embs = get_embs(
model,
tokenized_dataset,
emb_mode="cell",
layer_to_quant=-1,
pad_token_id=geneformer.pad_token_id,
forward_batch_size=1,
gene_token_dict=geneformer.gene_token_dict,
device=device
)
assert embs.shape == (1, model.config.hidden_size)

def test_cls_eos_tokens_presence(self, geneformer, mock_data):
geneformer.process_data(mock_data, gene_names='gene_symbols')
if geneformer.config["special_token"]:
assert "<cls>" in geneformer.tk.gene_token_dict
assert "<eos>" in geneformer.tk.gene_token_dict
else:
assert "<cls>" not in geneformer.tk.gene_token_dict
assert "<eos>" not in geneformer.tk.gene_token_dict
GeneformerConfig(model_name="gf-12L-30M-i2048", emb_mode='cls')

@pytest.mark.parametrize("emb_mode", ["cell", "gene"])
def test_get_embeddings_of_different_modes_v1(self, emb_mode, mock_data, mock_embeddings_v1, mocker):
config = GeneformerConfig(model_name="gf-12L-30M-i2048", batch_size=5, emb_mode=emb_mode)
geneformer = Geneformer(config)
mocker.patch.object(geneformer.model, "forward", return_value=mock_embeddings_v1)

dataset = geneformer.process_data(mock_data, gene_names ="gene_symbols")
embeddings = geneformer.get_embeddings(dataset)
if emb_mode == "gene":
data_list = pd.Series({
"ENSG00000187583": np.array([1.0, 2.0, 3.0, 2.0, 1.0]),
"ENSG00000187634": np.array([5.0, 5.0, 5.0, 5.0, 5.0]),
"ENSG00000188290": np.array([6.0, 6.0, 6.0, 6.0, 6.0])
})
for key in data_list.index:
assert np.all(np.equal(embeddings[0][key], data_list[key]))

if emb_mode == "cell":
expected = np.array([[4, 4.333333, 4.666667, 4.333333, 4]])
np.testing.assert_allclose(
embeddings,
expected,
rtol=1e-4,
atol=1e-4
)

@pytest.mark.parametrize("emb_mode", ["cell", "gene", "cls"])
def test_get_embeddings_of_different_modes_v2(self, emb_mode, mock_data, mock_embeddings_v2, mocker):
config = GeneformerConfig(model_name="gf-12L-95M-i4096", batch_size=5, emb_mode=emb_mode)
geneformer = Geneformer(config)
mocker.patch.object(geneformer.model, "forward", return_value=mock_embeddings_v2)

dataset = geneformer.process_data(mock_data, gene_names ="gene_symbols")
embeddings = geneformer.get_embeddings(dataset)
if emb_mode == "gene":
data_list = pd.Series({
"ENSG00000187583": np.array([1.0, 2.0, 3.0, 2.0, 1.0]),
"ENSG00000187634": np.array([5.0, 5.0, 5.0, 5.0, 5.0]),
"ENSG00000188290": np.array([6.0, 6.0, 6.0, 6.0, 6.0])
})
for key in data_list.index:
assert np.all(np.equal(embeddings[0][key], data_list[key]))

if emb_mode == "cls":
assert (embeddings == np.array([6.0, 5.0, 7.0, 5.0, 5.0])).all()
if emb_mode == "cell":
expected = np.array([[4, 4.333333, 4.666667, 4.333333, 4]])
np.testing.assert_allclose(
embeddings,
expected,
rtol=1e-4,
atol=1e-4
)

@pytest.mark.parametrize("emb_mode", ["cell", "gene"])
def test_fine_tune_classifier_returns_correct_shape(self, emb_mode, mock_data):
fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(emb_mode=emb_mode), fine_tuning_head="classification", output_size=1)
tokenized_dataset = fine_tuned_model.process_data(mock_data, gene_names='gene_symbols')
tokenized_dataset = tokenized_dataset.add_column('labels', list([0]))

fine_tuned_model.train(train_dataset=tokenized_dataset, label='labels')

def test_model_input_size(self, geneformer):
assert geneformer.config["input_size"] == geneformer.configurer.model_map[geneformer.config["model_name"]]['input_size']
outputs = fine_tuned_model.get_outputs(tokenized_dataset)
assert outputs.shape == (len(mock_data), 1)

def test_fine_tune_classifier_returns_correct_shape(self, mock_data, fine_tune_mock_data):
device = "cuda" if torch.cuda.is_available() else "cpu"
fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(device=device), fine_tuning_head="classification", output_size=1)
def test_fine_tune_classifier_cls_returns_correct_shape(self, mock_data):
fine_tuned_model = GeneformerFineTuningModel(GeneformerConfig(model_name="gf-12L-95M-i4096", emb_mode="cls"), fine_tuning_head="classification", output_size=1)
tokenized_dataset = fine_tuned_model.process_data(mock_data, gene_names='gene_symbols')
tokenized_dataset = tokenized_dataset.add_column('labels', fine_tune_mock_data)
tokenized_dataset = tokenized_dataset.add_column('labels', [0])

fine_tuned_model.train(train_dataset=tokenized_dataset, label='labels')
assert fine_tuned_model is not None

outputs = fine_tuned_model.get_outputs(tokenized_dataset)
assert outputs.shape == (len(mock_data), len(fine_tune_mock_data))
assert outputs.shape == (len(mock_data), 1)



5 changes: 5 additions & 0 deletions examples/fine_tune_models/fine_tune_geneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ def run_fine_tuning(cfg: DictConfig):
# either load via huggingface
# hf_dataset = load_dataset("helical-ai/yolksac_human",split="train[:5%]", trust_remote_code=True, download_mode="reuse_cache_if_exists")
# ann_data = get_anndata_from_hf_dataset(hf_dataset)

# or load directly from anndata file
ann_data = ad.read_h5ad("./yolksac_human.h5ad")

cell_types = list(ann_data.obs["LVL1"][:10])
Expand All @@ -32,5 +34,8 @@ def classes_to_ids(example):

geneformer_fine_tune.train(train_dataset=dataset)

outputs = geneformer_fine_tune.get_outputs(dataset)
print(outputs)

if __name__ == "__main__":
run_fine_tuning()
2 changes: 1 addition & 1 deletion examples/run_models/run_geneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def run(cfg: DictConfig):
dataset = geneformer.process_data(ann_data[:10])
embeddings = geneformer.get_embeddings(dataset)

print(embeddings.shape)
print(embeddings)

if __name__ == "__main__":
run()
2 changes: 1 addition & 1 deletion helical/models/classification/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def train_classifier_head(self,
LOGGER.info(f"Getting training embeddings with {base_model.__class__.__name__}.")
dataset = base_model.process_data(train_anndata, gene_names)
self.gene_names = gene_names
x = base_model.get_embeddings(dataset)
x = np.array(base_model.get_embeddings(dataset))

# then, train the classification model
LOGGER.info(f"Training classification model '{self.name}'.")
Expand Down
Loading

0 comments on commit eede0cf

Please sign in to comment.