Skip to content

Commit

Permalink
Add minimum options in sampler (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcSerraPeralta authored Feb 10, 2025
1 parent cceb6d9 commit 3cbc0ed
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 27 deletions.
82 changes: 55 additions & 27 deletions qec_util/performance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,23 @@
def sample_failures(
dem: stim.DetectorErrorModel,
decoder,
min_failures: int | float = 0,
max_failures: int | float = 100,
min_time: int | float = 0,
max_time: int | float = np.inf,
min_samples: int | float = 0,
max_samples: int | float = 1_000_000,
batch_size: int | np.float64 | float | None = None,
max_batch_size: int | float = np.inf,
file_name: str | pathlib.Path | None = None,
decoding_failure: Callable = lambda x: x.any(axis=1),
extra_metrics: Callable = lambda _: list(),
) -> tuple[int, int, list[int]]:
"""Samples decoding failures until one of three conditions is met:
(1) max. number of failures reached, (2) max. runtime reached,
"""Samples decoding failures until all the minimum requirements have been
fulfilled (i.e. min. number of failures, min. runtime, min. number of samples)
and one of three conditions is met:
(1) max. number of failures reached,
(2) max. runtime reached,
(3) max. number of samples taken.
Parameters
Expand All @@ -28,17 +35,27 @@ def sample_failures(
logical flips.
decoder
Decoder object with a ``decode_batch`` method.
min_failures
Minimum number of failures to reach before being able to stop the sampling.
max_failures
Maximum number of failures to reach before stopping the calculation.
Set this parameter to ``np.inf`` to not have any restriction on the
maximum number of failures.
min_time
Minimum duration for this function (in seconds) before being able to stop the sampling.
max_time
Maximum duration for this function, in seconds. By default, this
parameter is set to ``np.inf`` to not place any restriction on runtime.
min_failures
Minimum number of samples to reach before being able to stop the sampling.
max_samples
Maximum number of samples to reach before stopping the calculation.
Set this parameter to ``np.inf`` to not have any restriction on the
maximum number of samples.
batch_size
Number of samples to decode per batch. If ``None``, it estimates
the best ``batch_size`` given the other parameters (i.e. ``max_time``,
``max_samples``, and ``max_failures``).
max_batch_size
Maximum number of samples to decode per batch. This is useful when
encountering memory issues, as one can just reduce ``max_batch_size``.
Expand Down Expand Up @@ -94,7 +111,8 @@ def sample_failures(
if (num_samples >= max_samples) or (num_failures >= max_failures):
return num_failures, num_samples, extra

# estimate the batch size for decoding
# check that everyting works correct and estimate the batch size for decoding,
# if needed
sampler = dem.compile_sampler()
defects, log_flips, _ = sampler.sample(shots=100)
t_init = time.time()
Expand All @@ -113,36 +131,46 @@ def sample_failures(
or any(m.shape != (100,) for m in extra)
):
raise ValueError("'extra_metrics' does not return a correctly shaped output.")
log_err_prob = np.average(failures)
estimated_max_samples = min(
[
max_samples - num_samples,
max_time / run_time,
(
(max_failures - num_failures) / log_err_prob
if log_err_prob != 0
else np.inf
),
]
)
batch_size = estimated_max_samples / 5 # perform approx 5 batches

# avoid batch_size = 0 or np.inf and also avoid overshooting
batch_size = max([batch_size, 1])
batch_size = min([batch_size, max_samples - num_samples])
# int(np.inf) raises an error and it could be that both batch_size and
# max_samples are np.inf
batch_size = batch_size if batch_size != np.inf else 200_000
batch_size = min([batch_size, max_batch_size])

if batch_size is None:
log_err_prob = np.average(failures)
estimated_max_samples = min(
[
max_samples - num_samples,
max_time / run_time,
(
(max_failures - num_failures) / log_err_prob
if log_err_prob != 0
else np.inf
),
]
)
batch_size = estimated_max_samples / 5 # perform approx 5 batches

# avoid batch_size = 0 or np.inf and also avoid overshooting
batch_size = max([batch_size, 1])
batch_size = min([batch_size, max_samples - num_samples])
# int(np.inf) raises an error and it could be that both batch_size and
# max_samples are np.inf
batch_size = batch_size if batch_size != np.inf else 200_000
batch_size = min([batch_size, max_batch_size])

# ensure batch size is int
batch_size = int(batch_size)

# initialize the correct size of extra metrics
extra = [0 for _ in extra]

# start sampling...
while (
(time.time() - t_init) < max_time
and num_failures < max_failures
and num_samples < max_samples
(time.time() - t_init) < min_time
or num_failures < min_failures
or num_samples < min_samples
or (
(time.time() - t_init) < max_time
and num_failures < max_failures
and num_samples < max_samples
)
):
defects, log_flips, _ = sampler.sample(shots=batch_size)
predictions = decoder.decode_batch(defects)
Expand Down
33 changes: 33 additions & 0 deletions tests/performance/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,39 @@ def test_sampler_extra_metrics(tmp_path):
return


def test_samples_minimum_requirements():
circuit = stim.Circuit.generated(
code_task="repetition_code:memory",
distance=3,
rounds=3,
after_clifford_depolarization=0.01,
)
dem = circuit.detector_error_model()
mwpm = Matching(dem)

_, num_samples, _ = sample_failures(
dem,
mwpm,
max_samples=1,
min_samples=200,
max_time=0,
max_failures=1,
)
assert num_samples == 200

num_failures, _, _ = sample_failures(
dem,
mwpm,
max_samples=1,
min_failures=20,
max_time=0,
max_failures=0,
)
assert num_failures >= 2

return


def test_sample_early_stopping(failures_file):
circuit = stim.Circuit.generated(
code_task="repetition_code:memory",
Expand Down

0 comments on commit 3cbc0ed

Please sign in to comment.