Skip to content

Commit

Permalink
Merge pull request #13 from cantinilab/dev
Browse files Browse the repository at this point in the history
Dev Merge
  • Loading branch information
jkobject authored Jan 9, 2025
2 parents b28aada + 06d6282 commit 464b535
Show file tree
Hide file tree
Showing 57 changed files with 8,439 additions and 6,918 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -184,5 +184,7 @@ tests/data/step_0__predict_part_0_0.h5ad
data/human_dcm_hcm_nf/
data/geneformertest.csv
clf_omni.pkl
metrics__step0.json
metrics__*.json
config/torm.txt
torm.json
notebooks/additional/figures/
3 changes: 0 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
[submodule "RNABERT"]
path = scprint/tokenizers/RNABERT
url = https://github.com/jkobject/RNABERT
[submodule "tools/DeepSEM"]
path = tools/DeepSEM
url = https://github.com/jkobject/DeepSEM
Expand Down
1 change: 0 additions & 1 deletion RNABERT
Submodule RNABERT deleted from 9f411c
106 changes: 106 additions & 0 deletions config/ablation_study.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
project: scprint_ablation
seed_everything: 50
set_float32_matmul_precision: True
wandblog: all
log_freq: 200
log_graph: True
trainer:
precision: 16-mixed
strategy: ddp_find_unused_parameters_true
gradient_clip_val: 40
log_every_n_steps: 100
limit_train_batches: 5000
gradient_clip_algorithm: norm
limit_val_batches: 4000
limit_test_batches: 1 # we don't perform tests this way
reload_dataloaders_every_n_epochs: 20
max_epochs: 21
accumulate_grad_batches: 1
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: ${project}
save_dir: /lustre/fswork/projects/rech/xeg/uat95fg/ #/pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/ #/data/log/
offline: False
callbacks:
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
init_args:
swa_lrs: 0.03
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
save_top_k: 3
save_last: True
scprint_training:
run_full_forward: False
# noise: [0.6]
do_ecs: False
do_denoise: False
class_embd_diss_scale: 0.1
do_generate: True
test_every: 5
do_cce: False
mask_ratio: 0.3 #["TF"]
model:
dropout: 0.1
num_heads_kv: 2
transformer: flash
mvc_decoder: inner product
residual_in_fp32: True
checkpointing: True
cell_specific_blocks: True
fused_dropout_add_ln: False
prenorm: True
fused_mlp: False
fused_bias_fc: False
drop_path_rate: 0
freeze_embeddings: True
normalization: log
pred_embedding:
- cell_type_ontology_term_id
data:
organisms:
- NCBITaxon:9606
gene_position_tolerance: 10_000
gene_embeddings: /lustre/fswork/projects/rech/xeg/uat95fg/gene_embeddings.parquet
collection_name: scPRINT-V2 (100 random humans) #scPRINT-V2 (good quality)
how: random expr
max_len: 2200
pin_memory: True
prefetch_factor: 3
# metacell_mode: 0.2
weight_scaler: 200
do_gene_pos: ./data/main/biomart_pos.parquet
add_zero_genes: 0
train_oversampling_per_epoch: 0.15
validation_split: 0.05
test_split: 0.02
batch_size: 32
num_workers: 20
hierarchical_clss:
- cell_type_ontology_term_id
- tissue_ontology_term_id
- disease_ontology_term_id
- age_group
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
clss_to_weight:
- clust_cell_type
- tissue_ontology_term_id
- disease_ontology_term_id
- age_group
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
- cell_culture
# - nnz
clss_to_predict:
- cell_type_ontology_term_id
- tissue_ontology_term_id
- disease_ontology_term_id
- age_group
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
127 changes: 127 additions & 0 deletions config/all_possible_ablations.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
######
scprint_training:
run_full_forward: True
noise: [0.6]
do_denoise: True
class_embd_diss_scale: 0.1
do_generate: True
do_cce: True

###### TF masking
scprint_training:
mask_ratio: ["TF"]

#####
scprint_training:
do_ecs: True

##### Test cheap vs esm2 vs learnt


##### untrained model’s performance
trainer:
limit_train_batches: 1
limit_val_batches: 200
limit_test_batches: 1 # we don't perform tests this way
max_epochs: 2
scprint_training:
test_every: 1

##### Test attention bias faded vs none (MAKE SURE RIGHT DIR (T))
model:
attn_bias: full

#####
model:
dropout: 0
num_heads_kv: 4

#####
model:
cell_specific_blocks: False

#####
model:
freeze_embeddings: False

