Skip to content

Commit

Permalink
Refactor loggers using __name__ of files
Browse files Browse the repository at this point in the history
  • Loading branch information
bputzeys committed May 22, 2024
1 parent afc3a31 commit e2b86d7
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 40 deletions.
1 change: 1 addition & 0 deletions examples/run_hyena_dna.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from helical.models.hyena_dna.model import HyenaDNA, HyenaDNAConfig

hyena_config = HyenaDNAConfig(model_name = "hyenadna-tiny-1k-seqlen-d256")
model = HyenaDNA(configurer = hyena_config)
sequence = 'ACTG' * int(1024/4)
Expand Down
7 changes: 4 additions & 3 deletions helical/models/geneformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from accelerate import Accelerator
import pickle as pkl

LOGGER = logging.getLogger(__name__)
class Geneformer(HelicalBaseModel):
"""Geneformer Model.
The Geneformer Model is a transformer-based model that can be used to extract gene embeddings from single-cell RNA-seq data.
Expand Down Expand Up @@ -50,7 +51,6 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None:
super().__init__()
self.config = configurer

self.log = logging.getLogger("Geneformer-Model")
self.device = self.config.device

downloader = Downloader()
Expand All @@ -70,7 +70,8 @@ def __init__(self, configurer: GeneformerConfig = default_configurer) -> None:
self.model = self.accelerator.prepare(self.model)
else:
self.accelerator = None

LOGGER.info(f"Model finished initializing.")

def process_data(self, data: AnnData, nproc: int = 4,use_gene_symbols=True, output_path: Optional[str] = None) -> Dataset:
"""Processes the data for the UCE model
Expand Down Expand Up @@ -140,7 +141,7 @@ def get_embeddings(self, dataset: Dataset) -> np.array:
np.array
The gene embeddings in the form of a numpy array
"""
self.log.info(f"Inference started")
LOGGER.info(f"Inference started:")
embeddings = get_embs(
self.model,
dataset,
Expand Down
5 changes: 3 additions & 2 deletions helical/models/hyena_dna/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from .standalone_hyenadna import CharacterTokenizer
from helical.services.downloader import Downloader
LOGGER = logging.getLogger(__name__)

class HyenaDNA(HelicalBaseModel):
"""HyenaDNA model."""
Expand All @@ -14,7 +15,6 @@ class HyenaDNA(HelicalBaseModel):
def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
super().__init__()
self.config = configurer.config
self.log = logging.getLogger("Hyena-DNA-Model")

downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
Expand All @@ -35,7 +35,7 @@ def __init__(self, configurer: HyenaDNAConfig = default_configurer) -> None:
# prep model and forward
self.model.to(self.device)
self.model.eval()

LOGGER.info(f"Model finished initializing.")

def process_data(self, sequence):

Expand All @@ -48,5 +48,6 @@ def process_data(self, sequence):
return tok_seq

def get_embeddings(self, tok_seq):
LOGGER.info(f"Inference started")
with torch.inference_mode():
return self.model(tok_seq)
5 changes: 4 additions & 1 deletion helical/models/hyena_dna/pretrained_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
from transformers import PreTrainedModel
import re
from .standalone_hyenadna import HyenaDNAModel
import logging

LOGGER = logging.getLogger(__name__)

# helper 1
def inject_substring(orig_str):
Expand Down Expand Up @@ -82,6 +85,6 @@ def from_pretrained(cls,

# scratch model has now been updated
scratch_model.load_state_dict(state_dict)
print("Loaded pretrained weights ok!")
LOGGER.info("Loaded pretrained weights ok!")
return scratch_model

9 changes: 4 additions & 5 deletions helical/models/scgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from helical.models.scgpt.scgpt_utils import load_model, get_embedding
from helical.services.downloader import Downloader

LOGGER = logging.getLogger(__name__)
os.environ['KMP_DUPLICATE_LIB_OK']='True'

class scGPT(HelicalBaseModel):
Expand Down Expand Up @@ -48,7 +49,6 @@ def __init__(self, configurer: scGPTConfig = configurer) -> None:

super().__init__()
self.config = configurer.config
self.log = logging.getLogger("scGPT-Model")

downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
Expand All @@ -61,17 +61,17 @@ def __init__(self, configurer: scGPTConfig = configurer) -> None:
self.model = self.accelerator.prepare(self.model)
else:
self.accelerator = None

LOGGER.info(f"Model finished initializing.")

def get_embeddings(self,data: AnnData) -> np.array:
def get_embeddings(self, data: AnnData) -> np.array:
"""Gets the gene embeddings
Returns
-------
np.array
The gene embeddings in the form of a numpy array
"""
self.log.info(f"Inference started:")
LOGGER.info(f"Inference started:")
# The extracted embedding is stored in the `X_scGPT` field of `obsm` in AnnData.
# for local development, only get embeddings for the first 100 entries

Expand All @@ -82,7 +82,6 @@ def get_embeddings(self,data: AnnData) -> np.array:
model_configs=self.config,
gene_col=self.gene_column_name,
device=self.config["device"])


return embeddings

Expand Down
6 changes: 3 additions & 3 deletions helical/models/uce/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from accelerate import Accelerator
from helical.services.downloader import Downloader

LOGGER = logging.getLogger(__name__)
class UCE(HelicalBaseModel):
"""Universal Cell Embedding Model. This model reads in single-cell RNA-seq data and outputs gene embeddings.
This model particularly uses protein-embeddings generated by ESM2.
Expand Down Expand Up @@ -41,7 +42,6 @@ class UCE(HelicalBaseModel):
def __init__(self, configurer: UCEConfig = default_configurer) -> None:
super().__init__()
self.config = configurer.config
self.log = logging.getLogger("UCE-Model")

