Skip to content

Commit

Permalink
Merge pull request #117 from helicalAI/include-raw-counts-flag
Browse files Browse the repository at this point in the history
Include raw counts flag
  • Loading branch information
maxiallard authored Oct 23, 2024
2 parents f422e44 + 58a765a commit 4b87198
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 153 deletions.
2 changes: 1 addition & 1 deletion ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_process_data_padding_and_masking_ids(self, geneformer, mock_data):
def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self, geneformer, mock_data):
del mock_data.var['ensembl_id']
with pytest.raises(KeyError):
geneformer.ensure_data_validity(mock_data, "ensembl_id")
geneformer.ensure_rna_data_validity(mock_data, "ensembl_id")

@pytest.mark.parametrize("gene_symbols, raises_error",
[
Expand Down
6 changes: 4 additions & 2 deletions helical/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_embeddings():
pass

class HelicalRNAModel(HelicalBaseFoundationModel):
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str, use_raw_counts: bool = True) -> None:
"""Ensures that the data contains the gene_names and has integer counts for adata.X which is saved
in 'total_counts'.
Expand All @@ -59,6 +59,8 @@ def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
The data to be checked.
gene_names : str
The name of the column containing gene names in adata.var.
use_raw_counts : bool, default = True
Whether to use raw counts or not.
Raises
------
Expand All @@ -84,7 +86,7 @@ def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:

# verify that the data in X are integers
adata.obs["total_counts"] = adata.X.sum(axis=1)
if not (adata.obs["total_counts"] % 1 == 0).all():
if use_raw_counts and not (adata.obs["total_counts"] % 1 == 0).all():
message = "The data in X must be integers."
LOGGER.error(message)
raise ValueError(message)
Expand Down
26 changes: 4 additions & 22 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def process_data(self,
adata: AnnData,
gene_names: str = "index",
output_path: Optional[str] = None,
use_raw_counts: bool = True,
) -> Dataset:
"""Processes the data for the Geneformer model
Expand All @@ -139,15 +140,16 @@ def process_data(self,
In that case, it is recommended to create a new column with the Ensemble IDs in the data and pass "ensembl_id" as the gene_names.
output_path : str, default = None
Whether to save the tokenized dataset to the specified output_path.
use_raw_counts : bool, default = True
Whether to use raw counts or not.
Returns
-------
Dataset
The tokenized dataset in the form of a Hugginface Dataset object.
"""
self.ensure_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

# map gene symbols to ensemble ids if provided
if gene_names != "ensembl_id":
Expand Down Expand Up @@ -195,23 +197,3 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
).cpu().detach().numpy()

return embeddings

def ensure_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Ensure that the data is eligible for processing by the Geneformer model. This checks
if the data contains the gene_names, and sets the total_counts column in adata.obs.
Parameters
----------
adata : AnnData
The AnnData object containing the data to be processed.
gene_names: str
The column in adata.var that contains the gene names.
Raises
------
KeyError
If the data is missing column names.
"""
self.ensure_rna_data_validity(adata, gene_names)


22 changes: 10 additions & 12 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
cell_embeddings, axis=1, keepdims=True
)

# TODO?
# return_new_adata (bool): Whether to return a new AnnData object. If False, will
# add the cell embeddings to a new :attr:`adata.obsm` with key "X_scGPT".
# if return_new_adata:
# obs_df = adata.obs[obs_to_save] if obs_to_save is not None else None
# return sc.AnnData(X=cell_embeddings, obs=obs_df, dtype="float32")

return cell_embeddings

def process_data(self,
Expand All @@ -154,7 +147,8 @@ def process_data(self,
fine_tuning: bool = False,
n_top_genes: int = 1800,
flavor: Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] = "seurat_v3",
use_batch_labels: bool = False
use_batch_labels: bool = False,
use_raw_counts: bool = True
) -> Dataset:
"""Processes the data for the scGPT model
Expand All @@ -176,14 +170,16 @@ def process_data(self,
Seurat passes the cutoffs whereas Cell Ranger passes n_top_genes.
use_batch_labels: Bool, default = False
Whether to use batch labels. Defaults to False.
use_raw_counts: Bool, default = True
Whether to use raw counts or not.
Returns
-------
Dataset
The processed dataset.
"""

self.ensure_data_validity(adata, gene_names, use_batch_labels)
self.ensure_data_validity(adata, gene_names, use_batch_labels, use_raw_counts)
self.gene_names = gene_names
if fine_tuning:
# Preprocess the dataset and select `N_HVG` highly variable genes for downstream analysis.
Expand Down Expand Up @@ -220,7 +216,7 @@ def process_data(self,
return dataset


def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool) -> None:
def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool, use_raw_counts = True) -> None:
"""Checks if the data is eligible for processing by the scGPT model
Parameters
Expand All @@ -229,15 +225,17 @@ def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels
The AnnData object containing the data to be validated.
gene_names : str
The name of the column containing gene names.
use_batch_labels : str
use_batch_labels : bool
Wheter to use batch labels.
use_raw_counts : bool, default = True
Whether to use raw counts or not.
Raises
------
KeyError
If the data is missing column names.
"""
self.ensure_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

