From d1fff6eb7f4b42fba5610ea4b961528211cf1249 Mon Sep 17 00:00:00 2001 From: ancestor-mithril Date: Fri, 10 Jan 2025 18:53:38 +0200 Subject: [PATCH] Added cutmix and mixup to pipeline --- experiment_runner.py | 1 + main.py | 6 ++++++ utils/trainer.py | 17 ++++++++++++----- utils/transforms.py | 8 ++++++++ 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/experiment_runner.py b/experiment_runner.py index 1682467..ef185dd 100644 --- a/experiment_runner.py +++ b/experiment_runner.py @@ -109,6 +109,7 @@ def create_run( f" --disable_progress_bar" f" --stderr" f" --verbose" + f" --cutmix_mixup" ) + ( " --half" if torch.cuda.is_available() else "" ) + ( diff --git a/main.py b/main.py index 83325fe..03f311b 100644 --- a/main.py +++ b/main.py @@ -74,6 +74,12 @@ default=False, help="log to stderr instead of stdout", ) + parser.add_argument( + "--cutmix_mixup", + action="store_true", + default=False, + help="use cutmix and mixup", + ) args = parser.parse_args() args.scheduler_params = json.loads(args.scheduler_params.replace("'", '"')) diff --git a/utils/trainer.py b/utils/trainer.py index 8c9731f..8392963 100644 --- a/utils/trainer.py +++ b/utils/trainer.py @@ -2,6 +2,7 @@ import re from datetime import datetime from functools import cached_property +from typing import Tuple import torch from timed_decorator.simple_timed import timed @@ -19,6 +20,7 @@ from utils.loss import init_criterion from utils.optimizer import init_optimizer from utils.scheduler import init_scheduler +from utils.transforms import init_cutmix_or_mixup from utils.utils import seed_everything, try_optimize @@ -53,6 +55,7 @@ def __init__(self, args): self.train_loader, self.test_loader = init_loaders( args, self.train_dataset, self.test_dataset, pin_memory ) + self.cutmix_or_mixup = init_cutmix_or_mixup(args.cutmix_mixup, self.train_dataset.num_classes) self.model = init_model(args, self.train_dataset.num_classes).to(self.device) @@ -170,8 +173,7 @@ def train(self): total_loss = 0.0 for inputs, targets in self.train_loader: - inputs = self.prepare_inputs(inputs, self.device) - targets = targets.to(self.device, non_blocking=True) + inputs, targets = self.prepare_data(inputs, targets, self.device) with torch.autocast(self.device.type, enabled=self.args.half): outputs = self.model(inputs) @@ -185,6 +187,8 @@ def train(self): total_loss += loss.item() predicted = outputs.argmax(1) total += targets.size(0) + if targets.ndim > 1: + targets = targets.argmax(1) correct += predicted.eq(targets).sum().item() return {"Train/Accuracy": 100.0 * correct / total, "Train/Loss": total_loss} @@ -286,8 +290,11 @@ def maybe_clip(self): self.scaler.unscale_(self.optimizer) clip_grad_norm_(self.model.parameters(), self.args.clip_value) - def prepare_inputs(self, x: Tensor, device: torch.device) -> Tensor: + def prepare_data(self, x: Tensor, y: Tensor, device: torch.device) -> Tuple[Tensor, Tensor]: x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) if self.batch_transforms_device is not None: - return self.batch_transforms_device(x) - return x + x = self.batch_transforms_device(x) + if self.cutmix_or_mixup is not None: + x, y = self.cutmix_or_mixup(x, y) + return x, y diff --git a/utils/transforms.py b/utils/transforms.py index 6cac0df..68f532d 100644 --- a/utils/transforms.py +++ b/utils/transforms.py @@ -222,3 +222,11 @@ def init_transforms(args) -> DatasetTransforms: if args.dataset in ("MNIST", "DirtyMNIST"): return MNISTTransforms(args) raise NotImplementedError(f"Transforms not implemented for {args.dataset}") + + +def init_cutmix_or_mixup(use_cutmix_or_mixup, num_classes): + if use_cutmix_or_mixup: + cutmix = v2.CutMix(num_classes=num_classes) + mixup = v2.MixUp(num_classes=num_classes) + return v2.RandomChoice([cutmix, mixup]) + return None