From 3cbc0ed3aea8e75f01d79201f965c0aa65b20fd1 Mon Sep 17 00:00:00 2001 From: MarcSerraPeralta <43704266+MarcSerraPeralta@users.noreply.github.com> Date: Mon, 10 Feb 2025 16:33:30 +0100 Subject: [PATCH] Add minimum options in sampler (#48) --- qec_util/performance/sampler.py | 82 +++++++++++++++++++++---------- tests/performance/test_sampler.py | 33 +++++++++++++ 2 files changed, 88 insertions(+), 27 deletions(-) diff --git a/qec_util/performance/sampler.py b/qec_util/performance/sampler.py index 839b3e5..af55e7c 100644 --- a/qec_util/performance/sampler.py +++ b/qec_util/performance/sampler.py @@ -9,16 +9,23 @@ def sample_failures( dem: stim.DetectorErrorModel, decoder, + min_failures: int | float = 0, max_failures: int | float = 100, + min_time: int | float = 0, max_time: int | float = np.inf, + min_samples: int | float = 0, max_samples: int | float = 1_000_000, + batch_size: int | np.float64 | float | None = None, max_batch_size: int | float = np.inf, file_name: str | pathlib.Path | None = None, decoding_failure: Callable = lambda x: x.any(axis=1), extra_metrics: Callable = lambda _: list(), ) -> tuple[int, int, list[int]]: - """Samples decoding failures until one of three conditions is met: - (1) max. number of failures reached, (2) max. runtime reached, + """Samples decoding failures until all the minimum requirements have been + fulfilled (i.e. min. number of failures, min. runtime, min. number of samples) + and one of three conditions is met: + (1) max. number of failures reached, + (2) max. runtime reached, (3) max. number of samples taken. Parameters @@ -28,17 +35,27 @@ def sample_failures( logical flips. decoder Decoder object with a ``decode_batch`` method. + min_failures + Minimum number of failures to reach before being able to stop the sampling. max_failures Maximum number of failures to reach before stopping the calculation. Set this parameter to ``np.inf`` to not have any restriction on the maximum number of failures. + min_time + Minimum duration for this function (in seconds) before being able to stop the sampling. max_time Maximum duration for this function, in seconds. By default, this parameter is set to ``np.inf`` to not place any restriction on runtime. + min_failures + Minimum number of samples to reach before being able to stop the sampling. max_samples Maximum number of samples to reach before stopping the calculation. Set this parameter to ``np.inf`` to not have any restriction on the maximum number of samples. + batch_size + Number of samples to decode per batch. If ``None``, it estimates + the best ``batch_size`` given the other parameters (i.e. ``max_time``, + ``max_samples``, and ``max_failures``). max_batch_size Maximum number of samples to decode per batch. This is useful when encountering memory issues, as one can just reduce ``max_batch_size``. @@ -94,7 +111,8 @@ def sample_failures( if (num_samples >= max_samples) or (num_failures >= max_failures): return num_failures, num_samples, extra - # estimate the batch size for decoding + # check that everyting works correct and estimate the batch size for decoding, + # if needed sampler = dem.compile_sampler() defects, log_flips, _ = sampler.sample(shots=100) t_init = time.time() @@ -113,36 +131,46 @@ def sample_failures( or any(m.shape != (100,) for m in extra) ): raise ValueError("'extra_metrics' does not return a correctly shaped output.") - log_err_prob = np.average(failures) - estimated_max_samples = min( - [ - max_samples - num_samples, - max_time / run_time, - ( - (max_failures - num_failures) / log_err_prob - if log_err_prob != 0 - else np.inf - ), - ] - ) - batch_size = estimated_max_samples / 5 # perform approx 5 batches - - # avoid batch_size = 0 or np.inf and also avoid overshooting - batch_size = max([batch_size, 1]) - batch_size = min([batch_size, max_samples - num_samples]) - # int(np.inf) raises an error and it could be that both batch_size and - # max_samples are np.inf - batch_size = batch_size if batch_size != np.inf else 200_000 - batch_size = min([batch_size, max_batch_size]) + + if batch_size is None: + log_err_prob = np.average(failures) + estimated_max_samples = min( + [ + max_samples - num_samples, + max_time / run_time, + ( + (max_failures - num_failures) / log_err_prob + if log_err_prob != 0 + else np.inf + ), + ] + ) + batch_size = estimated_max_samples / 5 # perform approx 5 batches + + # avoid batch_size = 0 or np.inf and also avoid overshooting + batch_size = max([batch_size, 1]) + batch_size = min([batch_size, max_samples - num_samples]) + # int(np.inf) raises an error and it could be that both batch_size and + # max_samples are np.inf + batch_size = batch_size if batch_size != np.inf else 200_000 + batch_size = min([batch_size, max_batch_size]) + + # ensure batch size is int batch_size = int(batch_size) + # initialize the correct size of extra metrics extra = [0 for _ in extra] # start sampling... while ( - (time.time() - t_init) < max_time - and num_failures < max_failures - and num_samples < max_samples + (time.time() - t_init) < min_time + or num_failures < min_failures + or num_samples < min_samples + or ( + (time.time() - t_init) < max_time + and num_failures < max_failures + and num_samples < max_samples + ) ): defects, log_flips, _ = sampler.sample(shots=batch_size) predictions = decoder.decode_batch(defects) diff --git a/tests/performance/test_sampler.py b/tests/performance/test_sampler.py index a94aa59..90567ca 100644 --- a/tests/performance/test_sampler.py +++ b/tests/performance/test_sampler.py @@ -67,6 +67,39 @@ def test_sampler_extra_metrics(tmp_path): return +def test_samples_minimum_requirements(): + circuit = stim.Circuit.generated( + code_task="repetition_code:memory", + distance=3, + rounds=3, + after_clifford_depolarization=0.01, + ) + dem = circuit.detector_error_model() + mwpm = Matching(dem) + + _, num_samples, _ = sample_failures( + dem, + mwpm, + max_samples=1, + min_samples=200, + max_time=0, + max_failures=1, + ) + assert num_samples == 200 + + num_failures, _, _ = sample_failures( + dem, + mwpm, + max_samples=1, + min_failures=20, + max_time=0, + max_failures=0, + ) + assert num_failures >= 2 + + return + + def test_sample_early_stopping(failures_file): circuit = stim.Circuit.generated( code_task="repetition_code:memory",