Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/cantinilab/scPRINT
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Nov 19, 2024
2 parents 8bc566e + 20e536c commit 3aadbbe
Show file tree
Hide file tree
Showing 16 changed files with 34,679 additions and 73 deletions.
253 changes: 202 additions & 51 deletions figures/generate_GN_fig1.ipynb

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions notebooks/additional/scprint_overfit.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -959,8 +959,7 @@
}
],
"source": [
"trainer.fit(model, dat\n",
"amodule=datamodule)"
"trainer.fit(model, datamodule=datamodule)"
]
},
{
Expand Down Expand Up @@ -1255,7 +1254,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
896 changes: 896 additions & 0 deletions notebooks/additional/tests/bench_denoising.ipynb

Large diffs are not rendered by default.

33,466 changes: 33,466 additions & 0 deletions notebooks/additional/tests/bench_omni.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/additional/update_lamin_or_cellxgene.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"leidenalg>=0.10.0",
"django>=4.0.0",
"scikit-misc>=0.5.0",
"scDataLoader>=1.1.3",
"scDataLoader>=1.6.5",
"GRnnData>=1.1.4",
"BenGRN>=1.2.4",
"gseapy>=0.10.0",
Expand Down
3 changes: 2 additions & 1 deletion scprint/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self.fused_adam = False
self.lr_reduce_patience = 1
self.lr_reduce_factor = 0.6
self.test_every = 1
self.lr_reduce_monitor = "val_loss"
self.name = ""
self.lr = lr
Expand Down Expand Up @@ -1103,7 +1104,7 @@ def on_validation_epoch_end(self):
self.log_adata(
gtclass=self.info, name="validation_part_" + str(self.counter)
)
if (self.current_epoch + 1) % 30 == 0:
if (self.current_epoch + 1) % self.test_every == 0:
self.on_test_epoch_end()

def test_step(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion scprint/tasks/cell_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
Args:
batch_size (int, optional): The size of the batches to be used in the DataLoader. Defaults to 64.
num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 8.
how (str, optional): The method to be used for selecting valid genes. Defaults to "most expr".
how (str, optional): The method to be used for selecting valid genes. Defaults to "random expr".
max_len (int, optional): The maximum length of the gene sequence. Defaults to 1000.
add_zero_genes (int, optional): The number of zero genes to add to the gene sequence. Defaults to 100.
precision (str, optional): The precision to be used in the Trainer. Defaults to "16-mixed".
Expand Down
17 changes: 7 additions & 10 deletions scprint/tasks/grn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,25 +345,22 @@ def aggregate(self, attn, genes):
# / attn.sum(-1).sum(-1).unsqueeze(-1).unsqueeze(-1)
# ) # .view()
if self.head_agg == "mean":
attns = attn.detach() + (attns if attns is not None else 0)
attns = attn + (attns if attns is not None else 0)
elif self.head_agg == "max":
attns = (
torch.max(attn.detach(), attns)
if attns is not None
else attn.detach()
)
attns = torch.max(attn, attns) if attns is not None else attn
elif self.head_agg == "none":
attn = attn.detach()
attn = attn.reshape(attn.shape[0], attn.shape[1], 1)
if attns is not None:
attns = torch.cat((attns, attn), dim=2)
attns = torch.cat((attns, attn.detach().cpu()), dim=2)
else:
attns = attn
attns = attn.detach().cpu()
else:
raise ValueError("head_agg must be one of 'mean', 'max' or 'None'")
if self.head_agg == "mean":
attns = attns / Qs.shape[0]
return attns.cpu().numpy()
return (
attns.detach().cpu().numpy() if self.head_agg != "none" else attns.numpy()
)

def filter(self, adj, gt=None):
if self.filtration == "thresh":
Expand Down
7 changes: 6 additions & 1 deletion scprint/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
do_generate: bool = True,
class_scale: float = 1.5,
mask_ratio: List[float] = [], # 0.3
test_every: int = 1,
warmup_duration: int = 500,
fused_adam: bool = False,
adv_class_scale: float = 0.1,
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(
optim (str): Optimizer to use during training. Defaults to "adamW".
weight_decay (float): Weight decay to apply during optimization. Defaults to 0.01.
name (str): Name of the training mode. Defaults to an empty string. should be an ID for the model
test_every (int): Number of epochs between testing. Defaults to 1.
"""
super().__init__()
self.do_denoise = do_denoise
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
self.do_adv_batch = do_adv_batch
self.run_full_forward = run_full_forward
self.name = name
self.test_every = test_every

def __repr__(self):
return (
Expand Down Expand Up @@ -131,7 +134,8 @@ def __repr__(self):
f"do_cls={self.do_cls}, "
f"do_adv_batch={self.do_adv_batch}, "
f"run_full_forward={self.run_full_forward}), "
f"name={self.name})"
f"name={self.name}, "
f"test_every={self.test_every})"
)

def setup(self, trainer, model, stage=None):
Expand Down Expand Up @@ -165,4 +169,5 @@ def setup(self, trainer, model, stage=None):
model.optim = self.optim
model.weight_decay = self.weight_decay
model.name = self.name
model.test_every = self.test_every
# model.configure_optimizers()
1 change: 1 addition & 0 deletions src/triton
Submodule triton added at 8650b4
94 changes: 92 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
import pytest
import scanpy as sc
import torch
from scdataloader import Preprocessor
from scdataloader import Preprocessor, DataModule
from scdataloader.utils import populate_my_ontology

from scprint import scPrint
from scprint.base import NAME
from scprint.tasks import Denoiser, Embedder, GNInfer
from scprint.trainer import TrainingMode

import lamindb as ln
from lightning.pytorch import Trainer


def test_base():
Expand Down Expand Up @@ -102,7 +106,93 @@ def test_base():
)
grn_adata = grn_inferer(model, adata)
assert "GRN" in grn_adata.varp, "GRN inference failed"
# fit scprint
# make a collection
file = ln.Artifact(adata, description="test file")
file.save()
col = ln.Collection(file, name="test dataset")
col.save()
datamodule = DataModule(
collection_name="test dataset",
gene_embeddings=os.path.join(os.path.dirname(__file__), "test_emb.parquet"),
all_clss=[
"sex_ontology_term_id",
"organism_ontology_term_id",
],
hierarchical_clss=[],
organisms=["NCBITaxon:9606"], # , "NCBITaxon:10090"],
how="most expr",
max_len=200,
add_zero_genes=0,
# how much more you will see the most present vs less present category
weight_scaler=10,
clss_to_weight=["sex_ontology_term_id"],
clss_to_pred=[
"sex_ontology_term_id",
"organism_ontology_term_id",
],
batch_size=1,
num_workers=1,
# train_oversampling=2,
validation_split=0.1,
do_gene_pos=False,
test_split=0.1,
)
_ = datamodule.setup()
model = scPrint(
genes=datamodule.genes,
d_model=64,
nhead=1,
nlayers=1,
# layers_cls = [d_model],
# labels = datamodule.labels,
# cls_hierarchy = datamodule.cls_hierarchy,
dropout=0,
transformer="normal",
precpt_gene_emb=os.path.join(os.path.dirname(__file__), "test_emb.parquet"),
mvc_decoder="inner product",
fused_dropout_add_ln=False,
checkpointing=False,
)
trainingmode = TrainingMode(
do_denoise=True,
noise=[0.1],
do_cce=False,
do_ecs=False,
do_cls=True,
do_mvc=True,
mask_ratio=[],
warmup_duration=10,
lr_reduce_patience=10,
test_every=10_000,
)
trainer = Trainer(
gradient_clip_val=500,
max_time={"minutes": 4},
limit_val_batches=1,
callbacks=[trainingmode],
accumulate_grad_batches=1,
check_val_every_n_epoch=1,
overfit_batches=1,
max_epochs=20,
reload_dataloaders_every_n_epochs=100_000,
logger=None,
num_sanity_val_steps=0,
max_steps=100,
)
initial_loss = None
for i in range(2):
trainer.fit(model, datamodule=datamodule)
trainer.fit_loop.max_epochs = 20 * (
i + 2
) # Reset max_epochs for next iteration
current_loss = trainer.callback_metrics.get("train_loss")
if initial_loss is None:
initial_loss = current_loss
else:
assert (
current_loss < initial_loss
), f"Loss not decreasing: initial {initial_loss}, current {current_loss}"
initial_loss = current_loss
# cli
# get_Seq
# sinkhorn
Expand Down
Binary file added tests/test_emb.parquet
Binary file not shown.
2 changes: 1 addition & 1 deletion tools/Geneformer
Submodule Geneformer updated from acb8c7 to 1e574f
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 3aadbbe

Please sign in to comment.