Skip to content

Commit

Permalink
Encode the reparam sampler.
Browse files Browse the repository at this point in the history
  • Loading branch information
avullo committed Sep 17, 2024
1 parent dbe623b commit 9e0ea23
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion trieste/models/gpflux/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, sample_size: int, model: GPfluxPredictor):
)
for _ in range(len(self._model_gpflux.f_layers))
]
self._encode = lambda x: model.encode(x)

@property
def _model_gpflux(self) -> tf.Module:
Expand All @@ -96,7 +97,9 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy
tf.debugging.assert_shapes([(at, [..., 1, None])])
tf.debugging.assert_greater_equal(jitter, 0.0)

samples = tf.repeat(at[..., None, :, :], self._sample_size, axis=-3) # [..., S, 1, D]
samples = tf.repeat(
self._encode(at[..., None, :, :]), self._sample_size, axis=-3
) # [..., S, 1, D]
for i, layer in enumerate(self._model_gpflux.f_layers):
if isinstance(layer, LatentVariableLayer):
if not self._initialized:
Expand Down

0 comments on commit 9e0ea23

Please sign in to comment.