diff --git a/trieste/ask_tell_optimization.py b/trieste/ask_tell_optimization.py index 4052ed2ab..f2c2da6d5 100644 --- a/trieste/ask_tell_optimization.py +++ b/trieste/ask_tell_optimization.py @@ -435,16 +435,18 @@ def acquisition_state(self) -> StateType | None: @classmethod def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int: """Helper method for inferring the global dataset size.""" - dataset_lens = [ - tf.shape(dataset.query_points)[0] + dataset_lens = { + tag: int(tf.shape(dataset.query_points)[0]) for tag, dataset in datasets.items() if not LocalizedTag.from_tag(tag).is_local - ] - unique_lens, _ = tf.unique(dataset_lens) + } + unique_lens, _ = tf.unique(list(dataset_lens.values())) if len(unique_lens) == 1: return int(unique_lens[0]) else: - raise ValueError(f"Expected unique global dataset size, got {unique_lens}") + raise ValueError( + f"Expected unique global dataset size, got {unique_lens}: {dataset_lens}" + ) @classmethod def from_record(