Very short and compact README.
Code made in PyCharm, using python 3.9.5. Creating a new environment and installing packages with requirements.txt
should work.
Otherwise one of the two provided environment configs can be used. environment-base.yaml
is the environment I made everything in, so it quite bloated with unused packages etc.
environment-clean.yaml
provides an environment with only the necessary packages and should work as well.
Main components of code:
Main.py
Run for training. Uses a couple of methods to get all the parameters
-
Argparser
inmain
: (train-specific arguments)num_workers
: (int) number of cpu cores availableroot_dir
: (str) root directory of main.pyname
: (str) name of your training runwandb_project
: (str) name of wandb project for loggingwarmup_batches
: (int) amount of batches the training will start with only pixel-lossgan
: if flag present, training will run with adversarial loss, else notno-checkpointing
: if flag present, only the best checkpoint (based on SSIM) will be saved, else every epoch
-
Argparser
intrainer
(data-specific arguments)std
: (float) std of data, for normalisationmiddle_slices
: (int) number of slices to select from the middle of the volume of a subjectevery_other
: (int) select every x other slices from subjectsampler
: (str) grid of label-sampler, for labelsampler, mask is required
-
Argparser
builtin Pytorch Lightinggpus
: (int) number of gpus availablemax_epochs
: (int) set max amount of epochs
-
config
inmain
(dict containing all possible vars that can be varied)optimizer
: (str) adam or sgd,b1
,b2
: (floats) betas for adam (if needed)batch_size
: (int) batch size for trainingnum_filters
: (int) number of filters in architecture of generatorlearning_rate_G
: (float) learning rate of generator,learning_rate_D
: (float) learning rate of discriminator (if needed),patch_size
: (int) patch size ,alpha_edge
: (float) weighting of edge-loss,alpha_pixel
: (float) weighting of pixel-loss,alpha_perceptual
: (float) weighting of perceptual (VGG) loss,alpha_adversarial
: (float) weighting of adversarial loss (if needed),gan_mode
: (str) kind of adversarial loss (vanilla (RaSGAN), wgan) (if needed)edge_loss
: (int) kind of edge losse (1, 2, 3, seeedgeloss.py
)netD_freq
: (int) frequency the discriminator is trained vs the generator (if needed),patch_overlap
: (float 0-1) percentage of overlap for patches, if gridsampler,generator
: generator architecture (ESRGAN, RRDB, FSRCNN, DeepUResnet, DeepUResnet_v2)nr_sim_train
: (int) number of simulated subjects for trainingnr_hcp_train
: (int) number of hcp subjects for trainingnr_sim_val
: (int) number of simulated subjects for validationnr_hcp_val
: (int) number of simulated subjects for validation
It is meant to run on command line, for example run:
python main.py --gan --name mixed-data --wandb_project example --gpus 1 --log_every_n_steps 500 --max_epochs 25 --no_checkpointing
-
dataset_tio.py
Baseclass for data, dataset is build using TorchIO, exploiting their patch-based pipeline -
trainer_org.py
Accompanying trainer for non-GAN training. Written using Pytorch Lightning and logging to Wandb -
trainer_gan.py
Accompanying trainer for GAN training. Written using Pytorch Lightning and logging to Wandb -
predict.py
Run for SR generation using a model checkpoint. To combine with USM:
USM before SR, turn on augmentation in dataloading.
USM after SR, uncomment couple of lines before saving SR.
It is meant to run on command line, for example run:python predict.py --gan --generator ESRGAN --source sim
-
calculate_scores.py
Run to calculate metrics on generated SR images
It is meant to run on command line, for example run:python calculate_scores.py --source sim
-
sweep.yaml
For a hparam sweep (everything in the config can be searched) using wandb
Data is not present in this repo, but should be located in the data
folder in root
.
├── main.py
├── data # data-folder
│ ├── brain_real_t1w_mri # real data
│ │ ├── MRBrainS18 # Data from MRBrainS18 challenge
│ │ │ ├── GT # 1mm images
│ │ │ └── MSK # segmentations of GT data
│ │ ├── OASIS # Data from OASIS-1
│ │ │ ├── LR # 1mm images
│ │ │ └── MSK # segmentations of LR data
│ │ ├── HCP # data from HCP database
│ │ │ ├── LR # 1mm images
│ │ │ ├── HR # 0.7mm images
│ │ │ └── MSK # segmentations of HR data
│ ├── brain_simulated_t1w_mri # simulated data
│ │ ├── HR_img # 0.7mm images
│ │ ├── HR_msk # segmentations of HR data
│ │ └── LR_img # 1mm images
│ │ └── ...
│ └── ...
├── dataset_tio.py
├── trainer_org.py
├── trainer_gan.py
├── predict.py
├── calculate_scores.py
└── ...
...