diff --git a/utils/loss.py b/utils/loss.py index 29c1a93..49e01f7 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -77,7 +77,6 @@ def __init__(self, loss, reducer): f"Using Loss scaling with scaling range: {self.loss_scaling_range}" ) - @lru_cache(maxsize=32) def get_weights( self, shape: torch.Size, dtype: torch.dtype, device: torch.device ) -> Tensor: @@ -130,6 +129,7 @@ def init_criterion(args): else: loss = loss(reduction=args.reduction if args.loss_scaling is None else "none") + # TODO: Add decay to loss scaling if args.loss_scaling is None: return loss elif args.loss_scaling == "normal-scaling":