Skip to content

Commit

Permalink
Added cutmix and mixup to pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Jan 10, 2025
1 parent ceb34db commit d1fff6e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
1 change: 1 addition & 0 deletions experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def create_run(
f" --disable_progress_bar"
f" --stderr"
f" --verbose"
f" --cutmix_mixup"
) + (
" --half" if torch.cuda.is_available() else ""
) + (
Expand Down
6 changes: 6 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@
default=False,
help="log to stderr instead of stdout",
)
parser.add_argument(
"--cutmix_mixup",
action="store_true",
default=False,
help="use cutmix and mixup",
)

args = parser.parse_args()
args.scheduler_params = json.loads(args.scheduler_params.replace("'", '"'))
Expand Down
17 changes: 12 additions & 5 deletions utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from datetime import datetime
from functools import cached_property
from typing import Tuple

import torch
from timed_decorator.simple_timed import timed
Expand All @@ -19,6 +20,7 @@
from utils.loss import init_criterion
from utils.optimizer import init_optimizer
from utils.scheduler import init_scheduler
from utils.transforms import init_cutmix_or_mixup
from utils.utils import seed_everything, try_optimize


Expand Down Expand Up @@ -53,6 +55,7 @@ def __init__(self, args):
self.train_loader, self.test_loader = init_loaders(
args, self.train_dataset, self.test_dataset, pin_memory
)
self.cutmix_or_mixup = init_cutmix_or_mixup(args.cutmix_mixup, self.train_dataset.num_classes)

self.model = init_model(args, self.train_dataset.num_classes).to(self.device)

Expand Down Expand Up @@ -170,8 +173,7 @@ def train(self):
total_loss = 0.0

for inputs, targets in self.train_loader:
inputs = self.prepare_inputs(inputs, self.device)
targets = targets.to(self.device, non_blocking=True)
inputs, targets = self.prepare_data(inputs, targets, self.device)

with torch.autocast(self.device.type, enabled=self.args.half):
outputs = self.model(inputs)
Expand All @@ -185,6 +187,8 @@ def train(self):
total_loss += loss.item()
predicted = outputs.argmax(1)
total += targets.size(0)
if targets.ndim > 1:
targets = targets.argmax(1)
correct += predicted.eq(targets).sum().item()

return {"Train/Accuracy": 100.0 * correct / total, "Train/Loss": total_loss}
Expand Down Expand Up @@ -286,8 +290,11 @@ def maybe_clip(self):
self.scaler.unscale_(self.optimizer)
clip_grad_norm_(self.model.parameters(), self.args.clip_value)

def prepare_inputs(self, x: Tensor, device: torch.device) -> Tensor:
def prepare_data(self, x: Tensor, y: Tensor, device: torch.device) -> Tuple[Tensor, Tensor]:
x = x.to(device, non_blocking=True)
y = y.to(device, non_blocking=True)
if self.batch_transforms_device is not None:
return self.batch_transforms_device(x)
return x
x = self.batch_transforms_device(x)
if self.cutmix_or_mixup is not None:
x, y = self.cutmix_or_mixup(x, y)
return x, y
8 changes: 8 additions & 0 deletions utils/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,11 @@ def init_transforms(args) -> DatasetTransforms:
if args.dataset in ("MNIST", "DirtyMNIST"):
return MNISTTransforms(args)
raise NotImplementedError(f"Transforms not implemented for {args.dataset}")


def init_cutmix_or_mixup(use_cutmix_or_mixup, num_classes):
if use_cutmix_or_mixup:
cutmix = v2.CutMix(num_classes=num_classes)
mixup = v2.MixUp(num_classes=num_classes)
return v2.RandomChoice([cutmix, mixup])
return None

0 comments on commit d1fff6e

Please sign in to comment.