Skip to content

Commit

Permalink
nothing really
Browse files Browse the repository at this point in the history
  • Loading branch information
jkobject committed Jan 17, 2025
1 parent 76ceb58 commit 05aa29f
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 10 deletions.
5 changes: 3 additions & 2 deletions config/ablation_study.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ project: scprint_ablation
seed_everything: 50
set_float32_matmul_precision: True
wandblog: all
ckpt_path: null
log_freq: 200
log_graph: True
trainer:
Expand Down Expand Up @@ -46,7 +47,7 @@ model:
transformer: flash
mvc_decoder: inner product
residual_in_fp32: True
checkpointing: True
checkpointing: False
cell_specific_blocks: False
fused_dropout_add_ln: False
prenorm: True
Expand Down Expand Up @@ -76,7 +77,7 @@ data:
validation_split: 0.05
test_split: 0.02
batch_size: 64
num_workers: 24
num_workers: 20
hierarchical_clss:
- cell_type_ontology_term_id
- tissue_ontology_term_id
Expand Down
27 changes: 26 additions & 1 deletion config/all_possible_ablations.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
kxtm7jtn
9zdkhlej
4ezrufy9
1964988
4ihs8zbj #697jiesc

##### 2 change the weight scaling
data:
Expand All @@ -15,12 +17,14 @@ data:
hzjj21j9
fopg9qt2
ym72b3do
1965013
joyets6s

#### 3 change how many cells per group before changing groups (twice as less)
trainer:
limit_train_batches: 10000
reload_dataloaders_every_n_epochs: 5
max_epochs: 80
max_epochs: 40
scprint_training:
test_every: 5

Expand All @@ -29,6 +33,9 @@ scprint_training:
f2jxue8p
mh40wxyo
qu8zv48c
1965407
e20jt0en
# check e20jt0en for an issue at 320K

### 4 without replacement

Expand All @@ -40,6 +47,8 @@ data:
oj8i38ur
mygjp7yc
ltt10c1r
1965675
qw48w6p6

#### 5 only scale on cell groups:
data:
Expand All @@ -50,6 +59,8 @@ data:
1890999
h9mvhutk
fvp021zs
1968280
konhvgyn

### 6 only scale on nnz:
data:
Expand All @@ -63,6 +74,8 @@ data:
jmhazi02
sftebnir
lvqytack
1968336
6i2xs1cf

### 7 same but just changing the seed

Expand Down Expand Up @@ -109,6 +122,18 @@ data:
1926500
36za6yxt

#### 10 change how many cells per group before changing groups (twice as less)
trainer:
limit_train_batches: 20000
reload_dataloaders_every_n_epochs: 2
max_epochs: 60
scprint_training:
test_every: 5
data:
weight_scaler: 20
replacement: False
train_oversampling_per_epoch: 0.2

#### 11 all (already done)

#### 12 old (already done)
Expand Down
17 changes: 11 additions & 6 deletions scprint/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def __init__(
self.do_cce = False
self.cce_temp = 0.2
self.lr = 0.0001
self.cce_scale = 0.05
self.cce_scale = 0.1
self.do_ecs = False
self.ecs_threshold = 0.4
self.ecs_scale = 0.05
self.ecs_scale = 0.1
self.do_mvc = False
self.mvc_scale = 1.0
self.class_embd_diss_scale = 0.1
Expand All @@ -125,7 +125,7 @@ def __init__(
self.mean_attn_tot_c = 0
self.do_adv_batch = False
self.run_full_forward = True
self.class_scale = 0.4
self.class_scale = 1
self.zinb_and_mse = False
self.do_next_tp = False
self.do_generate = False
Expand All @@ -136,7 +136,7 @@ def __init__(
self.weight_decay = 0.01
self.optim = "adamW"
self.fused_adam = False
self.lr_reduce_patience = 1
self.lr_reduce_patience = 2
self.lr_reduce_factor = 0.6
self.test_every = 20
self.lr_reduce_monitor = "val_loss"
Expand Down Expand Up @@ -1140,22 +1140,27 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
# manually warm up lr without a scheduler
# making sure that we don't do this during lrfinder
lr_scale = None
prev_lr = None
if (
self.trainer.global_step < self.warmup_duration + self.lrfinder_steps
) and self.lrfinder_steps <= self.trainer.global_step:
for i, pg in enumerate(optimizer.param_groups):
lr_scale = min(
1.0, float(self.trainer.global_step + 1) / self.warmup_duration
)
prev_lr = pg["lr"]
pg["lr"] = lr_scale * self.hparams.lr
for i, pg in enumerate(optimizer.param_groups):
# if pg["lr"] < 2e-5:
# pg["lr"] = 2e-5
self.log("lr_" + str(i), pg["lr"])
if optimizer.param_groups[0]["lr"] > self.hparams.lr:
print(optimizer.param_groups[0]["lr"], self.hparams.lr)
print(lr_scale, self.warmup_duration, self.trainer.global_step)
raise ValueError("OPTIMIZER HAS INCREASED LR. WHYY?")
print(lr_scale, self.warmup_duration, self.trainer.global_step, prev_lr)
if prev_lr is not None:
pg["lr"] = prev_lr
else:
raise ValueError("OPTIMIZER HAS INCREASED LR. WHYY?")

optimizer.step(closure=optimizer_closure)

Expand Down
2 changes: 1 addition & 1 deletion scprint/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
do_cls: bool = True,
do_adv_batch: bool = False,
run_full_forward: bool = False,
lr: float = 0.001,
lr: float = 0.0001,
dropout: float = 0.1,
optim: str = "adamW",
weight_decay: float = 0.01,
Expand Down

0 comments on commit 05aa29f

Please sign in to comment.