if use_batch_labels:
if not "batch_id" in adata.obs:
Expand Down
9 changes: 5 additions & 4 deletions helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def process_data(self,
adata: AnnData,
gene_names: str = "index",
name = "test",
filter_genes_min_cell: int = None
filter_genes_min_cell: int = None,
use_raw_counts: bool = True
) -> UCEDataset:
"""Processes the data for the Universal Cell Embedding model
Expand All @@ -90,16 +91,16 @@ def process_data(self,
The name of the dataset. Needed for when slicing AnnData objects for train and validation datasets.
filter_genes_min_cell: int, default = None
Filter threshold that defines how many times a gene should occur in all the cells.
use_raw_counts: bool, default = True
Whether to use raw counts or not.
Returns
-------
UCEDataset
Inherits from Dataset class.
"""



self.ensure_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

if gene_names != "index":
adata.var.index = adata.var[gene_names]
Expand Down
4 changes: 0 additions & 4 deletions helical/models/uce/uce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def load_model(model_path: Union[str, Path], model_config: Dict[str, str], all_p
model.pe_embedding = torch.nn.Embedding.from_pretrained(empty_pe)
model.load_state_dict(torch.load(model_path, map_location=model_config["device"]), strict=True)

# TODO: Why load the protein embeddings from the `all_tokens.torch` file, pass it to this function but never use it?
# Cause in the lines above, we populate model.pe_embeddings with the empty_pe and this if clause will be true with the
# `all_tokens.torch` file
# From the original, this was the comment:
# This will make sure that you don't overwrite the tokens in case you're embedding species from the training data
# We avoid doing that just in case the random seeds are different across different versions.
if all_pe.shape[0] != 145469:
Expand Down
108 changes: 0 additions & 108 deletions helical/utils/downloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import requests
import json
import pickle as pkl
import pandas as pd
from helical.utils.logger import Logger
from helical.constants.enums import LoggingType, LoggingLevel
import logging
Expand All @@ -10,8 +7,6 @@
from pathlib import Path
from tqdm import tqdm
from azure.storage.blob import BlobClient
from azure.core.pipeline.transport import RequestsTransport
from git import Repo
from helical.constants.paths import CACHE_DIR_HELICAL

LOGGER = logging.getLogger(__name__)
Expand All @@ -30,47 +25,6 @@ def __init__(self, loging_type = LoggingType.CONSOLE, level = LoggingLevel.INFO)
# mount the adapter to the session
self.session.mount('https://', adapter)

def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path):
'''
Saves a mapping of the `Ensemble ID` to `display names`.
Args:
path_to_ets_csv: Path to the ETS csv file.
output: Path to where the output (.pkl) file should be saved to.
'''
try:
df = pd.read_csv(path_to_ets_csv)
except:
LOGGER.exception(f"Failed to open the '{path_to_ets_csv}' file. Please provide it.")

if output.is_file():
LOGGER.info(f"No mapping is done because mapping file already exists here: '{output}'")

else:
genes = df['egid'].dropna().unique()

server = "https://rest.ensembl.org/lookup/id"
headers={ "Content-Type" : "application/json", "Accept" : "application/json"}

ensemble_to_display_name = dict()

LOGGER.info(f"Starting to download the mappings of {len(genes)} genes from '{server}'")

# Resetting for visualization
self.data_length = 0
self.total_length = len(genes)

for i in range(0, len(genes), INTERVAL):
if self.display:
self._display_download_progress(INTERVAL)
ids = {'ids':genes[i:i+INTERVAL].tolist()}
r = requests.post(server, headers=headers, data=json.dumps(ids))
decoded = r.json()
ensemble_to_display_name.update(decoded)

pkl.dump(ensemble_to_display_name, open(output, 'wb'))
LOGGER.info(f"Downloaded all mappings and saved to: '{output}'")

def download_via_link(self, output: Path, link: str) -> None:
'''
Download a file via a link.
Expand Down Expand Up @@ -114,25 +68,6 @@ def download_via_link(self, output: Path, link: str) -> None:
LOGGER.error(f"Failed downloading file from '{link}'")
LOGGER.info(f"File saved to: '{output}'")

def clone_git_repo(self, destination: Path, repo_url: str, checkout: str) -> None:
'''
Clones a git repo to a destination folder if it does not yet exist.
Args:
destination: The path to where the git repo should be cloned to.
repo_url: The URL to do the git clone
checkout: The tag, branch or commit hash to checkout
'''

if destination.is_dir():
LOGGER.info(f"Folder: {destination} exists already. No 'git clone' is performed.")

else:
LOGGER.info(f"Clonging {repo_url} to {destination}")
repo = Repo.clone_from(repo_url, destination)
repo.git.checkout(checkout)
LOGGER.info(f"Successfully cloned and checked out '{checkout}' of {repo_url}")

def _display_download_progress(self, data_chunk_size: int) -> None:
'''
Display the download progress in console.
Expand All @@ -145,49 +80,6 @@ def _display_download_progress(self, data_chunk_size: int) -> None:
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (LOADING_BAR_LENGTH-done)) )
sys.stdout.flush()

def download_via_name_v0(self, name: str) -> None:
'''
Download a file via a link.
Args:
output: Path to the output file.
link: URL to download the file from.
'''
main_link = "https://helicalpackage.blob.core.windows.net/helicalpackage/data"
output = os.path.join(CACHE_DIR_HELICAL, name)

link = f"{main_link}/{name}"
if not os.path.exists(os.path.dirname(output)):
os.makedirs(os.path.dirname(output),exist_ok=True)

if Path(output).is_file():
LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")

else:
LOGGER.info(f"Starting to download: '{link}'")
with open(output, "wb") as f:
response = requests.get(link, stream=True)
total_length = response.headers.get('content-length')

# Resetting for visualization
self.data_length = 0
self.total_length = int(total_length)

if total_length is None: # no content length header
f.write(response.content)
else:
try:
# for data in response.iter_content(chunk_size=CHUNK_SIZE):
pbar = tqdm(total=int(self.total_length), unit="B", unit_scale=True)
for data in tqdm(response.iter_content(chunk_size=CHUNK_SIZE)):
# self._display_download_progress(len(data))
f.write(data)
pbar.update(len(data))
except:
LOGGER.error(f"Failed downloading file from '{link}'")
LOGGER.info(f"File saved to: '{output}'")


def download_via_name(self, name: str) -> None:
'''
Download a file via a link.
Expand Down

0 comments on commit 4b87198

Please sign in to comment.