Skip to content

Commit

Permalink
Added FantasizeMixIn and reduced posterior samples for DNN
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Stone committed Sep 12, 2024
1 parent 6d350a4 commit 2d87895
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions obsidian/surrogates/custom_torch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
"""Custom implementations of PyTorch surrogate models using BoTorch API"""

from botorch.models.model import Model
from botorch.models.model import FantasizeMixin
from botorch.models.ensemble import EnsembleModel
from botorch.posteriors.ensemble import Posterior, EnsemblePosterior

import torch.nn as nn
from torch.nn import Module
from torch import Tensor

from typing import TypeVar
TFantasizeMixin = TypeVar("TFantasizeMixin", bound="FantasizeMixin")


class DNNPosterior(EnsemblePosterior):

Expand All @@ -17,7 +22,7 @@ def quantile(self, value: Tensor) -> Tensor:
return self.values.quantile(q=value.to(self.values), dim=-3, interpolation='linear')


class DNN(Model):
class DNN(EnsembleModel, FantasizeMixin):
def __init__(self,
train_X: Tensor,
train_Y: Tensor,
Expand Down Expand Up @@ -60,7 +65,7 @@ def forward(self,

def posterior(self,
X: Tensor,
n_sample: int = 16384,
n_sample: int = 512,
output_indices: list[int] = None,
observation_noise: bool | Tensor = False) -> Posterior:
"""Calculates the posterior distribution of the model at X"""
Expand All @@ -86,3 +91,32 @@ def posterior(self,
def num_outputs(self) -> int:
"""Number of outputs of the model"""
return self._num_outputs

def transform_inputs(self,
X: Tensor,
input_transform: Module = None) -> Tensor:
"""
Transform inputs.
Args:
X: A tensor of inputs
input_transform: A Module that performs the input transformation.
Returns:
A tensor of transformed inputs
"""
if input_transform is not None:
input_transform.to(X)
return input_transform(X)
try:
return self.input_transform(X)
except AttributeError:
return X

def condition_on_observations(self,
X: Tensor,
Y: Tensor) -> TFantasizeMixin:
"""
Condition the model to new observations, returning a fantasy model
"""
return self

0 comments on commit 2d87895

Please sign in to comment.