Skip to content

Commit

Permalink
Added X_baseline to SF kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin Stone committed Sep 17, 2024
1 parent 8831fa3 commit f49a8b3
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions obsidian/acquisition/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ class qSpaceFill(MCAcquisitionFunction):
"""
def __init__(self,
model: Model,
X_baseline: Tensor,
sampler: MCSampler | None = None,
objective: MCAcquisitionObjective | None = None,
posterior_transform: PosteriorTransform | None = None,
X_pending: Tensor | None = None):
X_pending: Tensor | None = None,):

if sampler is None:
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([512]))
Expand All @@ -80,6 +81,8 @@ def __init__(self,

super().__init__(model=model, sampler=sampler, objective=objective,
posterior_transform=posterior_transform, X_pending=X_pending)

self.register_buffer('X_baseline', X_baseline)

@t_batch_mode_transform()
def forward(self,
Expand All @@ -88,7 +91,7 @@ def forward(self,
Evaluate the acquisition function on the candidate set x
"""
# x dimensions: b * q * d
x_train = self.model.train_inputs[0][0] # train_inputs is a list of tuples
x_train = self.X_baseline

# For sequential mode, add pending data points to "train"
if self.X_pending is not None:
Expand Down

0 comments on commit f49a8b3

Please sign in to comment.