Skip to content

Commit

Permalink
Add option in qec_util.performance.sampler_failures to store batc…
Browse files Browse the repository at this point in the history
…h data in file (#26)
  • Loading branch information
MarcSerraPeralta authored Sep 13, 2024
1 parent 5a0d906 commit f7b2ac9
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 5 deletions.
3 changes: 2 additions & 1 deletion qec_util/performance/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
lmfit_par_to_ufloat,
confidence_interval_binomial,
)
from .sampler import sample_failures
from .sampler import sample_failures, read_failures_from_file
from . import plots

__all__ = [
Expand All @@ -15,5 +15,6 @@
"lmfit_par_to_ufloat",
"confidence_interval_binomial",
"sample_failures",
"read_failures_from_file",
"plots",
]
99 changes: 96 additions & 3 deletions qec_util/performance/sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Tuple
from typing import Tuple, Optional, Callable
import time
import pathlib

import numpy as np
import stim
Expand All @@ -11,6 +12,8 @@ def sample_failures(
max_failures: int | float = 100,
max_time: int | float = 3600,
max_samples: int | float = 1_000_000,
file_name: Optional[str | pathlib.Path] = None,
decoding_failure: Callable = lambda x: x.any(axis=1),
) -> Tuple[int, int]:
"""Samples decoding failures until one of three conditions is met:
(1) max. number of failures reached, (2) max. runtime reached,
Expand All @@ -34,13 +37,33 @@ def sample_failures(
Maximum number of samples to reach before stopping the calculation.
Set this parameter to ``np.inf`` to not have any restriction on the
maximum number of samples.
file_name
Name of the file in which to store the partial results.
If the file does not exist, it will be created.
Specifying a file is useful if the computation is stop midway, so
that it can be continued in if the same file is given. It can also
be used to sample more points.
decoding_failure
Function that returns `True` if there has been a decoding failure, else
`False`. Its input is an ``np.ndarray`` of shape
``(num_samples, num_observables)`` and its output must be a boolean
``np.ndarray`` of shape ``(num_samples,)``.
By default, a decoding failure is when a logical error happened to
any of the logical observables.
Returns
-------
num_failures
Number of decoding failures.
num_samples
Number of samples taken.
Notes
-----
If ``file_name`` is specified, each batch is stored in the file in a
different line using the following format: ``num_failures num_samples\n``.
The number of failures and samples can be read using
``read_failures_from_file`` function present in the same module.
"""
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(
Expand All @@ -51,13 +74,23 @@ def sample_failures(

sampler = dem.compile_sampler()
num_failures, num_samples = 0, 0
if (file_name is not None) and pathlib.Path(file_name).exists():
num_failures, num_samples = read_failures_from_file(file_name)
# update the maximum limits based on the already calculated samples
max_samples -= num_samples
max_failures -= num_failures

# estimate the batch size for decoding
defects, log_flips, _ = sampler.sample(shots=100)
t_init = time.time()
predictions = decoder.decode_batch(defects)
run_time = (time.time() - t_init) / 100
log_err_prob = np.average(predictions != log_flips)
failures = decoding_failure(predictions != log_flips)
if (not isinstance(failures, np.ndarray)) or (failures.shape != (100,)):
raise ValueError(
f"'decoding_function' does not return a correctly shaped output"
)
log_err_prob = np.average(failures)
estimated_max_samples = min(
[
max_samples,
Expand All @@ -79,7 +112,67 @@ def sample_failures(
defects, log_flips, _ = sampler.sample(shots=batch_size)
predictions = decoder.decode_batch(defects)
log_errors = predictions != log_flips
num_failures += log_errors.sum()
batch_failures = decoding_failure(log_errors).sum()

num_failures += batch_failures
num_samples += batch_size
if file_name is not None:
with open(file_name, "a") as file:
file.write(f"{batch_failures} {batch_size}\n")

return int(num_failures), num_samples


def read_failures_from_file(
file_name: str | pathlib.Path,
max_num_failures: int | float = np.inf,
max_num_samples: int | float = np.inf,
) -> Tuple[int, int]:
"""Returns the number of failues and samples stored in a file.
Parameters
----------
file_name
Name of the file with the data.
The structure of the file is specified in the Notes and the intended
usage is for the ``sample_failures`` function.
max_num_failues
If specified, only adds up the first batches until the number of
failures reaches or (firstly) surpasses the given number.
By default ``np.inf``, thus it adds up all the batches in the file.
max_num_samples
If specified, only adds up the first batches until the number of
samples reaches or (firstly) surpasses the given number.
By default ``np.inf``, thus it adds up all the batches in the file.
Returns
-------
num_failures
Total number of failues in the given number of samples.
num_samples
Total number of samples.
Notes
-----
The structure of ``file_name`` file is: each batch is stored in the file in a
different line using the format ``num_failures num_samples\n``.
The file ends with an empty line.
"""
if not pathlib.Path(file_name).exists():
raise FileExistsError(f"The given file ({file_name}) does not exist.")

num_failures, num_samples = 0, 0
with open(file_name, "r") as file:
for line in file:
if line == "":
continue

line = line[:-1] # remove \n character at the end
batch_failures, batch_samples = map(int, line.split(" "))
num_failures += batch_failures
num_samples += batch_samples

if num_failures >= max_num_failures or num_samples >= max_num_samples:
return num_failures, num_samples

return num_failures, num_samples
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import pytest


def pytest_addoption(parser):
parser.addoption("--show-figures", action="store_true")

Expand All @@ -9,3 +12,17 @@ def pytest_generate_tests(metafunc):
option_value = metafunc.config.option.show_figures
if "show_figures" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("show_figures", [bool(option_value)])


@pytest.fixture(scope="session")
def failures_file(tmp_path_factory):
"""This function is executed before any test is run and creates
a file in a temporary directory that can be passed to any test function.
From https://docs.pytest.org/en/stable/how-to/tmp_path.html
"""
file_name = tmp_path_factory.mktemp("data") / "tmp_failures_file.txt"
contents = "10 50\n11 50\n"
with open(file_name, "w") as file:
file.write(contents)
return file_name
60 changes: 59 additions & 1 deletion tests/performance/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,58 @@
import stim
from pymatching import Matching

from qec_util.performance import sample_failures
from qec_util.performance import sample_failures, read_failures_from_file


def test_sampler_to_file(tmp_path):
circuit = stim.Circuit.generated(
code_task="repetition_code:memory",
distance=3,
rounds=3,
after_clifford_depolarization=0.01,
)
dem = circuit.detector_error_model()
mwpm = Matching(dem)

num_failures, num_samples = sample_failures(
dem,
mwpm,
max_samples=1_000,
max_time=np.inf,
max_failures=np.inf,
file_name=tmp_path / "tmp_file.txt",
)
read_failures, read_samples = read_failures_from_file(tmp_path / "tmp_file.txt")

assert num_failures == read_failures
assert num_samples == read_samples

return


def test_sampler_from_file(failures_file):
circuit = stim.Circuit.generated(
code_task="repetition_code:memory",
distance=3,
rounds=3,
after_clifford_depolarization=0.01,
)
dem = circuit.detector_error_model()
mwpm = Matching(dem)

num_failures, num_samples = sample_failures(
dem,
mwpm,
max_samples=100,
max_time=np.inf,
max_failures=np.inf,
file_name=failures_file,
)

assert num_failures == 21
assert num_samples == 100

return


def test_sampler():
Expand Down Expand Up @@ -37,3 +88,10 @@ def test_sampler():
assert (num_failures >= 0) and (num_samples) >= 0

return


def test_read_failures_from_file(failures_file):
num_failures, num_samples = read_failures_from_file(failures_file)
assert num_failures == 21
assert num_samples == 100
return

0 comments on commit f7b2ac9

Please sign in to comment.