Skip to content

Commit

Permalink
debug extracting gene networks from cell sep model
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jan 15, 2025
1 parent 47162ea commit 7033a0c
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 81 deletions.
43 changes: 13 additions & 30 deletions scprint/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,40 +633,23 @@ def forward(
bias=bias if self.attn_bias != "none" else None,
bias_layer=list(range(self.nlayers - 1)),
)
if len(get_attention_layer) > 0:
transformer_output, qkvs = transformer_output
if self.cell_transformer:
cell_output = self.cell_transformer(
cell_encoding,
x_kv=transformer_output[0]
if len(get_attention_layer) > 0
else transformer_output,
)
cell_output = self.cell_transformer(cell_encoding, x_kv=transformer_output)
transformer_output = torch.cat([cell_output, transformer_output], dim=1)
# if not provided we will mult by the current expression sum
depth_mult = expression.sum(1) if depth_mult is None else depth_mult
if len(get_attention_layer) > 0:
transformer_output, qkvs = transformer_output
return (
self._decoder(
transformer_output,
depth_mult,
get_gene_emb,
do_sample,
do_mvc,
do_class,
req_depth=req_depth if not self.depth_atinput else None,
),
qkvs,
)
else:
return self._decoder(
transformer_output,
depth_mult,
get_gene_emb,
do_sample,
do_mvc,
do_class,
req_depth=req_depth if not self.depth_atinput else None,
)
res = self._decoder(
transformer_output,
depth_mult,
get_gene_emb,
do_sample,
do_mvc,
do_class,
req_depth=req_depth if not self.depth_atinput else None,
)
return (res, qkvs) if len(get_attention_layer) > 0 else res

def configure_optimizers(self):
"""@see pl.LightningModule"""
Expand Down
102 changes: 51 additions & 51 deletions scprint/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,57 +615,57 @@ def test(
Returns:
None
"""
# metrics = {}
# res = embbed_task.default_benchmark(
# model, default_dataset="lung", do_class=do_class, coarse=False
# )
# f = open("metrics_" + name + ".json", "a")
# f.write(json.dumps({"embed_lung": res}, indent=4))
# f.close()
# metrics.update(
# {
# "emb_lung/scib": float(res["scib"]["Total"]),
# "emb_lung/ct_class": float(
# res["classif"]["cell_type_ontology_term_id"]["accuracy"]
# if do_class
# else 0
# ),
# }
# )
# print(metrics)
# res = embbed_task.default_benchmark(
# model, default_dataset="pancreas", do_class=do_class, coarse=False
# )
# f = open("metrics_" + name + ".json", "a")
# f.write(json.dumps({"embed_panc": res}, indent=4))
# f.close()
# metrics.update(
# {
# "emb_panc/scib": float(res["scib"]["Total"]),
# "emb_panc/ct_class": float(
# res["classif"]["cell_type_ontology_term_id"]["accuracy"]
# if do_class
# else 0
# ),
# }
# )
# print(metrics)
# gc.collect()
# res = denoise_task.default_benchmark(
# model, filedir + "/../../data/gNNpgpo6gATjuxTE7CCp.h5ad"
# )
# metrics.update(
# {
# "denoise/reco2full_vs_noisy2full": float(
# res["reco2full"] - res["noisy2full"]
# ),
# }
# )
# gc.collect()
# print(metrics)
# f = open("metrics_" + name + ".json", "a")
# f.write(json.dumps({"denoise": res}, indent=4))
# f.close()
metrics = {}
res = embbed_task.default_benchmark(
model, default_dataset="lung", do_class=do_class, coarse=False
)
f = open("metrics_" + name + ".json", "a")
f.write(json.dumps({"embed_lung": res}, indent=4))
f.close()
metrics.update(
{
"emb_lung/scib": float(res["scib"]["Total"]),
"emb_lung/ct_class": float(
res["classif"]["cell_type_ontology_term_id"]["accuracy"]
if do_class
else 0
),
}
)
print(metrics)
res = embbed_task.default_benchmark(
model, default_dataset="pancreas", do_class=do_class, coarse=False
)
f = open("metrics_" + name + ".json", "a")
f.write(json.dumps({"embed_panc": res}, indent=4))
f.close()
metrics.update(
{
"emb_panc/scib": float(res["scib"]["Total"]),
"emb_panc/ct_class": float(
res["classif"]["cell_type_ontology_term_id"]["accuracy"]
if do_class
else 0
),
}
)
print(metrics)
gc.collect()
res = denoise_task.default_benchmark(
model, filedir + "/../../data/gNNpgpo6gATjuxTE7CCp.h5ad"
)
metrics.update(
{
"denoise/reco2full_vs_noisy2full": float(
res["reco2full"] - res["noisy2full"]
),
}
)
gc.collect()
print(metrics)
f = open("metrics_" + name + ".json", "a")
f.write(json.dumps({"denoise": res}, indent=4))
f.close()
res = grn_task.default_benchmark(
model, "gwps", batch_size=32 if model.d_model <= 512 else 8
)
Expand Down

0 comments on commit 7033a0c

Please sign in to comment.