Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into khurram/trustregions
Browse files Browse the repository at this point in the history
  • Loading branch information
khurram-ghani committed Aug 2, 2023
2 parents c4ce7ab + a6d3470 commit 20b5687
Show file tree
Hide file tree
Showing 19 changed files with 167 additions and 45 deletions.
27 changes: 27 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,21 @@

from __future__ import annotations

from typing import Iterable

import pytest
from _pytest.config import Config
from _pytest.config.argparsing import Parser
from check_shapes import (
DocstringFormat,
ShapeCheckingState,
get_enable_check_shapes,
get_enable_function_call_precompute,
get_rewrite_docstrings,
set_enable_check_shapes,
set_enable_function_call_precompute,
set_rewrite_docstrings,
)


def pytest_addoption(parser: Parser) -> None:
Expand Down Expand Up @@ -74,3 +86,18 @@ def pytest_collection_modifyitems(config: Config, items: list[pytest.Item]) -> N
import tensorflow as tf

tf.config.experimental_run_functions_eagerly(True)


@pytest.fixture(autouse=True)
def enable_shape_checks() -> Iterable[None]:
# ensure `check_shapes` is always enabled for tests
old_enable = get_enable_check_shapes()
old_rewrite_docstrings = get_rewrite_docstrings()
old_function_call_precompute = get_enable_function_call_precompute()
set_enable_check_shapes(ShapeCheckingState.ENABLED)
set_rewrite_docstrings(DocstringFormat.SPHINX)
set_enable_function_call_precompute(True)
yield
set_enable_function_call_precompute(old_function_call_precompute)
set_rewrite_docstrings(old_rewrite_docstrings)
set_enable_check_shapes(old_enable)
5 changes: 3 additions & 2 deletions tests/unit/acquisition/function/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow.python.autograph.impl.api import StagingError

