Skip to content

Commit

Permalink
Stabilized mean prediction of ensemble surrogates
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Stone committed Sep 12, 2024
1 parent 2d87895 commit abdcc19
Showing 1 changed file with 11 additions and 16 deletions.
27 changes: 11 additions & 16 deletions obsidian/surrogates/botorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from .base import SurrogateModel
from .config import model_class_dict
from .utils import fit_pytorch

from obsidian.utils import tensordict_to_dict, dict_to_tensordict
from obsidian.exceptions import SurrogateFitError
Expand All @@ -10,11 +11,11 @@
from botorch.fit import fit_gpytorch_mll
from botorch.optim.fit import fit_gpytorch_mll_torch, fit_gpytorch_mll_scipy
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.ensemble import EnsembleModel
from gpytorch.mlls import ExactMarginalLogLikelihood

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import warnings
Expand Down Expand Up @@ -156,20 +157,7 @@ def fit(self,
raise SurrogateFitError('BoTorch model failed to fit')
else:
self.loss_fcn = nn.MSELoss()
self.optimizer = optim.Adam(self.torch_model.parameters(), lr=1e-2)

self.torch_model.train()
for epoch in range(200):
self.optimizer.zero_grad()
output = self.torch_model(X_p)
loss = self.loss_fcn(output, y_p)
loss.backward()
self.optimizer.step()

if (epoch % 50 == 0 and self.verbose):
print(f'Epoch {epoch}: Loss {loss.item()}')

self.torch_model.eval()
fit_pytorch(self.torch_model, X_p, y_p, loss_fcn=self.loss_fcn, verbose=self.verbose)

self.is_fit = True

Expand Down Expand Up @@ -229,7 +217,14 @@ def predict(self,
X_p = self._prepare(X)

pred_posterior = self.torch_model.posterior(X_p)
mu = pred_posterior.mean.detach().cpu().squeeze(-1)

# We would prefer to have stability in the mean of ensemble models,
# So, we will not re-sample for prediction but use forward methods
if isinstance(self.torch_model, EnsembleModel):
mu = self.torch_model.forward(X_p).detach()
else:
mu = pred_posterior.mean.detach().cpu().squeeze(-1)

if q is not None:
if (q < 0) or (q > 1):
raise ValueError('Quantile must be between 0 and 1')
Expand Down

0 comments on commit abdcc19

Please sign in to comment.