From 13b4ac7dc163900d901ded06f2b2995e23f5187a Mon Sep 17 00:00:00 2001 From: sgbaird Date: Tue, 22 Feb 2022 01:08:14 -0700 Subject: [PATCH] swa links --- crabnet/crabnet_.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crabnet/crabnet_.py b/crabnet/crabnet_.py index 7513c4f..41733a8 100644 --- a/crabnet/crabnet_.py +++ b/crabnet/crabnet_.py @@ -379,6 +379,8 @@ def _train(self): ) pred_v = np.nan_to_num(pred_v) mae_v = mean_absolute_error(true_v, pred_v) + # https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py + # https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/ self.optimizer.update_swa(mae_v) minima.append(self.optimizer.minimum_found)