From 9e0ea232d6d79661746f77c1a76b76a691611851 Mon Sep 17 00:00:00 2001 From: Alessandro Vullo Date: Tue, 17 Sep 2024 16:59:13 +0100 Subject: [PATCH] Encode the reparam sampler. --- trieste/models/gpflux/sampler.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/trieste/models/gpflux/sampler.py b/trieste/models/gpflux/sampler.py index 6435f01e4..4aef9937a 100644 --- a/trieste/models/gpflux/sampler.py +++ b/trieste/models/gpflux/sampler.py @@ -72,6 +72,7 @@ def __init__(self, sample_size: int, model: GPfluxPredictor): ) for _ in range(len(self._model_gpflux.f_layers)) ] + self._encode = lambda x: model.encode(x) @property def _model_gpflux(self) -> tf.Module: @@ -96,7 +97,9 @@ def sample(self, at: TensorType, *, jitter: float = DEFAULTS.JITTER) -> TensorTy tf.debugging.assert_shapes([(at, [..., 1, None])]) tf.debugging.assert_greater_equal(jitter, 0.0) - samples = tf.repeat(at[..., None, :, :], self._sample_size, axis=-3) # [..., S, 1, D] + samples = tf.repeat( + self._encode(at[..., None, :, :]), self._sample_size, axis=-3 + ) # [..., S, 1, D] for i, layer in enumerate(self._model_gpflux.f_layers): if isinstance(layer, LatentVariableLayer): if not self._initialized: