Skip to content

Commit

Permalink
Move Downloader to scGPT configurer
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 21, 2024
1 parent facf151 commit 8996133
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 68 deletions.
4 changes: 2 additions & 2 deletions examples/run_scgpt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from helical.models.scgpt.model import scGPT, scGPTConfig
import anndata as ad

model_config = scGPTConfig(batch_size=10)
scgpt = scGPT(model_config=model_config)
scgpt_config = scGPTConfig(batch_size=10)
scgpt = scGPT(configurer = scgpt_config)

adata = ad.read_h5ad("./10k_pbmcs_proc.h5ad")
data = scgpt.process_data(adata[:10])
Expand Down
57 changes: 13 additions & 44 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,18 @@
# This tutorial covers the zero-shot integration with continual pre-trained scGPT.
# This particular workflow works for scRNA-seq datasets without fine-tuning (or any extensive training) of scGPT.
# Continual pre-trained scGPT (scGPT_CP) is a model that inherits the pre-trained scGPT whole-human model checkpoint,
# and is further supervised by extra cell type labels (using the [Tabula Sapiens](https://tabula-sapiens-portal.ds.czbiohub.org/) dataset)
# during the continual pre-training stage. We observed that the scGPT_CP model can achieve comparable or better zero-shot performance
# on cell embedding related tasks compared to the original checkpoint, especially on datasets with observable technical batch effects.
# This tutorial will show how to use the latent space of scGPT to integrate scRNA-seq datasets.
# We use the `scGPT_CP` model to provide embeddings out of the box.
# You may download it from [here](https://drive.google.com/drive/folders/1_GROJTzXiAV8HB4imruOTk6PEGuNOcgB).

# We will use the [scIB](https://www.nature.com/articles/s41592-021-01336-8) pancreas dataset as an example.
# This dataset is publicly accessible via [here](https://figshare.com/ndownloader/files/24539828). You may place the dataset under `data` directory at the outer level.

# The zero-shot integration workflow is as follows:
# 1. [Load and pre-process the dataset](#prepare-the-datasets)
# 2. [Generate scGPT embeddings for each cell](#generate-the-cell-embeddings)

import os
import scanpy as sc
from helical.models.helical import HelicalBaseModel
from helical.models.scgpt.scgpt_config import scGPTConfig
import numpy as np
from anndata import AnnData
import logging
from typing import Optional, Literal
from pathlib import Path
from typing import Literal
from accelerate import Accelerator
from helical.services.downloader import Downloader
from helical.models.scgpt.scgpt_utils import load_model, get_embedding

os.environ['KMP_DUPLICATE_LIB_OK']='True'

class scGPT(HelicalBaseModel):
default_config = scGPTConfig()
configurer = scGPTConfig()
"""scGPT Model.
The scGPT Model is a transformer-based model that can be used to extract gene embeddings from single-cell RNA-seq data.
Currently we load the continous pre-training model from the scGPT repository as default model which works best on zero-shot tasks.
Expand All @@ -41,18 +22,16 @@ class scGPT(HelicalBaseModel):
-------
>>> from helical.models import scGPT,scGPTConfig
>>> import anndata as ad
>>> model_config=scGPTConfig(batch_size=10)
>>> scgpt = scGPT(model_config=model_config)
>>> scgpt_config=scGPTConfig(batch_size=10)
>>> scgpt = scGPT(configurer=scgpt_config)
>>> ann_data = ad.read_h5ad("./data/10k_pbmcs_proc.h5ad")
>>> dataset = scgpt.process_data(ann_data[:100])
>>> embeddings = scgpt.get_embeddings(dataset)
Parameters
----------
model_dir : str, optional, default = None
The path to the model directory. None by default, which will download the model if not present.
model_config : scGPTConfig, optional, default = default_config
configurer : scGPTConfig, optional, default = configurer
The model configuration.
Returns
Expand All @@ -64,26 +43,16 @@ class scGPT(HelicalBaseModel):
We use the implementation from this `repository <https://github.com/bowang-lab/scGPT>`_ , which comes from the original authors. You can find the description of the method in this `paper <https://www.nature.com/articles/s41592-024-02201-0>`_.
"""

def __init__(self, model_dir: Optional[str] = None, model_config: scGPTConfig = default_config) -> None:

def __init__(self, configurer: scGPTConfig = configurer) -> None:

super().__init__()
self.model_config = model_config.config
self.downloader = Downloader()

if model_dir is None:
self.downloader.download_via_name("scgpt/scGPT_CP/vocab.json")
self.downloader.download_via_name("scgpt/scGPT_CP/best_model.pt")
self.model_dir = Path(os.path.join(self.downloader.CACHE_DIR_HELICAL,'scgpt/scGPT_CP'))
else:
self.model_dir = Path(model_dir)

self.config = configurer.config
self.log = logging.getLogger("scGPT-Model")

self.model,self.vocab = load_model(self.model_dir,self.model_config,device=self.model_config["device"],use_fast_transformer=False)
self.model, self.vocab = load_model(self.config)

if self.model_config["accelerator"]:
self.accelerator = Accelerator(project_dir=self.model_dir, cpu=self.model_config["accelerator"]["cpu"])
if self.config["accelerator"]:
self.accelerator = Accelerator(project_dir=self.config["model_path"].parent, cpu=self.config["accelerator"]["cpu"])
self.model = self.accelerator.prepare(self.model)
else:
self.accelerator = None
Expand All @@ -104,10 +73,10 @@ def get_embeddings(self,data: AnnData) -> np.array:
embeddings = get_embedding(data,
model = self.model,
vocab = self.vocab,
batch_size=self.model_config["batch_size"],
model_configs=self.model_config,
batch_size=self.config["batch_size"],
model_configs=self.config,
gene_col=self.gene_column_name,
device=self.model_config["device"])
device=self.config["device"])


return embeddings
Expand Down
16 changes: 14 additions & 2 deletions helical/models/scgpt/scgpt_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional
from helical.services.downloader import Downloader
from pathlib import Path
import os
class scGPTConfig():
"""
Configuration class to use the scGPT Model.
Expand Down Expand Up @@ -33,7 +36,8 @@ class scGPTConfig():
The accelerator configuration
device : str, optional, default = "cpu"
The device to use. Either use "cuda" or "cpu"
use_fast_transformer : bool, optional, default = False
Wheter to use fast transformer or nots
Returns
-------
Expand Down Expand Up @@ -62,10 +66,17 @@ def __init__(
world_size: int = 8,
accelerator: Optional[dict] = None,
device: str = "cpu",
use_fast_transformer: bool = False,
):

model_name = 'best_model' # TODO: Include more models
downloader = Downloader()
downloader.download_via_name("scgpt/scGPT_CP/vocab.json")
downloader.download_via_name(f"scgpt/scGPT_CP/{model_name}.pt")
model_path = Path(os.path.join(downloader.CACHE_DIR_HELICAL, 'scgpt/scGPT_CP', f'{model_name}.pt'))

self.config = {
"model_path": model_path,
"pad_token": pad_token,
"batch_size": batch_size,
"fast_transformer": fast_transformer,
Expand All @@ -79,5 +90,6 @@ def __init__(
"pad_value": pad_value,
"world_size": world_size,
"accelerator": accelerator,
"device": device
"device": device,
"use_fast_transformer": use_fast_transformer,
}
29 changes: 9 additions & 20 deletions helical/models/scgpt/scgpt_utils.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,26 @@
import json
import os
from pathlib import Path
from typing import Optional, Union
from os import PathLike

import numpy as np
import scanpy as sc
import torch
from anndata import AnnData
from torch.utils.data import DataLoader, SequentialSampler
from tqdm import tqdm


from .data_collator import DataCollator
from .model_dir import TransformerModel
from .tokenizer import GeneVocab
from .utils import load_pretrained

from helical.models.scgpt.scgpt_config import scGPTConfig
from helical.models.scgpt.tasks.cell_emb import get_batch_cell_embeddings



def load_model(model_dir,model_configs,device='cpu',use_fast_transformer=False):
def load_model(model_configs: scGPTConfig):

# LOAD MODEL
model_dir = Path(model_dir)
model_dir = model_configs["model_path"].parent
vocab_file = model_dir / "vocab.json"
model_file = model_dir / "best_model.pt"
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
special_tokens = [model_configs["pad_token"], "<cls>", "<eoc>"]

# vocabulary
vocab = GeneVocab.from_file(vocab_file)
Expand All @@ -38,7 +30,7 @@ def load_model(model_dir,model_configs,device='cpu',use_fast_transformer=False):

# Binning will be applied after tokenization. A possible way to do is to use the unified way of binning in the data collator.

vocab.set_default_index(vocab["<pad>"])
vocab.set_default_index(vocab[model_configs["pad_token"]])

model = TransformerModel(
ntoken=len(vocab),
Expand All @@ -57,18 +49,15 @@ def load_model(model_dir,model_configs,device='cpu',use_fast_transformer=False):
use_batch_labels=False,
domain_spec_batchnorm=False,
explicit_zero_prob=False,
use_fast_transformer=use_fast_transformer,
use_fast_transformer=model_configs["use_fast_transformer"],
fast_transformer_backend="flash",
pre_norm=False,
)

load_pretrained(model, torch.load(model_file, map_location=device), verbose=False)
model.to(device)
load_pretrained(model, torch.load(model_configs["model_path"], map_location = model_configs["device"]), verbose = False)
model.to(model_configs["device"])
model.eval()
return model,vocab



return model, vocab

def get_embedding(
adata_or_file: Union[AnnData, PathLike],
Expand Down

0 comments on commit 8996133

Please sign in to comment.