Skip to content

Commit

Permalink
swa links
Browse files Browse the repository at this point in the history
  • Loading branch information
sgbaird committed Feb 22, 2022
1 parent 9369c34 commit 13b4ac7
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions crabnet/crabnet_.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,8 @@ def _train(self):
)
pred_v = np.nan_to_num(pred_v)
mae_v = mean_absolute_error(true_v, pred_v)
# https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
# https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/
self.optimizer.update_swa(mae_v)
minima.append(self.optimizer.minimum_found)

Expand Down

0 comments on commit 13b4ac7

Please sign in to comment.