from tests.util.misc import (
TF_DEBUGGING_ERROR_TYPES,
Expand Down Expand Up @@ -282,7 +283,7 @@ def test_expected_improvement_switches_to_improvement_on_feasible_points() -> No
def test_expected_improvement_raises_for_invalid_batch_size(at: TensorType) -> None:
ei = expected_improvement(QuadraticMeanAndRBFKernel(), tf.constant([1.0]))

with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
with pytest.raises(StagingError):
ei(at)


Expand Down Expand Up @@ -946,7 +947,7 @@ def test_expected_constrained_improvement_raises_for_invalid_batch_size(at: Tens

eci = builder.prepare_acquisition_function({NA: QuadraticMeanAndRBFKernel()}, datasets=data)

with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
with pytest.raises(StagingError):
eci(at)


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/models/gpflow/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes.exceptions import ShapeMismatchError
from scipy import stats

from tests.util.misc import TF_DEBUGGING_ERROR_TYPES, ShapeLike, quadratic, random_seed
Expand Down Expand Up @@ -140,7 +141,7 @@ def test_independent_reparametrization_sampler_sample_raises_for_invalid_at_shap
) -> None:
sampler = IndependentReparametrizationSampler(1, QuadraticMeanAndRBFKernel(), qmc=qmc)

with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
with pytest.raises(ShapeMismatchError):
sampler.sample(tf.zeros(shape))


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/models/gpflux/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
from check_shapes import inherit_check_shapes
from gpflow.conditionals.util import sample_mvn
from gpflux.helpers import construct_basic_inducing_variables, construct_basic_kernel
from gpflux.layers import GPLayer
Expand Down Expand Up @@ -57,6 +58,7 @@ def model_keras(self) -> tf.keras.Model:
def optimizer(self) -> tf.keras.optimizers.Optimizer:
return self._optimizer

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
# Taken from GPflow implementation of `GPModel.predict_f_samples` in gpflow.models.model
mean, cov = self._model_gpflux.predict_f(query_points, full_cov=True)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/objectives/test_multi_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
from check_shapes.exceptions import ShapeMismatchError

from tests.util.misc import TF_DEBUGGING_ERROR_TYPES
from trieste.objectives.multi_objectives import DTLZ1, DTLZ2, VLMOP2, MultiObjectiveTestProblem
from trieste.types import TensorType

Expand Down Expand Up @@ -142,7 +142,7 @@ def test_gen_pareto_front_is_equal_to_math_defined(
def test_func_raises_specified_input_dim_not_align_with_actual_input_dim(
obj_inst: MultiObjectiveTestProblem, actual_x: TensorType
) -> None:
with pytest.raises(TF_DEBUGGING_ERROR_TYPES):
with pytest.raises(ShapeMismatchError):
obj_inst.objective(actual_x)


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_bayesian_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy.testing as npt
import pytest
import tensorflow as tf
from check_shapes import inherit_check_shapes

from tests.util.misc import (
FixedAcquisitionRule,
Expand Down Expand Up @@ -596,6 +597,7 @@ def __init__(self, data: Dataset):
super().__init__()
self._data = data

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, var = super().predict(query_points)
return mean, var / len(self._data)
Expand Down
10 changes: 9 additions & 1 deletion tests/util/models/gpflow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import inherit_check_shapes
from gpflow.models import GPR, SGPR, SVGP, VGP, GPModel
from typing_extensions import Protocol

Expand All @@ -45,6 +46,7 @@
SupportsCovarianceWithTopFidelity,
SupportsGetKernel,
SupportsGetObservationNoise,
SupportsPredictJoint,
)
from trieste.models.optimizer import Optimizer
from trieste.types import TensorType
Expand All @@ -70,6 +72,7 @@ def optimize(self, dataset: Dataset) -> None:
class GaussianMarginal(ProbabilisticModel):
"""A probabilistic model with Gaussian marginal distribution. Assumes events of shape [N]."""

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
mean, var = self.predict(query_points)
samples = tfp.distributions.Normal(mean, tf.sqrt(var)).sample(num_samples)
Expand All @@ -95,10 +98,12 @@ def __init__(
def __repr__(self) -> str:
return f"GaussianProcess({self._mean_functions!r}, {self._kernels!r})"

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = self.predict_joint(query_points[..., None, :])
return tf.squeeze(mean, -2), tf.squeeze(cov, [-2, -1])

@inherit_check_shapes
def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
means = [f(query_points) for f in self._mean_functions]
covs = [k.tensor(query_points, query_points, 1, 1)[..., None, :, :] for k in self._kernels]
Expand All @@ -116,7 +121,7 @@ def covariance_between_points(
return tf.concat(covs, axis=-3)


class GaussianProcessWithoutNoise(GaussianMarginal, HasReparamSampler):
class GaussianProcessWithoutNoise(GaussianMarginal, SupportsPredictJoint, HasReparamSampler):
"""A (static) Gaussian process over a vector random variable with independent reparam sampler
but without noise variance."""

Expand All @@ -131,10 +136,12 @@ def __init__(
def __repr__(self) -> str:
return f"GaussianProcessWithoutNoise({self._mean_functions!r}, {self._kernels!r})"

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = self.predict_joint(query_points[..., None, :])
return tf.squeeze(mean, -2), tf.squeeze(cov, [-2, -1])

@inherit_check_shapes
def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
means = [f(query_points) for f in self._mean_functions]
covs = [k.tensor(query_points, query_points, 1, 1)[..., None, :, :] for k in self._kernels]
Expand Down Expand Up @@ -278,6 +285,7 @@ def covariance_with_top_fidelity(self, x: TensorType) -> TensorType:
mean, _ = self.predict(x)
return tf.ones_like(mean, dtype=mean.dtype) # dummy covariances of correct shape

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
fmean, fvar = self.predict(query_points)
yvar = fvar + tf.constant(1.0, dtype=fmean.dtype) # dummy noise variance
Expand Down
9 changes: 5 additions & 4 deletions trieste/acquisition/function/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import check_shapes

from ...data import Dataset
from ...models import ProbabilisticModel, ReparametrizationSampler
Expand Down Expand Up @@ -212,11 +213,11 @@ def update(self, eta: TensorType) -> None:
self._eta.assign(eta)

@tf.function
@check_shapes(
"x: [N..., 1, D] # This acquisition function only supports batch sizes of one",
"return: [N..., L]",
)
def __call__(self, x: TensorType) -> TensorType:
tf.debugging.assert_shapes(
[(x, [..., 1, None])],
message="This acquisition function only supports batch sizes of one.",
)
mean, variance = self._model.predict(tf.squeeze(x, -2))
normal = tfp.distributions.Normal(mean, tf.sqrt(variance))
return (self._eta - mean) * normal.cdf(self._eta) + variance * normal.prob(self._eta)
Expand Down
24 changes: 22 additions & 2 deletions trieste/acquisition/function/greedy_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import gpflow
import tensorflow as tf
import tensorflow_probability as tfp
from check_shapes import check_shapes
from typing_extensions import Protocol, runtime_checkable

from ...data import Dataset
Expand Down Expand Up @@ -652,14 +653,19 @@ def update_fantasized_data(self, fantasized_data: Dataset) -> None:
self._fantasized_query_points.assign(fantasized_data.query_points)
self._fantasized_observations.assign(fantasized_data.observations)

@check_shapes(
"query_points: [batch..., N, D]",
"return[0]: [batch..., ..., N, L]",
"return[1]: [batch..., ..., N, L]",
)
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""
This function wraps conditional_predict_f. It cannot directly call
conditional_predict_f, since it does not accept query_points with rank > 2.
We use map_fn to allow leading dimensions for query_points.
:param query_points: shape [...*, N, d]
:return: mean, shape [...*, ..., N, L] and cov, shape [...*, ..., L, N],
:return: mean, shape [...*, ..., N, L] and cov, shape [...*, ..., N, L],
where ... are the leading dimensions of fantasized_data
"""

Expand All @@ -671,6 +677,11 @@ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (t

return _broadcast_predict(query_points, fun)

@check_shapes(
"query_points: [batch..., N, D]",
"return[0]: [batch..., ..., N, L]",
"return[1]: [batch..., ..., L, N, N]",
)
def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""
This function wraps conditional_predict_joint. It cannot directly call
Expand All @@ -690,6 +701,10 @@ def fun(qp: TensorType) -> tuple[TensorType, TensorType]: # pragma: no cover (t

return _broadcast_predict(query_points, fun)

@check_shapes(
"query_points: [batch..., N, D]",
"return: [batch..., ..., S, N, L]",
)
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
"""
This function wraps conditional_predict_f_sample. It cannot directly call
Expand All @@ -716,14 +731,19 @@ def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
) # [B, ..., S, L]
return _restore_leading_dim(samples, leading_dim)

@check_shapes(
"query_points: [broadcast batch..., N, D]",
"return[0]: [batch..., ..., N, L]",
"return[1]: [batch..., ..., N, L]",
)
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
"""
This function wraps conditional_predict_y. It cannot directly call
conditional_predict_joint, since it does not accept query_points with rank > 2.
We use tf.map_fn to allow leading dimensions for query_points.
:param query_points: shape [...*, N, D]
:return: mean, shape [...*, ..., N, L] and var, shape [...*, ..., L, N],
:return: mean, shape [...*, ..., N, L] and var, shape [...*, ..., N, L],
where ... are the leading dimensions of fantasized_data
"""

Expand Down
5 changes: 5 additions & 0 deletions trieste/models/gpflow/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import gpflow
import tensorflow as tf
from check_shapes import inherit_check_shapes
from gpflow.models import GPModel
from gpflow.posteriors import BasePosterior, PrecomputeCacheType
from typing_extensions import Protocol
Expand Down Expand Up @@ -93,13 +94,15 @@ def update_posterior_cache(self) -> None:
def model(self) -> GPModel:
"""The underlying GPflow model."""

@inherit_check_shapes
def predict(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = (self._posterior or self.model).predict_f(query_points)
# posterior predict can return negative variance values [cf GPFlow issue #1813]
if self._posterior is not None:
cov = tf.clip_by_value(cov, 1e-12, cov.dtype.max)
return mean, cov

@inherit_check_shapes
def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
mean, cov = (self._posterior or self.model).predict_f(query_points, full_cov=True)
# posterior predict can return negative variance values [cf GPFlow issue #1813]
Expand All @@ -109,9 +112,11 @@ def predict_joint(self, query_points: TensorType) -> tuple[TensorType, TensorTyp
)
return mean, cov

@inherit_check_shapes
def sample(self, query_points: TensorType, num_samples: int) -> TensorType:
return self.model.predict_f_samples(query_points, num_samples)

@inherit_check_shapes
def predict_y(self, query_points: TensorType) -> tuple[TensorType, TensorType]:
return self.model.predict_y(query_points)

Expand Down
Loading

0 comments on commit 20b5687

Please sign in to comment.