diff --git a/examples/run_scgpt.py b/examples/run_scgpt.py index 54a6b83e..65d178a0 100644 --- a/examples/run_scgpt.py +++ b/examples/run_scgpt.py @@ -1,8 +1,8 @@ from helical.models.scgpt.model import scGPT, scGPTConfig import anndata as ad -model_config = scGPTConfig(batch_size=10) -scgpt = scGPT(model_config=model_config) +scgpt_config = scGPTConfig(batch_size=10) +scgpt = scGPT(configurer = scgpt_config) adata = ad.read_h5ad("./10k_pbmcs_proc.h5ad") data = scgpt.process_data(adata[:10]) diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 73fcd821..c8ecedf3 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -1,20 +1,3 @@ -# This tutorial covers the zero-shot integration with continual pre-trained scGPT. -# This particular workflow works for scRNA-seq datasets without fine-tuning (or any extensive training) of scGPT. -# Continual pre-trained scGPT (scGPT_CP) is a model that inherits the pre-trained scGPT whole-human model checkpoint, -# and is further supervised by extra cell type labels (using the [Tabula Sapiens](https://tabula-sapiens-portal.ds.czbiohub.org/) dataset) -# during the continual pre-training stage. We observed that the scGPT_CP model can achieve comparable or better zero-shot performance -# on cell embedding related tasks compared to the original checkpoint, especially on datasets with observable technical batch effects. -# This tutorial will show how to use the latent space of scGPT to integrate scRNA-seq datasets. -# We use the `scGPT_CP` model to provide embeddings out of the box. -# You may download it from [here](https://drive.google.com/drive/folders/1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB). - -# We will use the [scIB](https://www.nature.com/articles/s41592-021-01336-8) pancreas dataset as an example. -# This dataset is publicly accessible via [here](https://figshare.com/ndownloader/files/24539828). You may place the dataset under `data` directory at the outer level. - -# The zero-shot integration workflow is as follows: -# 1. [Load and pre-process the dataset](#prepare-the-datasets) -# 2. [Generate scGPT embeddings for each cell](#generate-the-cell-embeddings) - import os import scanpy as sc from helical.models.helical import HelicalBaseModel @@ -22,16 +5,14 @@ import numpy as np from anndata import AnnData import logging -from typing import Optional, Literal -from pathlib import Path +from typing import Literal from accelerate import Accelerator -from helical.services.downloader import Downloader from helical.models.scgpt.scgpt_utils import load_model, get_embedding os.environ['KMP_DUPLICATE_LIB_OK']='True' class scGPT(HelicalBaseModel): - default_config = scGPTConfig() + configurer = scGPTConfig() """scGPT Model. The scGPT Model is a transformer-based model that can be used to extract gene embeddings from single-cell RNA-seq data. Currently we load the continous pre-training model from the scGPT repository as default model which works best on zero-shot tasks. @@ -41,8 +22,8 @@ class scGPT(HelicalBaseModel): ------- >>> from helical.models import scGPT,scGPTConfig >>> import anndata as ad - >>> model_config=scGPTConfig(batch_size=10) - >>> scgpt = scGPT(model_config=model_config) + >>> scgpt_config=scGPTConfig(batch_size=10) + >>> scgpt = scGPT(configurer=scgpt_config) >>> ann_data = ad.read_h5ad("./data/10k_pbmcs_proc.h5ad") >>> dataset = scgpt.process_data(ann_data[:100]) >>> embeddings = scgpt.get_embeddings(dataset) @@ -50,9 +31,7 @@ class scGPT(HelicalBaseModel): Parameters ---------- - model_dir : str, optional, default = None - The path to the model directory. None by default, which will download the model if not present. - model_config : scGPTConfig, optional, default = default_config + configurer : scGPTConfig, optional, default = configurer The model configuration. Returns @@ -64,26 +43,16 @@ class scGPT(HelicalBaseModel): We use the implementation from this `repository `_ , which comes from the original authors. You can find the description of the method in this `paper `_. """ - def __init__(self, model_dir: Optional[str] = None, model_config: scGPTConfig = default_config) -> None: - + def __init__(self, configurer: scGPTConfig = configurer) -> None: super().__init__() - self.model_config = model_config.config - self.downloader = Downloader() - - if model_dir is None: - self.downloader.download_via_name("scgpt/scGPT_CP/vocab.json") - self.downloader.download_via_name("scgpt/scGPT_CP/best_model.pt") - self.model_dir = Path(os.path.join(self.downloader.CACHE_DIR_HELICAL,'scgpt/scGPT_CP')) - else: - self.model_dir = Path(model_dir) - + self.config = configurer.config self.log = logging.getLogger("scGPT-Model") - self.model,self.vocab = load_model(self.model_dir,self.model_config,device=self.model_config["device"],use_fast_transformer=False) + self.model, self.vocab = load_model(self.config) - if self.model_config["accelerator"]: - self.accelerator = Accelerator(project_dir=self.model_dir, cpu=self.model_config["accelerator"]["cpu"]) + if self.config["accelerator"]: + self.accelerator = Accelerator(project_dir=self.config["model_path"].parent, cpu=self.config["accelerator"]["cpu"]) self.model = self.accelerator.prepare(self.model) else: self.accelerator = None @@ -104,10 +73,10 @@ def get_embeddings(self,data: AnnData) -> np.array: embeddings = get_embedding(data, model = self.model, vocab = self.vocab, - batch_size=self.model_config["batch_size"], - model_configs=self.model_config, + batch_size=self.config["batch_size"], + model_configs=self.config, gene_col=self.gene_column_name, - device=self.model_config["device"]) + device=self.config["device"]) return embeddings diff --git a/helical/models/scgpt/scgpt_config.py b/helical/models/scgpt/scgpt_config.py index fb17ca7d..a27f0608 100644 --- a/helical/models/scgpt/scgpt_config.py +++ b/helical/models/scgpt/scgpt_config.py @@ -1,4 +1,7 @@ from typing import Optional +from helical.services.downloader import Downloader +from pathlib import Path +import os class scGPTConfig(): """ Configuration class to use the scGPT Model. @@ -33,7 +36,8 @@ class scGPTConfig(): The accelerator configuration device : str, optional, default = "cpu" The device to use. Either use "cuda" or "cpu" - + use_fast_transformer : bool, optional, default = False + Wheter to use fast transformer or nots Returns ------- @@ -62,10 +66,17 @@ def __init__( world_size: int = 8, accelerator: Optional[dict] = None, device: str = "cpu", + use_fast_transformer: bool = False, ): + model_name = 'best_model' # TODO: Include more models + downloader = Downloader() + downloader.download_via_name("scgpt/scGPT_CP/vocab.json") + downloader.download_via_name(f"scgpt/scGPT_CP/{model_name}.pt") + model_path = Path(os.path.join(downloader.CACHE_DIR_HELICAL, 'scgpt/scGPT_CP', f'{model_name}.pt')) self.config = { + "model_path": model_path, "pad_token": pad_token, "batch_size": batch_size, "fast_transformer": fast_transformer, @@ -79,5 +90,6 @@ def __init__( "pad_value": pad_value, "world_size": world_size, "accelerator": accelerator, - "device": device + "device": device, + "use_fast_transformer": use_fast_transformer, } \ No newline at end of file diff --git a/helical/models/scgpt/scgpt_utils.py b/helical/models/scgpt/scgpt_utils.py index 815f894d..914cb72a 100644 --- a/helical/models/scgpt/scgpt_utils.py +++ b/helical/models/scgpt/scgpt_utils.py @@ -1,6 +1,3 @@ -import json -import os -from pathlib import Path from typing import Optional, Union from os import PathLike @@ -8,27 +5,22 @@ import scanpy as sc import torch from anndata import AnnData -from torch.utils.data import DataLoader, SequentialSampler -from tqdm import tqdm -from .data_collator import DataCollator from .model_dir import TransformerModel from .tokenizer import GeneVocab from .utils import load_pretrained - +from helical.models.scgpt.scgpt_config import scGPTConfig from helical.models.scgpt.tasks.cell_emb import get_batch_cell_embeddings -def load_model(model_dir,model_configs,device='cpu',use_fast_transformer=False): +def load_model(model_configs: scGPTConfig): # LOAD MODEL - model_dir = Path(model_dir) + model_dir = model_configs["model_path"].parent vocab_file = model_dir / "vocab.json" - model_file = model_dir / "best_model.pt" - pad_token = "" - special_tokens = [pad_token, "", ""] + special_tokens = [model_configs["pad_token"], "", ""] # vocabulary vocab = GeneVocab.from_file(vocab_file) @@ -38,7 +30,7 @@ def load_model(model_dir,model_configs,device='cpu',use_fast_transformer=False): # Binning will be applied after tokenization. A possible way to do is to use the unified way of binning in the data collator. - vocab.set_default_index(vocab[""]) + vocab.set_default_index(vocab[model_configs["pad_token"]]) model = TransformerModel( ntoken=len(vocab), @@ -57,18 +49,15 @@ def load_model(model_dir,model_configs,device='cpu',use_fast_transformer=False): use_batch_labels=False, domain_spec_batchnorm=False, explicit_zero_prob=False, - use_fast_transformer=use_fast_transformer, + use_fast_transformer=model_configs["use_fast_transformer"], fast_transformer_backend="flash", pre_norm=False, ) - load_pretrained(model, torch.load(model_file, map_location=device), verbose=False) - model.to(device) + load_pretrained(model, torch.load(model_configs["model_path"], map_location = model_configs["device"]), verbose = False) + model.to(model_configs["device"]) model.eval() - return model,vocab - - - + return model, vocab def get_embedding( adata_or_file: Union[AnnData, PathLike],