Skip to content

Commit

Permalink
Remove superfluous encodings
Browse files Browse the repository at this point in the history
  • Loading branch information
Uri Granta committed Jul 31, 2024
1 parent 25d20ca commit c071112
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 12 deletions.
9 changes: 0 additions & 9 deletions trieste/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
SupportsPredictY,
TrainableProbabilisticModel,
TrajectorySampler,
encode_query_points,
)
from ..optimizer import BatchOptimizer, Optimizer, OptimizeResult
from .inducing_point_selectors import InducingPointSelector
Expand Down Expand Up @@ -165,7 +164,6 @@ def _ensure_variable_model_data(self) -> None:
)

@inherit_check_shapes
@encode_query_points
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
f_mean, f_var = self.predict(query_points)
return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var)
Expand Down Expand Up @@ -355,7 +353,6 @@ def get_internal_data(self) -> Dataset:
"""
return Dataset(self.model.data[0], self.model.data[1])

@encode_query_points
def conditional_predict_f(
self, query_points: TensorType, additional_data: Dataset
) -> tuple[TensorType, TensorType]:
Expand Down Expand Up @@ -421,7 +418,6 @@ def conditional_predict_f(

return mean_qp_new, var_qp_new

@encode_query_points
def conditional_predict_joint(
self, query_points: TensorType, additional_data: Dataset
) -> tuple[TensorType, TensorType]:
Expand Down Expand Up @@ -492,7 +488,6 @@ def conditional_predict_joint(

return mean_qp_new, cov_qp_new

@encode_query_points
def conditional_predict_f_sample(
self, query_points: TensorType, additional_data: Dataset, num_samples: int
) -> TensorType:
Expand All @@ -514,7 +509,6 @@ def conditional_predict_f_sample(
) # [..., (S), P, N]
return tf.linalg.adjoint(samples) # [..., (S), N, L]

@encode_query_points
def conditional_predict_y(
self, query_points: TensorType, additional_data: Dataset
) -> tuple[TensorType, TensorType]:
Expand Down Expand Up @@ -623,7 +617,6 @@ def inducing_point_selector(
return self._inducing_point_selector

@inherit_check_shapes
@encode_query_points
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
f_mean, f_var = self.predict(query_points)
return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var)
Expand Down Expand Up @@ -950,7 +943,6 @@ def inducing_point_selector(self) -> Optional[InducingPointSelector[SparseVariat
return self._inducing_point_selector

@inherit_check_shapes
@encode_query_points
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
f_mean, f_var = self.predict(query_points)
return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var)
Expand Down Expand Up @@ -1267,7 +1259,6 @@ def model(self) -> VGP:
return self._model

@inherit_check_shapes
@encode_query_points
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
f_mean, f_var = self.predict(query_points)
return self.model.likelihood.predict_mean_and_var(query_points, f_mean, f_var)
Expand Down
5 changes: 3 additions & 2 deletions trieste/models/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,8 +765,9 @@ def encoder(self) -> EncoderFunction | None:


def encode_query_points(f: C) -> C:
"""Decorator for encoding query points, applicable to HasEncoder model methods such as predict
whose first argument is query_points.
"""
Decorator for automatically encoding query points, applicable to HasEncoder model methods
such as predict whose first argument is query_points.
"""

@functools.wraps(f)
Expand Down
1 change: 0 additions & 1 deletion trieste/models/keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def predict_ensemble(self, query_points: TensorType) -> tuple[TensorType, Tensor
return predicted_means, predicted_vars

@inherit_check_shapes
@encode_query_points
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
"""
Return ``num_samples`` samples at ``query_points``. We use the mixture approximation in
Expand Down

0 comments on commit c071112

Please sign in to comment.