Skip to content

Commit

Permalink
Merge pull request #19 from anton-bushuiev/main
Browse files Browse the repository at this point in the history
Improve training submission scripts
  • Loading branch information
anton-bushuiev authored Jun 3, 2024
2 parents 07db264 + 7ae8b93 commit 25fd3cc
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 10 deletions.
4 changes: 4 additions & 0 deletions massspecgym/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def configure_optimizers(self):
self.parameters(), lr=self.lr, weight_decay=self.weight_decay
)

def get_checkpoint_monitors(self) -> list[dict]:
monitors = [{"monitor": "val_loss", "mode": "min", "early_stopping": True}]
return monitors

def _update_metric(
self,
name: str,
Expand Down
2 changes: 1 addition & 1 deletion massspecgym/models/retrieval/deepsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def step(

# Log loss
self.log(
metric_pref + "loss_step",
metric_pref + "loss",
loss,
batch_size=x.size(0),
sync_dist=True,
Expand Down
2 changes: 1 addition & 1 deletion massspecgym/models/retrieval/fingerprint_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def step(

# Log loss
self.log(
metric_pref + "loss_step",
metric_pref + "loss",
loss,
batch_size=x.size(0),
sync_dist=True,
Expand Down
16 changes: 12 additions & 4 deletions scripts/submit_train_retrieval.sh → scripts/submit.sh
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
#!/bin/bash

# Get submission running time as first positional argument
# Expected format: HH:MM
# Get the name of the file and time limit as positional arguments
# Expected format for time: HH:MM
if [ -z "$1" ]
then
echo "Error: No submission script name provided."
exit 1
else
script_name="$1"
fi

if [ -z "$2" ]
then
time="48:00:00"
else
time="${1}:00"
time="${2}:00"
fi

# Init logging dir and common file
Expand All @@ -30,7 +38,7 @@ job_id=$(sbatch \
--error="${outdir}/${job_key}"_errout.txt \
--job-name="${job_key}" \
--export=job_key="${job_key}" \
"${train_dir}/train_retrieval.sh"
"${train_dir}/${script_name}"
)

# Extract job ID from the output
Expand Down
39 changes: 36 additions & 3 deletions scripts/train_retrieval.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import argparse
from pathlib import Path

import pandas as pd

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from massspecgym.data import RetrievalDataset, MassSpecDataModule
from massspecgym.transforms import MolFingerprinter, SpecBinner
Expand All @@ -12,6 +14,9 @@

parser = argparse.ArgumentParser()

# Submission
parser.add_argument('--job_key', type=str, required=True)

# Experiment setup
parser.add_argument('--run_name', type=str, required=True)
parser.add_argument('--project_name', type=str, default='MassSpecGymRetrieval')
Expand All @@ -25,6 +30,7 @@
parser.add_argument('--accelerator', type=str, default='gpu')
parser.add_argument('--devices', type=int, default=8)
parser.add_argument('--log_every_n_steps', type=int, default=50)
parser.add_argument('--val_check_interval', type=float, default=1.0)

# General hyperparameters
parser.add_argument('--batch_size', type=int, default=64)
Expand Down Expand Up @@ -69,23 +75,47 @@ def main(args):
)

# Init logger
# You may need to run wandb init first to use the wandb logger
if args.no_wandb:
logger = None
else:
logger = pl.loggers.WandbLogger(
name=args.run_name,
project=args.project_name,
log_model=False,
config=args
)

# Init callbacks for checkpointing and early stopping
callbacks = []
for i, monitor in enumerate(model.get_checkpoint_monitors()):
monitor_name = monitor['monitor']
checkpoint = pl.callbacks.ModelCheckpoint(
monitor=monitor_name,
save_top_k=1,
mode=monitor['mode'],
dirpath=Path(args.project_name) / args.job_key,
filename=f'{{step:06d}}-{{{monitor_name}:03.03f}}',
auto_insert_metric_name=True,
save_last=(i == 0)
)
callbacks.append(checkpoint)
if monitor.get('early_stopping', False):
early_stopping = EarlyStopping(
monitor=monitor_name,
mode=monitor['mode'],
verbose=True
)
callbacks.append(early_stopping)

# Init trainer
trainer = Trainer(
accelerator=args.accelerator,
devices=args.devices,
max_epochs=args.max_epochs,
logger=logger,
log_every_n_steps=args.log_every_n_steps
log_every_n_steps=args.log_every_n_steps,
val_check_interval=args.val_check_interval,
callbacks=callbacks
)

# Validate before training
Expand All @@ -96,6 +126,9 @@ def main(args):
# Train
trainer.fit(model, datamodule=data_module)

# Test
trainer.test(model, datamodule=data_module)


if __name__ == "__main__":
args = parser.parse_args([] if "__file__" not in globals() else None)
Expand Down
4 changes: 3 additions & 1 deletion scripts/train_retrieval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ export SLURM_GPUS_PER_NODE=8
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

srun --export=ALL --preserve-env python3 train_retrieval.py \
--run_name=debug
--job_key="${job_key}" \
--run_name=debug_0.5_v3 \
--val_check_interval=0.5

0 comments on commit 25fd3cc

Please sign in to comment.