From b8c2488bb2c21e3d0f3a7f51ce85848a9efd0e7c Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Mon, 9 Sep 2024 22:06:42 +0100 Subject: [PATCH 1/3] Don't one hot encode binary categories --- tests/unit/test_space.py | 15 +++++++++++++-- trieste/space.py | 13 +++++++++++-- 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index 6fb473d13..a39cd0100 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -1759,6 +1759,11 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None: tf.constant([[0], [0]], dtype=tf.float64), tf.constant([[1], [1]], dtype=tf.float64), ), + ( + CategoricalSearchSpace(["Y", "N"]), + tf.constant([[0], [1], [0]], dtype=tf.float64), + tf.constant([[0], [1], [0]], dtype=tf.float64), + ), ( CategoricalSearchSpace(["R", "G", "B"], dtype=tf.float32), tf.constant([[0], [2], [1]], dtype=tf.float32), @@ -1777,13 +1782,13 @@ def test_categorical_search_space__to_tags_raises_for_non_integers() -> None: ( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), tf.constant([[0, 0], [2, 0], [1, 1]], dtype=tf.float64), - tf.constant([[1, 0, 0, 1, 0], [0, 0, 1, 1, 0], [0, 1, 0, 0, 1]], dtype=tf.float64), + tf.constant([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]], dtype=tf.float64), ), ( CategoricalSearchSpace([["R", "G", "B"], ["Y", "N"]]), tf.constant([[[0, 0], [0, 0]], [[2, 0], [1, 1]]], dtype=tf.float64), tf.constant( - [[[1, 0, 0, 1, 0], [1, 0, 0, 1, 0]], [[0, 0, 1, 1, 0], [0, 1, 0, 0, 1]]], + [[[1, 0, 0, 0], [1, 0, 0, 0]], [[0, 0, 1, 0], [0, 1, 0, 1]]], dtype=tf.float64, ), ), @@ -1824,6 +1829,12 @@ def test_categorical_search_space_one_hot_encoding( pytest.param( CategoricalSearchSpace(["Y", "N"]), tf.constant([[0], [2], [1]]), + ValueError, + id="Out of range binary input value", + ), + pytest.param( + CategoricalSearchSpace(["Y", "N", "maybe"]), + tf.constant([[0], [3], [1]]), InvalidArgumentError, id="Out of range input value", ), diff --git a/trieste/space.py b/trieste/space.py index 4a228460c..84fed23a9 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -633,7 +633,14 @@ def tags(self) -> Sequence[Sequence[str]]: @property def one_hot_encoder(self) -> EncoderFunction: - """A one-hot encoder for the numerical indices.""" + """A one-hot encoder for the numerical indices. Note that binary categories + are left unchanged instead of adding an unnecessary second feature.""" + + def binary_encoder(x: TensorType) -> TensorType: + # no need to one-hot encode binary categories (but we should still validate) + if not tf.reduce_all((x == 0) | (x == 1)): + raise ValueError(f"Invalid value {x}") + return x def encoder(x: TensorType) -> TensorType: flat_x, unflatten = flatten_leading_dims(x) @@ -644,7 +651,9 @@ def encoder(x: TensorType) -> TensorType: ) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ - tf.keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") + binary_encoder + if len(ts) == 2 + else tf.keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") for ts in self.tags ] encoded = tf.concat( From b95af67927d9ad5a08a4e7de77bd943213b0d636 Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 10 Sep 2024 09:54:19 +0100 Subject: [PATCH 2/3] Cast encoder --- tests/unit/test_space.py | 17 +++++++++++++++++ trieste/space.py | 26 +++++++++++++++++++++++--- 2 files changed, 40 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_space.py b/tests/unit/test_space.py index a39cd0100..7cbc34248 100644 --- a/tests/unit/test_space.py +++ b/tests/unit/test_space.py @@ -38,6 +38,7 @@ SearchSpace, TaggedMultiSearchSpace, TaggedProductSearchSpace, + cast_encoder, one_hot_encoder, ) from trieste.types import TensorType @@ -1870,3 +1871,19 @@ def test_unbound_search_spaces( space.lower with pytest.raises(AttributeError): space.upper + + +@pytest.mark.parametrize("input_dtype", [None, tf.float64, tf.float32]) +@pytest.mark.parametrize("output_dtype", [None, tf.float64, tf.float32]) +def test_cast_encoder(input_dtype: Optional[tf.DType], output_dtype: Optional[tf.DType]) -> None: + + query_points = tf.constant([1, 2, 3], dtype=tf.int32) + + def add_encoder(x: TensorType) -> TensorType: + assert x.dtype is (input_dtype or tf.int32) + return x + 1 + + encoder = cast_encoder(add_encoder, input_dtype=input_dtype, output_dtype=output_dtype) + points = encoder(query_points) + assert points.dtype is (output_dtype or input_dtype or tf.int32) + npt.assert_array_equal(tf.cast(query_points + 1, points.dtype), points) diff --git a/trieste/space.py b/trieste/space.py index ac78cc97c..41a73f514 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -518,6 +518,24 @@ def one_hot_encoder(space: SearchSpace) -> EncoderFunction: return space.one_hot_encoder if isinstance(space, HasOneHotEncoder) else lambda x: x +def cast_encoder( + encoder: EncoderFunction, + input_dtype: Optional[tf.DType] = None, + output_dtype: Optional[tf.DType] = None, +) -> EncoderFunction: + "A utility function for casting the input and/or output of an encoder." + + def cast_and_encode(x: TensorType) -> TensorType: + if input_dtype is not None: + x = tf.cast(x, input_dtype) + y = encoder(x) + if output_dtype is not None: + y = tf.cast(y, output_dtype) + return y + + return cast_and_encode + + def one_hot_encoded_space(space: SearchSpace) -> SearchSpace: "A bounded search space corresponding to the one-hot encoding of the given space." @@ -651,9 +669,11 @@ def encoder(x: TensorType) -> TensorType: ) columns = tf.split(flat_x, flat_x.shape[-1], axis=1) encoders = [ - binary_encoder - if len(ts) == 2 - else tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") + ( + binary_encoder + if len(ts) == 2 + else tf_keras.layers.CategoryEncoding(num_tokens=len(ts), output_mode="one_hot") + ) for ts in self.tags ] encoded = tf.concat( From 69448cfda5b5a09c419f431ef08d67894dc4a18d Mon Sep 17 00:00:00 2001 From: Uri Granta Date: Tue, 10 Sep 2024 10:11:33 +0100 Subject: [PATCH 3/3] Better exception message --- trieste/space.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trieste/space.py b/trieste/space.py index 41a73f514..4b32181cc 100644 --- a/trieste/space.py +++ b/trieste/space.py @@ -656,8 +656,8 @@ def one_hot_encoder(self) -> EncoderFunction: def binary_encoder(x: TensorType) -> TensorType: # no need to one-hot encode binary categories (but we should still validate) - if not tf.reduce_all((x == 0) | (x == 1)): - raise ValueError(f"Invalid value {x}") + if tf.reduce_any((x != 0) & (x != 1)): + raise ValueError(f"Invalid values {tf.boolean_mask(x, ((x != 0) & (x != 1)))}") return x def encoder(x: TensorType) -> TensorType: