Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include raw counts flag #117

Merged
merged 4 commits into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ci/tests/test_geneformer/test_geneformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_process_data_padding_and_masking_ids(self, geneformer, mock_data):
def test_ensure_data_validity_raising_error_with_missing_ensembl_id_column(self, geneformer, mock_data):
del mock_data.var['ensembl_id']
with pytest.raises(KeyError):
geneformer.ensure_data_validity(mock_data, "ensembl_id")
geneformer.ensure_rna_data_validity(mock_data, "ensembl_id")

@pytest.mark.parametrize("gene_symbols, raises_error",
[
Expand Down
6 changes: 4 additions & 2 deletions helical/models/base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_embeddings():
pass

class HelicalRNAModel(HelicalBaseFoundationModel):
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
def ensure_rna_data_validity(self, adata: AnnData, gene_names: str, use_raw_counts: bool = True) -> None:
"""Ensures that the data contains the gene_names and has integer counts for adata.X which is saved
in 'total_counts'.

Expand All @@ -59,6 +59,8 @@ def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:
The data to be checked.
gene_names : str
The name of the column containing gene names in adata.var.
use_raw_counts : bool, default = True
Whether to use raw counts or not.

Raises
------
Expand All @@ -84,7 +86,7 @@ def ensure_rna_data_validity(self, adata: AnnData, gene_names: str) -> None:

# verify that the data in X are integers
adata.obs["total_counts"] = adata.X.sum(axis=1)
if not (adata.obs["total_counts"] % 1 == 0).all():
if use_raw_counts and not (adata.obs["total_counts"] % 1 == 0).all():
message = "The data in X must be integers."
LOGGER.error(message)
raise ValueError(message)
Expand Down
26 changes: 4 additions & 22 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def process_data(self,
adata: AnnData,
gene_names: str = "index",
output_path: Optional[str] = None,
use_raw_counts: bool = True,
) -> Dataset:
"""Processes the data for the Geneformer model

Expand All @@ -139,15 +140,16 @@ def process_data(self,
In that case, it is recommended to create a new column with the Ensemble IDs in the data and pass "ensembl_id" as the gene_names.
output_path : str, default = None
Whether to save the tokenized dataset to the specified output_path.

use_raw_counts : bool, default = True
Whether to use raw counts or not.

Returns
-------
Dataset
The tokenized dataset in the form of a Hugginface Dataset object.

"""
self.ensure_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

# map gene symbols to ensemble ids if provided
if gene_names != "ensembl_id":
Expand Down Expand Up @@ -195,23 +197,3 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
).cpu().detach().numpy()

return embeddings

def ensure_data_validity(self, adata: AnnData, gene_names: str) -> None:
"""Ensure that the data is eligible for processing by the Geneformer model. This checks
if the data contains the gene_names, and sets the total_counts column in adata.obs.

Parameters
----------
adata : AnnData
The AnnData object containing the data to be processed.
gene_names: str
The column in adata.var that contains the gene names.

Raises
------
KeyError
If the data is missing column names.
"""
self.ensure_rna_data_validity(adata, gene_names)


22 changes: 10 additions & 12 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,6 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
cell_embeddings, axis=1, keepdims=True
)

# TODO?
# return_new_adata (bool): Whether to return a new AnnData object. If False, will
# add the cell embeddings to a new :attr:`adata.obsm` with key "X_scGPT".
# if return_new_adata:
# obs_df = adata.obs[obs_to_save] if obs_to_save is not None else None
# return sc.AnnData(X=cell_embeddings, obs=obs_df, dtype="float32")

return cell_embeddings

def process_data(self,
Expand All @@ -154,7 +147,8 @@ def process_data(self,
fine_tuning: bool = False,
n_top_genes: int = 1800,
flavor: Literal["seurat", "cell_ranger", "seurat_v3", "seurat_v3_paper"] = "seurat_v3",
use_batch_labels: bool = False
use_batch_labels: bool = False,
use_raw_counts: bool = True
) -> Dataset:
"""Processes the data for the scGPT model

Expand All @@ -176,14 +170,16 @@ def process_data(self,
Seurat passes the cutoffs whereas Cell Ranger passes n_top_genes.
use_batch_labels: Bool, default = False
Whether to use batch labels. Defaults to False.
use_raw_counts: Bool, default = True
Whether to use raw counts or not.

Returns
-------
Dataset
The processed dataset.
"""

self.ensure_data_validity(adata, gene_names, use_batch_labels)
self.ensure_data_validity(adata, gene_names, use_batch_labels, use_raw_counts)
self.gene_names = gene_names
if fine_tuning:
# Preprocess the dataset and select `N_HVG` highly variable genes for downstream analysis.
Expand Down Expand Up @@ -220,7 +216,7 @@ def process_data(self,
return dataset


def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool) -> None:
def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels: bool, use_raw_counts = True) -> None:
"""Checks if the data is eligible for processing by the scGPT model

Parameters
Expand All @@ -229,15 +225,17 @@ def ensure_data_validity(self, adata: AnnData, gene_names: str, use_batch_labels
The AnnData object containing the data to be validated.
gene_names : str
The name of the column containing gene names.
use_batch_labels : str
use_batch_labels : bool
Wheter to use batch labels.
use_raw_counts : bool, default = True
Whether to use raw counts or not.

Raises
------
KeyError
If the data is missing column names.
"""
self.ensure_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

if use_batch_labels:
if not "batch_id" in adata.obs:
Expand Down
9 changes: 5 additions & 4 deletions helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def process_data(self,
adata: AnnData,
gene_names: str = "index",
name = "test",
filter_genes_min_cell: int = None
filter_genes_min_cell: int = None,
use_raw_counts: bool = True
) -> UCEDataset:
"""Processes the data for the Universal Cell Embedding model

Expand All @@ -90,16 +91,16 @@ def process_data(self,
The name of the dataset. Needed for when slicing AnnData objects for train and validation datasets.
filter_genes_min_cell: int, default = None
Filter threshold that defines how many times a gene should occur in all the cells.
use_raw_counts: bool, default = True
Whether to use raw counts or not.

Returns
-------
UCEDataset
Inherits from Dataset class.
"""



self.ensure_rna_data_validity(adata, gene_names)
self.ensure_rna_data_validity(adata, gene_names, use_raw_counts)

if gene_names != "index":
adata.var.index = adata.var[gene_names]
Expand Down
4 changes: 0 additions & 4 deletions helical/models/uce/uce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,6 @@ def load_model(model_path: Union[str, Path], model_config: Dict[str, str], all_p
model.pe_embedding = torch.nn.Embedding.from_pretrained(empty_pe)
model.load_state_dict(torch.load(model_path, map_location=model_config["device"]), strict=True)

# TODO: Why load the protein embeddings from the `all_tokens.torch` file, pass it to this function but never use it?
# Cause in the lines above, we populate model.pe_embeddings with the empty_pe and this if clause will be true with the
# `all_tokens.torch` file
# From the original, this was the comment:
# This will make sure that you don't overwrite the tokens in case you're embedding species from the training data
# We avoid doing that just in case the random seeds are different across different versions.
if all_pe.shape[0] != 145469:
Expand Down
108 changes: 0 additions & 108 deletions helical/utils/downloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import requests
import json
import pickle as pkl
import pandas as pd
from helical.utils.logger import Logger
from helical.constants.enums import LoggingType, LoggingLevel
import logging
Expand All @@ -10,8 +7,6 @@
from pathlib import Path
from tqdm import tqdm
from azure.storage.blob import BlobClient
from azure.core.pipeline.transport import RequestsTransport
from git import Repo
from helical.constants.paths import CACHE_DIR_HELICAL

LOGGER = logging.getLogger(__name__)
Expand All @@ -30,47 +25,6 @@ def __init__(self, loging_type = LoggingType.CONSOLE, level = LoggingLevel.INFO)
# mount the adapter to the session
self.session.mount('https://', adapter)

def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path):
'''
Saves a mapping of the `Ensemble ID` to `display names`.

Args:
path_to_ets_csv: Path to the ETS csv file.
output: Path to where the output (.pkl) file should be saved to.
'''
try:
df = pd.read_csv(path_to_ets_csv)
except:
LOGGER.exception(f"Failed to open the '{path_to_ets_csv}' file. Please provide it.")

if output.is_file():
LOGGER.info(f"No mapping is done because mapping file already exists here: '{output}'")

else:
genes = df['egid'].dropna().unique()

server = "https://rest.ensembl.org/lookup/id"
headers={ "Content-Type" : "application/json", "Accept" : "application/json"}

ensemble_to_display_name = dict()

LOGGER.info(f"Starting to download the mappings of {len(genes)} genes from '{server}'")

# Resetting for visualization
self.data_length = 0
self.total_length = len(genes)

for i in range(0, len(genes), INTERVAL):
if self.display:
self._display_download_progress(INTERVAL)
ids = {'ids':genes[i:i+INTERVAL].tolist()}
r = requests.post(server, headers=headers, data=json.dumps(ids))
decoded = r.json()
ensemble_to_display_name.update(decoded)

pkl.dump(ensemble_to_display_name, open(output, 'wb'))
LOGGER.info(f"Downloaded all mappings and saved to: '{output}'")

def download_via_link(self, output: Path, link: str) -> None:
'''
Download a file via a link.
Expand Down Expand Up @@ -114,25 +68,6 @@ def download_via_link(self, output: Path, link: str) -> None:
LOGGER.error(f"Failed downloading file from '{link}'")
LOGGER.info(f"File saved to: '{output}'")

def clone_git_repo(self, destination: Path, repo_url: str, checkout: str) -> None:
'''
Clones a git repo to a destination folder if it does not yet exist.

Args:
destination: The path to where the git repo should be cloned to.
repo_url: The URL to do the git clone
checkout: The tag, branch or commit hash to checkout
'''

if destination.is_dir():
LOGGER.info(f"Folder: {destination} exists already. No 'git clone' is performed.")

else:
LOGGER.info(f"Clonging {repo_url} to {destination}")
repo = Repo.clone_from(repo_url, destination)
repo.git.checkout(checkout)
LOGGER.info(f"Successfully cloned and checked out '{checkout}' of {repo_url}")

def _display_download_progress(self, data_chunk_size: int) -> None:
'''
Display the download progress in console.
Expand All @@ -145,49 +80,6 @@ def _display_download_progress(self, data_chunk_size: int) -> None:
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (LOADING_BAR_LENGTH-done)) )
sys.stdout.flush()

def download_via_name_v0(self, name: str) -> None:
'''
Download a file via a link.

Args:
output: Path to the output file.
link: URL to download the file from.
'''
main_link = "https://helicalpackage.blob.core.windows.net/helicalpackage/data"
output = os.path.join(CACHE_DIR_HELICAL, name)

link = f"{main_link}/{name}"
if not os.path.exists(os.path.dirname(output)):
os.makedirs(os.path.dirname(output),exist_ok=True)

if Path(output).is_file():
LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")

else:
LOGGER.info(f"Starting to download: '{link}'")
with open(output, "wb") as f:
response = requests.get(link, stream=True)
total_length = response.headers.get('content-length')

# Resetting for visualization
self.data_length = 0
self.total_length = int(total_length)

if total_length is None: # no content length header
f.write(response.content)
else:
try:
# for data in response.iter_content(chunk_size=CHUNK_SIZE):
pbar = tqdm(total=int(self.total_length), unit="B", unit_scale=True)
for data in tqdm(response.iter_content(chunk_size=CHUNK_SIZE)):
# self._display_download_progress(len(data))
f.write(data)
pbar.update(len(data))
except:
LOGGER.error(f"Failed downloading file from '{link}'")
LOGGER.info(f"File saved to: '{output}'")


def download_via_name(self, name: str) -> None:
'''
Download a file via a link.
Expand Down
Loading