diff --git a/examples/run_hyena_dna.py b/examples/run_hyena_dna.py index 60c77fb8..d6ba91d9 100644 --- a/examples/run_hyena_dna.py +++ b/examples/run_hyena_dna.py @@ -1,4 +1,5 @@ from helical.models.hyena_dna.model import HyenaDNA, HyenaDNAConfig + hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256") model = HyenaDNA(configurer = hyena_config) sequence = 'ACTG' * int(1024/4) diff --git a/helical/models/geneformer/model.py b/helical/models/geneformer/model.py index 598e4e47..bb98efeb 100644 --- a/helical/models/geneformer/model.py +++ b/helical/models/geneformer/model.py @@ -14,6 +14,7 @@ from accelerate import Accelerator import pickle as pkl +LOGGER = logging.getLogger(__name__) class Geneformer(HelicalBaseModel): """Geneformer Model. The Geneformer Model is a transformer-based model that can be used to extract gene embeddings from single-cell RNA-seq data. @@ -50,7 +51,6 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None: super().__init__() self.config = configurer - self.log = logging.getLogger("Geneformer-Model") self.device = self.config.device downloader = Downloader() @@ -70,7 +70,8 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None: self.model = self.accelerator.prepare(self.model) else: self.accelerator = None - + LOGGER.info(f"Model finished initializing.") + def process_data(self, data: AnnData, nproc: int = 4,use_gene_symbols=True, output_path: Optional[str] = None) -> Dataset: """Processes the data for the UCE model @@ -140,7 +141,7 @@ def get_embeddings(self, dataset: Dataset) -> np.array: np.array The gene embeddings in the form of a numpy array """ - self.log.info(f"Inference started") + LOGGER.info(f"Inference started:") embeddings = get_embs( self.model, dataset, diff --git a/helical/models/hyena_dna/model.py b/helical/models/hyena_dna/model.py index 907d21cb..648a89f6 100644 --- a/helical/models/hyena_dna/model.py +++ b/helical/models/hyena_dna/model.py @@ -6,6 +6,7 @@ import torch from .standalone_hyenadna import CharacterTokenizer from helical.services.downloader import Downloader +LOGGER = logging.getLogger(__name__) class HyenaDNA(HelicalBaseModel): """HyenaDNA model.""" @@ -14,7 +15,6 @@ class HyenaDNA(HelicalBaseModel): def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None: super().__init__() self.config = configurer.config - self.log = logging.getLogger("Hyena-DNA-Model") downloader = Downloader() for file in self.config["list_of_files_to_download"]: @@ -35,7 +35,7 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None: # prep model and forward self.model.to(self.device) self.model.eval() - + LOGGER.info(f"Model finished initializing.") def process_data(self, sequence): @@ -48,5 +48,6 @@ def process_data(self, sequence): return tok_seq def get_embeddings(self, tok_seq): + LOGGER.info(f"Inference started") with torch.inference_mode(): return self.model(tok_seq) \ No newline at end of file diff --git a/helical/models/hyena_dna/pretrained_model.py b/helical/models/hyena_dna/pretrained_model.py index bae48468..9fab5f51 100644 --- a/helical/models/hyena_dna/pretrained_model.py +++ b/helical/models/hyena_dna/pretrained_model.py @@ -4,6 +4,9 @@ from transformers import PreTrainedModel import re from .standalone_hyenadna import HyenaDNAModel +import logging + +LOGGER = logging.getLogger(__name__) # helper 1 def inject_substring(orig_str): @@ -82,6 +85,6 @@ def from_pretrained(cls, # scratch model has now been updated scratch_model.load_state_dict(state_dict) - print("Loaded pretrained weights ok!") + LOGGER.info("Loaded pretrained weights ok!") return scratch_model \ No newline at end of file diff --git a/helical/models/scgpt/model.py b/helical/models/scgpt/model.py index 22079979..0a4f1011 100644 --- a/helical/models/scgpt/model.py +++ b/helical/models/scgpt/model.py @@ -10,6 +10,7 @@ from helical.models.scgpt.scgpt_utils import load_model, get_embedding from helical.services.downloader import Downloader +LOGGER = logging.getLogger(__name__) os.environ['KMP_DUPLICATE_LIB_OK']='True' class scGPT(HelicalBaseModel): @@ -48,7 +49,6 @@ def __init__(self, configurer: scGPTConfig = configurer) -> None: super().__init__() self.config = configurer.config - self.log = logging.getLogger("scGPT-Model") downloader = Downloader() for file in self.config["list_of_files_to_download"]: @@ -61,9 +61,9 @@ def __init__(self, configurer: scGPTConfig = configurer) -> None: self.model = self.accelerator.prepare(self.model) else: self.accelerator = None - + LOGGER.info(f"Model finished initializing.") - def get_embeddings(self,data: AnnData) -> np.array: + def get_embeddings(self, data: AnnData) -> np.array: """Gets the gene embeddings Returns @@ -71,7 +71,7 @@ def get_embeddings(self,data: AnnData) -> np.array: np.array The gene embeddings in the form of a numpy array """ - self.log.info(f"Inference started:") + LOGGER.info(f"Inference started:") # The extracted embedding is stored in the `X_scGPT` field of `obsm` in AnnData. # for local development, only get embeddings for the first 100 entries @@ -82,7 +82,6 @@ def get_embeddings(self,data: AnnData) -> np.array: model_configs=self.config, gene_col=self.gene_column_name, device=self.config["device"]) - return embeddings diff --git a/helical/models/uce/model.py b/helical/models/uce/model.py index 5e4ba7f3..1cf0ef08 100644 --- a/helical/models/uce/model.py +++ b/helical/models/uce/model.py @@ -8,6 +8,7 @@ from accelerate import Accelerator from helical.services.downloader import Downloader +LOGGER = logging.getLogger(__name__) class UCE(HelicalBaseModel): """Universal Cell Embedding Model. This model reads in single-cell RNA-seq data and outputs gene embeddings. This model particularly uses protein-embeddings generated by ESM2. @@ -41,7 +42,6 @@ class UCE(HelicalBaseModel): def __init__(self, configurer: UCEConfig = default_configurer) -> None: super().__init__() self.config = configurer.config - self.log = logging.getLogger("UCE-Model") downloader = Downloader() for file in self.config["list_of_files_to_download"]: @@ -58,7 +58,7 @@ def __init__(self, configurer: UCEConfig = default_configurer) -> None: self.model = self.accelerator.prepare(self.model) else: self.accelerator = None - self.log.info(f"Model finished initializing.") + LOGGER.info(f"Model finished initializing.") def process_data(self, data: AnnData, species: str = "human", @@ -112,6 +112,6 @@ def get_embeddings(self, dataloader: DataLoader) -> np.array: np.array The gene embeddings in the form of a numpy array """ - self.log.info(f"Inference started") + LOGGER.info(f"Inference started") embeddings = get_gene_embeddings(self.model, dataloader, self.accelerator) return embeddings diff --git a/helical/models/uce/uce_utils.py b/helical/models/uce/uce_utils.py index fe713b99..6486e50c 100644 --- a/helical/models/uce/uce_utils.py +++ b/helical/models/uce/uce_utils.py @@ -14,7 +14,7 @@ from helical.models.uce.uce_model import TransformerModel from helical.models.uce.uce_dataset import UCEDataset -logger = logging.getLogger(__name__) +LOGGER = logging.getLogger(__name__) def process_data(anndata, model_config, @@ -53,7 +53,7 @@ def process_data(anndata, dataset_chroms, dataset_start = get_positions(Path(files_config["spec_chrom_csv_path"]), species, filtered_adata) if not (len(dataset_chroms) == len(dataset_start) == num_genes == pe_row_idxs.shape[0]): - logger.error(f'Invalid input dimensions for the UCEDataset! ' + LOGGER.error(f'Invalid input dimensions for the UCEDataset! ' f'dataset_chroms: {len(dataset_chroms)}, ' f'dataset_start: {len(dataset_start)}, ' f'num_genes: {num_genes}, ' @@ -75,7 +75,7 @@ def process_data(anndata, collate_fn=dataset.collator_fn, num_workers=0) - logger.info(f'UCE Dataset and DataLoader prepared. Setting batch_size={batch_size} for inference.') + LOGGER.info(f'UCE Dataset and DataLoader prepared. Setting batch_size={batch_size} for inference.') if accelerator is not None: dataloader = accelerator.prepare(dataloader) @@ -166,9 +166,9 @@ def prepare_expression_counts_file(gene_expression: np.array, name: str, folder_ fp = np.memmap(filename, dtype='int64', mode='w+', shape=shape) fp[:] = gene_expression[:] fp.flush() - logger.info(f"Passed the gene expressions (with shape={shape} and max gene count data {gene_expression.max()}) to {filename}") + LOGGER.info(f"Passed the gene expressions (with shape={shape} and max gene count data {gene_expression.max()}) to {filename}") except: - logger.error(f"Error during preparation of npz file {filename}.") + LOGGER.error(f"Error during preparation of npz file {filename}.") raise Exception ## writing a funciton to load the model diff --git a/helical/services/downloader.py b/helical/services/downloader.py index abf7b5a3..bce80119 100644 --- a/helical/services/downloader.py +++ b/helical/services/downloader.py @@ -14,13 +14,13 @@ from git import Repo from helical.constants.paths import CACHE_DIR_HELICAL +LOGGER = logging.getLogger(__name__) INTERVAL = 1000 # interval to get gene mappings CHUNK_SIZE = 1024 * 1024 * 10 #8192 # size of individual chunks to download LOADING_BAR_LENGTH = 50 # size of the download progression bar in console class Downloader(Logger): def __init__(self, loging_type = LoggingType.CONSOLE, level = LoggingLevel.INFO) -> None: super().__init__(loging_type, level) - self.log = logging.getLogger("Downloader") self.display = True # manually create a requests session @@ -41,10 +41,10 @@ def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path): try: df = pd.read_csv(path_to_ets_csv) except: - self.log.exception(f"Failed to open the '{path_to_ets_csv}' file. Please provide it.") + LOGGER.exception(f"Failed to open the '{path_to_ets_csv}' file. Please provide it.") if output.is_file(): - self.log.info(f"No mapping is done because mapping file already exists here: '{output}'") + LOGGER.info(f"No mapping is done because mapping file already exists here: '{output}'") else: genes = df['egid'].dropna().unique() @@ -54,7 +54,7 @@ def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path): ensemble_to_display_name = dict() - self.log.info(f"Starting to download the mappings of {len(genes)} genes from '{server}'") + LOGGER.info(f"Starting to download the mappings of {len(genes)} genes from '{server}'") # Resetting for visualization self.data_length = 0 @@ -69,7 +69,7 @@ def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path): ensemble_to_display_name.update(decoded) pkl.dump(ensemble_to_display_name, open(output, 'wb')) - self.log.info(f"Downloaded all mappings and saved to: '{output}'") + LOGGER.info(f"Downloaded all mappings and saved to: '{output}'") def download_via_link(self, output: Path, link: str) -> None: ''' @@ -81,10 +81,10 @@ def download_via_link(self, output: Path, link: str) -> None: ''' if output.is_file(): - self.log.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.") + LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.") else: - self.log.info(f"Starting to download: '{link}'") + LOGGER.info(f"Starting to download: '{link}'") with open(output, "wb") as f: response = requests.get(link, stream=True) total_length = response.headers.get('content-length') @@ -102,8 +102,8 @@ def download_via_link(self, output: Path, link: str) -> None: self._display_download_progress(len(data)) f.write(data) except: - self.log.error(f"Failed downloading file from '{link}'") - self.log.info(f"File saved to: '{output}'") + LOGGER.error(f"Failed downloading file from '{link}'") + LOGGER.info(f"File saved to: '{output}'") def clone_git_repo(self, destination: Path, repo_url: str, checkout: str) -> None: ''' @@ -116,13 +116,13 @@ def clone_git_repo(self, destination: Path, repo_url: str, checkout: str) -> Non ''' if destination.is_dir(): - self.log.info(f"Folder: {destination} exists already. No 'git clone' is performed.") + LOGGER.info(f"Folder: {destination} exists already. No 'git clone' is performed.") else: - self.log.info(f"Clonging {repo_url} to {destination}") + LOGGER.info(f"Clonging {repo_url} to {destination}") repo = Repo.clone_from(repo_url, destination) repo.git.checkout(checkout) - self.log.info(f"Successfully cloned and checked out '{checkout}' of {repo_url}") + LOGGER.info(f"Successfully cloned and checked out '{checkout}' of {repo_url}") def _display_download_progress(self, data_chunk_size: int) -> None: ''' @@ -152,10 +152,10 @@ def download_via_name_v0(self, name: str) -> None: os.makedirs(os.path.dirname(output),exist_ok=True) if Path(output).is_file(): - self.log.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.") + LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.") else: - self.log.info(f"Starting to download: '{link}'") + LOGGER.info(f"Starting to download: '{link}'") with open(output, "wb") as f: response = requests.get(link, stream=True) total_length = response.headers.get('content-length') @@ -175,8 +175,8 @@ def download_via_name_v0(self, name: str) -> None: f.write(data) pbar.update(len(data)) except: - self.log.error(f"Failed downloading file from '{link}'") - self.log.info(f"File saved to: '{output}'") + LOGGER.error(f"Failed downloading file from '{link}'") + LOGGER.info(f"File saved to: '{output}'") def download_via_name(self, name: str) -> None: @@ -201,19 +201,19 @@ def download_via_name(self, name: str) -> None: if not os.path.exists(os.path.dirname(output)): os.makedirs(os.path.dirname(output),exist_ok=True) - self.log.info(f"Creating Folder {os.path.dirname(output)}") + LOGGER.info(f"Creating Folder {os.path.dirname(output)}") if Path(output).is_file(): - self.log.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.") + LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.") else: - self.log.info(f"Starting to download: '{blob_url}'") + LOGGER.info(f"Starting to download: '{blob_url}'") # disabling logging info messages from Azure package as there are too many logging.disable(logging.INFO) self.display_azure_download_progress(blob_client, blob_url, output) logging.disable(logging.NOTSET) - self.log.info(f"File saved to: '{output}'") + LOGGER.info(f"File saved to: '{output}'") def display_azure_download_progress(self, blob_client: BlobClient, blob_url: str, output: Path) -> None: """ @@ -248,7 +248,7 @@ def progress_callback(bytes_transferred, total_bytes): sample_blob.write(download_stream.readall()) except: - self.log.error(f"Failed downloading file from '{blob_url}'") + LOGGER.error(f"Failed downloading file from '{blob_url}'") if self.display: pbar.close()