Skip to content

Commit

Permalink
Improve dataset_len error message (#879)
Browse files Browse the repository at this point in the history
Co-authored-by: Uri Granta <uri.granta@secondmind.ai>
  • Loading branch information
uri-granta and Uri Granta authored Oct 9, 2024
1 parent 7b3e9e8 commit 1a2e0ac
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions trieste/ask_tell_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,16 +435,18 @@ def acquisition_state(self) -> StateType | None:
@classmethod
def dataset_len(cls, datasets: Mapping[Tag, Dataset]) -> int:
"""Helper method for inferring the global dataset size."""
dataset_lens = [
tf.shape(dataset.query_points)[0]
dataset_lens = {
tag: int(tf.shape(dataset.query_points)[0])
for tag, dataset in datasets.items()
if not LocalizedTag.from_tag(tag).is_local
]
unique_lens, _ = tf.unique(dataset_lens)
}
unique_lens, _ = tf.unique(list(dataset_lens.values()))
if len(unique_lens) == 1:
return int(unique_lens[0])
else:
raise ValueError(f"Expected unique global dataset size, got {unique_lens}")
raise ValueError(
f"Expected unique global dataset size, got {unique_lens}: {dataset_lens}"
)

@classmethod
def from_record(
Expand Down

0 comments on commit 1a2e0ac

Please sign in to comment.