Skip to content

Commit

Permalink
Improved fitting of pytorch surrogates with auto-stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Stone committed Sep 12, 2024
1 parent d98840f commit 90516f4
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions obsidian/surrogates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
from .custom_GP import *
from .custom_torch import *
from .config import *
from .utils import *
86 changes: 86 additions & 0 deletions obsidian/surrogates/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Utility functions for surrogate model handling"""

import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Module, Parameter
import torch.optim as optim


def check_parameter_change(params: list[Parameter],
prev_params: list[Parameter]) -> float:
"""Check the average change in model parameters."""

changes = []
for param, prev_param in zip(params, prev_params):
changes.append(torch.norm(param.data - prev_param.data))

total_change = sum(changes)
avg_change = total_change / len(list(params))

return avg_change


def fit_pytorch(model: Module,
X: Tensor,
y: Tensor,
loss_fcn: Module | None = None,
verbose: bool = False,
max_iter: int = 5000) -> None:
"""
Fits a PyTorch model to the given input data.
Args:
model (Module): The PyTorch model to be trained.
X (Tensor): The input data.
y (Tensor): The target data.
loss_fcn (Module, optional): The loss function to be used.
If not provided, the Mean Squared Error (MSE) loss function will
be used. Defaults to ``None``.
verbose (bool, optional): Whether to print the loss at each epoch.
Defaults to ``False``.
"""

if loss_fcn is None:
loss_fcn = nn.MSELoss()

optimizer = optim.Adam(model.parameters(),
lr=1e-2,
weight_decay=1e-2)

# Set up early stoppping
avg_change = 1e10
param_tol = 1e-2
converge_buffer = 5
converge_count = 0

model.train()
for epoch in range(max_iter):

prev_params = [param.clone().detach() for param in model.parameters()]
optimizer.zero_grad()
output = model(X)
loss = loss_fcn(output, y)
loss.backward()
optimizer.step()

# Check change
current_params = [param.clone().detach() for param in model.parameters()]
avg_change = check_parameter_change(current_params, prev_params)

if (epoch % 50 == 0 and verbose):
print(f'Epoch {epoch}: Loss {loss.item()}')

# If the parameter change is under tolerance, increment the counter
if avg_change < param_tol:
converge_count += 1
else:
converge_count = 0

# Auto-converge when the count reaches the desired length
if converge_count == converge_buffer:
if verbose:
print(f'Converged at epoch {epoch}')
break

model.eval()
return

0 comments on commit 90516f4

Please sign in to comment.