downloader = Downloader()
for file in self.config["list_of_files_to_download"]:
Expand All @@ -58,7 +58,7 @@ def __init__(self, configurer: UCEConfig = default_configurer) -> None:
self.model = self.accelerator.prepare(self.model)
else:
self.accelerator = None
self.log.info(f"Model finished initializing.")
LOGGER.info(f"Model finished initializing.")

def process_data(self, data: AnnData,
species: str = "human",
Expand Down Expand Up @@ -112,6 +112,6 @@ def get_embeddings(self, dataloader: DataLoader) -> np.array:
np.array
The gene embeddings in the form of a numpy array
"""
self.log.info(f"Inference started")
LOGGER.info(f"Inference started")
embeddings = get_gene_embeddings(self.model, dataloader, self.accelerator)
return embeddings
10 changes: 5 additions & 5 deletions helical/models/uce/uce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from helical.models.uce.uce_model import TransformerModel
from helical.models.uce.uce_dataset import UCEDataset

logger = logging.getLogger(__name__)
LOGGER = logging.getLogger(__name__)

def process_data(anndata,
model_config,
Expand Down Expand Up @@ -53,7 +53,7 @@ def process_data(anndata,
dataset_chroms, dataset_start = get_positions(Path(files_config["spec_chrom_csv_path"]), species, filtered_adata)

if not (len(dataset_chroms) == len(dataset_start) == num_genes == pe_row_idxs.shape[0]):
logger.error(f'Invalid input dimensions for the UCEDataset! '
LOGGER.error(f'Invalid input dimensions for the UCEDataset! '
f'dataset_chroms: {len(dataset_chroms)}, '
f'dataset_start: {len(dataset_start)}, '
f'num_genes: {num_genes}, '
Expand All @@ -75,7 +75,7 @@ def process_data(anndata,
collate_fn=dataset.collator_fn,
num_workers=0)

logger.info(f'UCE Dataset and DataLoader prepared. Setting batch_size={batch_size} for inference.')
LOGGER.info(f'UCE Dataset and DataLoader prepared. Setting batch_size={batch_size} for inference.')

if accelerator is not None:
dataloader = accelerator.prepare(dataloader)
Expand Down Expand Up @@ -166,9 +166,9 @@ def prepare_expression_counts_file(gene_expression: np.array, name: str, folder_
fp = np.memmap(filename, dtype='int64', mode='w+', shape=shape)
fp[:] = gene_expression[:]
fp.flush()
logger.info(f"Passed the gene expressions (with shape={shape} and max gene count data {gene_expression.max()}) to {filename}")
LOGGER.info(f"Passed the gene expressions (with shape={shape} and max gene count data {gene_expression.max()}) to {filename}")
except:
logger.error(f"Error during preparation of npz file {filename}.")
LOGGER.error(f"Error during preparation of npz file {filename}.")
raise Exception

## writing a funciton to load the model
Expand Down
42 changes: 21 additions & 21 deletions helical/services/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
from git import Repo
from helical.constants.paths import CACHE_DIR_HELICAL

LOGGER = logging.getLogger(__name__)
INTERVAL = 1000 # interval to get gene mappings
CHUNK_SIZE = 1024 * 1024 * 10 #8192 # size of individual chunks to download
LOADING_BAR_LENGTH = 50 # size of the download progression bar in console
class Downloader(Logger):
def __init__(self, loging_type = LoggingType.CONSOLE, level = LoggingLevel.INFO) -> None:
super().__init__(loging_type, level)
self.log = logging.getLogger("Downloader")
self.display = True

# manually create a requests session
Expand All @@ -41,10 +41,10 @@ def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path):
try:
df = pd.read_csv(path_to_ets_csv)
except:
self.log.exception(f"Failed to open the '{path_to_ets_csv}' file. Please provide it.")
LOGGER.exception(f"Failed to open the '{path_to_ets_csv}' file. Please provide it.")

if output.is_file():
self.log.info(f"No mapping is done because mapping file already exists here: '{output}'")
LOGGER.info(f"No mapping is done because mapping file already exists here: '{output}'")

else:
genes = df['egid'].dropna().unique()
Expand All @@ -54,7 +54,7 @@ def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path):

ensemble_to_display_name = dict()

self.log.info(f"Starting to download the mappings of {len(genes)} genes from '{server}'")
LOGGER.info(f"Starting to download the mappings of {len(genes)} genes from '{server}'")

# Resetting for visualization
self.data_length = 0
Expand All @@ -69,7 +69,7 @@ def get_ensemble_mapping(self, path_to_ets_csv: Path, output: Path):
ensemble_to_display_name.update(decoded)

pkl.dump(ensemble_to_display_name, open(output, 'wb'))
self.log.info(f"Downloaded all mappings and saved to: '{output}'")
LOGGER.info(f"Downloaded all mappings and saved to: '{output}'")

def download_via_link(self, output: Path, link: str) -> None:
'''
Expand All @@ -81,10 +81,10 @@ def download_via_link(self, output: Path, link: str) -> None:
'''

if output.is_file():
self.log.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")
LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")

else:
self.log.info(f"Starting to download: '{link}'")
LOGGER.info(f"Starting to download: '{link}'")
with open(output, "wb") as f:
response = requests.get(link, stream=True)
total_length = response.headers.get('content-length')
Expand All @@ -102,8 +102,8 @@ def download_via_link(self, output: Path, link: str) -> None:
self._display_download_progress(len(data))
f.write(data)
except:
self.log.error(f"Failed downloading file from '{link}'")
self.log.info(f"File saved to: '{output}'")
LOGGER.error(f"Failed downloading file from '{link}'")
LOGGER.info(f"File saved to: '{output}'")

def clone_git_repo(self, destination: Path, repo_url: str, checkout: str) -> None:
'''
Expand All @@ -116,13 +116,13 @@ def clone_git_repo(self, destination: Path, repo_url: str, checkout: str) -> Non
'''

if destination.is_dir():
self.log.info(f"Folder: {destination} exists already. No 'git clone' is performed.")
LOGGER.info(f"Folder: {destination} exists already. No 'git clone' is performed.")

else:
self.log.info(f"Clonging {repo_url} to {destination}")
LOGGER.info(f"Clonging {repo_url} to {destination}")
repo = Repo.clone_from(repo_url, destination)
repo.git.checkout(checkout)
self.log.info(f"Successfully cloned and checked out '{checkout}' of {repo_url}")
LOGGER.info(f"Successfully cloned and checked out '{checkout}' of {repo_url}")

def _display_download_progress(self, data_chunk_size: int) -> None:
'''
Expand Down Expand Up @@ -152,10 +152,10 @@ def download_via_name_v0(self, name: str) -> None:
os.makedirs(os.path.dirname(output),exist_ok=True)

if Path(output).is_file():
self.log.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")
LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")

else:
self.log.info(f"Starting to download: '{link}'")
LOGGER.info(f"Starting to download: '{link}'")
with open(output, "wb") as f:
response = requests.get(link, stream=True)
total_length = response.headers.get('content-length')
Expand All @@ -175,8 +175,8 @@ def download_via_name_v0(self, name: str) -> None:
f.write(data)
pbar.update(len(data))
except:
self.log.error(f"Failed downloading file from '{link}'")
self.log.info(f"File saved to: '{output}'")
LOGGER.error(f"Failed downloading file from '{link}'")
LOGGER.info(f"File saved to: '{output}'")


def download_via_name(self, name: str) -> None:
Expand All @@ -201,19 +201,19 @@ def download_via_name(self, name: str) -> None:

if not os.path.exists(os.path.dirname(output)):
os.makedirs(os.path.dirname(output),exist_ok=True)
self.log.info(f"Creating Folder {os.path.dirname(output)}")
LOGGER.info(f"Creating Folder {os.path.dirname(output)}")

if Path(output).is_file():
self.log.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")
LOGGER.info(f"File: '{output}' exists already. File is not overwritten and nothing is downloaded.")

else:
self.log.info(f"Starting to download: '{blob_url}'")
LOGGER.info(f"Starting to download: '{blob_url}'")
# disabling logging info messages from Azure package as there are too many
logging.disable(logging.INFO)
self.display_azure_download_progress(blob_client, blob_url, output)
logging.disable(logging.NOTSET)

self.log.info(f"File saved to: '{output}'")
LOGGER.info(f"File saved to: '{output}'")

def display_azure_download_progress(self, blob_client: BlobClient, blob_url: str, output: Path) -> None:
"""
Expand Down Expand Up @@ -248,7 +248,7 @@ def progress_callback(bytes_transferred, total_bytes):

sample_blob.write(download_stream.readall())
except:
self.log.error(f"Failed downloading file from '{blob_url}'")
LOGGER.error(f"Failed downloading file from '{blob_url}'")

if self.display:
pbar.close()

0 comments on commit e2b86d7

Please sign in to comment.