Skip to content

Commit

Permalink
Include use_raw_counts flag, giving the option to
Browse files Browse the repository at this point in the history
the user to still run models with counts which are not integers.
  • Loading branch information
bputzeys committed Oct 23, 2024
1 parent dc76657 commit 58a765a
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 41 deletions.
2 changes: 1 addition & 1 deletion ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_process_data_padding_and_masking_ids(self, geneformer, mock_data):
def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self, geneformer, mock_data):
del mock_data.var['ensembl_id']
with pytest.raises(KeyError):
geneformer.ensure_data_validity(mock_data, "ensembl_id")
geneformer.ensure_rna_data_validity(mock_data, "ensembl_id")

@pytest.mark.parametrize("gene_symbols, raises_error",
[
Expand Down
6 changes: 4 additions & 2 deletions helical/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_embeddings():
pass

class HelicalRNAModel(HelicalBaseFoundationModel):
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str, use_raw_counts: bool = True) -> None:
"""Ensures that the data contains the gene_names and has integer counts for adata.X which is saved
in 'total_counts'.
Expand All @@ -59,6 +59,8 @@ def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
The data to be checked.
gene_names : str
The name of the column containing gene names in adata.var.
use_raw_counts : bool, default = True
Whether to use raw counts or not.
Raises
------
Expand All @@ -84,7 +86,7 @@ def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:

# verify that the data in X are integers
adata.obs["total_counts"] = adata.X.sum(axis=1)
if not (adata.obs["total_counts"] % 1 == 0).all():
if use_raw_counts and not (adata.obs["total_counts"] % 1 == 0).all():
message = "The data in X must be integers."
LOGGER.error(message)
raise ValueError(message)
Expand Down
26 changes: 4 additions & 22 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def process_data(self,
adata: AnnData,
gene_names: str = "index",
output_path: Optional[str] = None,
use_raw_counts: bool = True,
) -> Dataset:
"""Processes the data for the Geneformer model
Expand All @@ -139,15 +140,16 @@ def process_data(self,
In that case, it is recommended to create a new column with the Ensemble IDs in the data and pass "ensembl_id" as the gene_names.
output_path : str, default = None
Whether to save the tokenized dataset to the specified output_path.
use_raw_counts : bool, default = True
Whether to use raw counts or not.
Returns
-------
Dataset
The tokenized dataset in the form of a Hugginface Dataset object.
"""
self.ensure_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

# map gene symbols to ensemble ids if provided
if gene_names != "ensembl_id":
Expand Down Expand Up @@ -195,23 +197,3 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
).cpu().detach().numpy()

return embeddings

def ensure_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Ensure that the data is eligible for processing by the Geneformer model. This checks
if the data contains the gene_names, and sets the total_counts column in adata.obs.
Parameters
----------
adata : AnnData
The AnnData object containing the data to be processed.
gene_names: str
The column in adata.var that contains the gene names.
Raises
------
KeyError
If the data is missing column names.
"""
self.ensure_rna_data_validity(adata, gene_names)


22 changes: 10 additions & 12 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
cell_embeddings, axis=1, keepdims=True
)

# TODO?
# return_new_adata (bool): Whether to return a new AnnData object. If False, will
# add the cell embeddings to a new :attr:`adata.obsm` with key "X_scGPT".
# if return_new_adata:
# obs_df = adata.obs[obs_to_save] if obs_to_save is not None else None
# return sc.AnnData(X=cell_embeddings, obs=obs_df, dtype="float32")

return cell_embeddings

def process_data(self,
Expand All @@ -154,7 +147,8 @@ def process_data(self,
fine_tuning: bool = False,
n_top_genes: int = 1800,
flavor: Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] = "seurat_v3",
use_batch_labels: bool = False
use_batch_labels: bool = False,
use_raw_counts: bool = True
) -> Dataset:
"""Processes the data for the scGPT model
Expand All @@ -176,14 +170,16 @@ def process_data(self,
Seurat passes the cutoffs whereas Cell Ranger passes n_top_genes.
use_batch_labels: Bool, default = False
Whether to use batch labels. Defaults to False.
use_raw_counts: Bool, default = True
Whether to use raw counts or not.
Returns
-------
Dataset
The processed dataset.
"""

self.ensure_data_validity(adata, gene_names, use_batch_labels)
self.ensure_data_validity(adata, gene_names, use_batch_labels, use_raw_counts)
self.gene_names = gene_names
if fine_tuning:
# Preprocess the dataset and select `N_HVG` highly variable genes for downstream analysis.
Expand Down Expand Up @@ -220,7 +216,7 @@ def process_data(self,
return dataset


def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool) -> None:
def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool, use_raw_counts = True) -> None:
"""Checks if the data is eligible for processing by the scGPT model
Parameters
Expand All @@ -229,15 +225,17 @@ def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels
The AnnData object containing the data to be validated.
gene_names : str
The name of the column containing gene names.
use_batch_labels : str
use_batch_labels : bool
Wheter to use batch labels.
use_raw_counts : bool, default = True
Whether to use raw counts or not.
Raises
------
KeyError
If the data is missing column names.
"""
self.ensure_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

if use_batch_labels:
if not "batch_id" in adata.obs:
Expand Down
9 changes: 5 additions & 4 deletions helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def process_data(self,
adata: AnnData,
gene_names: str = "index",
name = "test",
filter_genes_min_cell: int = None
filter_genes_min_cell: int = None,
use_raw_counts: bool = True
) -> UCEDataset:
"""Processes the data for the Universal Cell Embedding model
Expand All @@ -90,16 +91,16 @@ def process_data(self,
The name of the dataset. Needed for when slicing AnnData objects for train and validation datasets.
filter_genes_min_cell: int, default = None
Filter threshold that defines how many times a gene should occur in all the cells.
use_raw_counts: bool, default = True
Whether to use raw counts or not.
Returns
-------
UCEDataset
Inherits from Dataset class.
"""



self.ensure_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

if gene_names != "index":
adata.var.index = adata.var[gene_names]
Expand Down

0 comments on commit 58a765a

Please sign in to comment.