Skip to content

Commit

Permalink
Completed fantasize method for DNN surrogate
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Stone committed Sep 17, 2024
1 parent f49a8b3 commit 39dce24
Showing 1 changed file with 36 additions and 2 deletions.
38 changes: 36 additions & 2 deletions obsidian/surrogates/custom_torch.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""Custom implementations of PyTorch surrogate models using BoTorch API"""

from .utils import fit_pytorch

from obsidian.config import TORCH_DTYPE

from botorch.models.model import FantasizeMixin
from botorch.models.ensemble import EnsembleModel
from botorch.models.ensemble import EnsembleModel, Model
from botorch.posteriors.ensemble import Posterior, EnsemblePosterior

import torch
import torch.nn as nn
from torch.nn import Module
from torch import Tensor
Expand Down Expand Up @@ -38,6 +43,13 @@ def __init__(self,
if p_dropout < 0 or p_dropout > 1:
raise ValueError("p_dropout must be in [0, 1]")

self.register_buffer('train_X', train_X)
self.register_buffer('train_Y', train_Y)
self.register_buffer('p_dropout', torch.tensor(p_dropout, dtype=TORCH_DTYPE))
self.register_buffer('h_width', torch.tensor(h_width, dtype=torch.int))
self.register_buffer('h_layers', torch.tensor(h_layers, dtype=torch.int))
self.register_buffer('num_outputs', torch.tensor(num_outputs, dtype=torch.int))

self.input_layer = nn.Sequential(
nn.Linear(train_X.shape[-1], h_width),
nn.PReLU(),
Expand All @@ -54,6 +66,7 @@ def __init__(self,

self.outer_layer = nn.Linear(h_width, num_outputs)
self._num_outputs = num_outputs
self.to(TORCH_DTYPE)

def forward(self,
x: Tensor) -> Tensor:
Expand Down Expand Up @@ -119,4 +132,25 @@ def condition_on_observations(self,
"""
Condition the model to new observations, returning a fantasy model
"""
return self

X_c = torch.concat((self.train_X, X), axis=0)
Y_c = torch.concat((self.train_Y, Y), axis=0)

# Create a new model based on the current one
fantasy = self.__class__(train_X=X_c, train_Y=Y_c,
p_dropout=float(self.p_dropout),
h_width=int(self.h_width), h_layers=int(self.h_layers),
num_outputs=int(self.num_outputs))

# Fit to the new data
fit_pytorch(fantasy, X_c, Y_c)

return fantasy

def fantasize(self,
X: Tensor) -> Model:

Y_f = self.forward(X).detach()
fantasy = self.condition_on_observations(X, Y_f)

return fantasy

0 comments on commit 39dce24

Please sign in to comment.