#####
model:
normalization: sum

##### no zinb vs mse vs both
model:
zinb: False

##### no zinb vs mse vs both
scprint_training:
zinb_and_mse: True

####
model:
depth_atinput: True

####
scprint_training:
do_mvc: True

##### need to run them until the end
data:
organisms:
- NCBITaxon:9606
- NCBITaxon:10090
collection_name: scPRINT-V2 (good quality)

##### need to run them until the end
data:
organisms:
- NCBITaxon:9606
- NCBITaxon:10090
collection_name: scPRINT-V2 full

#####
data:
gene_position_tolerance: 1000

#####
data:
do_gene_pos: False

#####
data:
weight_scaler: 3000

##### Test multi context vs fixed
scprint_training:
var_context_length: True
data:
max_len: 3200

######
data:
metacell_mode: 0.2

#####
data:
clss_to_weight:
- clust_cell_type
- tissue_ontology_term_id
- disease_ontology_term_id
- age_group
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
- cell_culture
- nnz

#####
data:
clss_to_weight:
- cell_type_ontology_term_id
- tissue_ontology_term_id
- disease_ontology_term_id
- age_group
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
65 changes: 22 additions & 43 deletions config/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,34 @@ log_freq: 200
log_graph: True
trainer:
precision: 16-mixed
# profiler: simple
gradient_clip_val: 100
log_every_n_steps: 100
limit_train_batches: 7000
limit_val_batches: 2000
limit_train_batches: 20000
limit_val_batches: 4000
limit_test_batches: 1 # we don't perform tests this way
reload_dataloaders_every_n_epochs: 1
reload_dataloaders_every_n_epochs: 5
accumulate_grad_batches: 1
max_time:
hours: 71
logger:
- class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: ${project}
save_dir: data/log/
offline: True
save_dir: /pasteur/zeus/projets/p02/ml4ig_hot/Users/jkalfon/ #/data/log/
offline: False
callbacks:
- class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
init_args:
swa_lrs: 0.03
# - class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
# init_args:
# swa_lrs: 0.03
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: val_loss
save_top_k: 20
save_top_k: 10
save_last: True

#- class_path: lightning.pytorch.callbacks.LearningRateFinder
#init_args:
# mode: exponential
#plugins:
# - class_path: lightning.pytorch.plugins.environments.SLURMEnvironment
# requeue_signal: signal.SIGHUP
model:
dropout: 0.1
transformer: flash
mvc_decoder: inner product
residual_in_fp32: True
depth_atinput: True
num_heads_kv: null
fused_dropout_add_ln: False
prenorm: True
Expand All @@ -53,56 +44,44 @@ model:
freeze_embeddings: True
pred_embedding:
- cell_type_ontology_term_id
- disease_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
# - disease_ontology_term_id
# - self_reported_ethnicity_ontology_term_id
# - sex_ontology_term_id
data:
organisms:
- NCBITaxon:9606
- NCBITaxon:10090
gene_position_tolerance: 10_000
gene_embeddings: ./data/main/gene_embeddings.parquet
collection_name: all no zhang13M #preprocessed dataset
collection_name: all #preprocessed dataset, all no zhang13M
how: random expr
max_len: 2200
weight_scaler: 50
pin_memory: True
prefetch_factor: 3
weight_scaler: 100
do_gene_pos: ./data/main/biomart_pos.parquet
add_zero_genes: 0
train_oversampling_per_epoch: 0.3
validation_split: 0.02
test_split: 0.02
train_oversampling_per_epoch: 0.2
validation_split: 0.05
test_split: 0.05
batch_size: 64
num_workers: 12
# TODO: drop tissue & dev stage until part or is taken in account
num_workers: 16
hierarchical_clss:
- cell_type_ontology_term_id
#- tissue_ontology_term_id
- disease_ontology_term_id
#- development_stage_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
clss_to_weight:
- cell_type_ontology_term_id
# - tissue_ontology_term_id
- disease_ontology_term_id
# - development_stage_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
# - cell_culture
all_clss:
clss_to_predict:
- cell_type_ontology_term_id
# - tissue_ontology_term_id
- disease_ontology_term_id
# - development_stage_ontology_term_id
- assay_ontology_term_id
- self_reported_ethnicity_ontology_term_id
- sex_ontology_term_id
- organism_ontology_term_id
#- heat_diff
#- total_counts
#- nnz
#- dpt_group
#- dataset_id
#- cell_culture
Loading

0 comments on commit 464b535

Please sign in to comment.