Skip to content

Commit

Permalink
Add dem_instrs module (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcSerraPeralta authored Oct 17, 2024
1 parent 52492d4 commit c8ad42c
Show file tree
Hide file tree
Showing 9 changed files with 423 additions and 53 deletions.
23 changes: 23 additions & 0 deletions qec_util/dem_instrs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from .dem_instrs import (
get_detectors,
get_logicals,
has_separator,
decomposed_detectors,
decomposed_logicals,
remove_detectors,
sorted_dem_instr,
)
from .util import xor_probs, xor_lists


__all__ = [
"get_detectors",
"get_logicals",
"has_separator",
"decomposed_detectors",
"decomposed_logicals",
"xor_probs",
"xor_lists",
"remove_detectors",
"sorted_dem_instr",
]
153 changes: 153 additions & 0 deletions qec_util/dem_instrs/dem_instrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from collections.abc import Iterable

import stim

from .util import xor_lists


def get_detectors(dem_instr: stim.DemInstruction) -> tuple[int, ...]:
"""Returns the detector indices that are flipped in the given DEM instruction."""
if dem_instr.type != "error":
raise ValueError(f"DemInstruction is not an error, it is {dem_instr.type}.")

if has_separator(dem_instr):
return xor_lists(*decomposed_detectors(dem_instr))
else:
return tuple(
sorted(
i.val for i in dem_instr.targets_copy() if i.is_relative_detector_id()
)
)


def get_logicals(dem_instr: stim.DemInstruction) -> tuple[int, ...]:
"""Returns the logical observable indices that are flipped in the given DEM instruction."""
if dem_instr.type != "error":
raise ValueError(f"DemInstruction is not an error, it is {dem_instr.type}.")

if has_separator(dem_instr):
return xor_lists(*decomposed_logicals(dem_instr))
else:
return tuple(
sorted(
i.val for i in dem_instr.targets_copy() if i.is_logical_observable_id()
)
)


def has_separator(dem_instr: stim.DemInstruction) -> bool:
"""Returns if the given DEM instruction has a separator."""
if dem_instr.type != "error":
raise ValueError(f"DemInstruction is not an error, it is {dem_instr.type}.")

return bool([i for i in dem_instr.targets_copy() if i.is_separator()])


def decomposed_detectors(dem_instr: stim.DemInstruction) -> list[tuple[int, ...]]:
"""Returns a list of the detector indices triggered for each fault that the DEM
instruction is decomposed into.
"""
if dem_instr.type != "error":
raise ValueError(f"DemInstruction is not an error, it is {dem_instr.type}.")

list_dets = []
current = []
for e in dem_instr.targets_copy():
if e.is_separator():
list_dets.append(current)
current = []
if e.is_relative_detector_id():
current.append(e.val)
list_dets.append(current)

# process dets
list_dets = [tuple(sorted(d)) for d in list_dets]

return list_dets


def decomposed_logicals(dem_instr: stim.DemInstruction) -> list[tuple[int, ...]]:
"""Returns a list of the logical indices triggered for each fault that the DEM
instruction is decomposed into.
"""
if not isinstance(dem_instr, stim.DemInstruction):
raise TypeError(
f"'dem_instr' must be a stim.DemInstruction, but {type(dem_instr)} was given."
)
if dem_instr.type != "error":
raise ValueError(f"'dem_instr' is not an error, it is {dem_instr.type}.")

list_logs = []
current = []
for e in dem_instr.targets_copy():
if e.is_separator():
list_logs.append(current)
current = []
if e.is_logical_observable_id():
current.append(e.val)
list_logs.append(current)

# process dets
list_logs = [tuple(sorted(l)) for l in list_logs]

return list_logs


def remove_detectors(
dem_instr: stim.DemInstruction, dets: Iterable[int]
) -> stim.DemInstruction:
"""Removes the specified detector indices from the given DEM instruction."""
if not isinstance(dem_instr, stim.DemInstruction):
raise TypeError(
f"'dem_instr' must be a stim.DemInstruction, but {type(dem_instr)} was given."
)
if not isinstance(dets, Iterable):
raise TypeError(f"'dets' must be iterable, but {type(dets)} was given.")

if dem_instr.type != "error":
raise ValueError(f"'dem_instr' is not an error, it is {dem_instr.type}.")

prob = dem_instr.args_copy()
targets = [
d
for d in dem_instr.targets_copy()
if not (d.is_relative_detector_id() and (d.val in dets))
]

# recurrently check that there cannot be any separator at the beginning or
# end of a stim.DemInstruction.
correct = [False, False]
while not (correct[0] and correct[1]):
if targets[0].is_separator():
targets = targets[1:]
else:
correct[0] = True

if targets[-1].is_separator():
targets = targets[:-1]
else:
correct[1] = True

return stim.DemInstruction(type="error", targets=targets, args=prob)


def sorted_dem_instr(dem_instr: stim.DemInstruction) -> stim.DemInstruction:
"""Returns the dem_instr in an specific order.
Note that it removes the separators.
"""
if not isinstance(dem_instr, stim.DemInstruction):
raise TypeError(
f"'dem_instr' must be a stim.DemInstruction, but {type(dem_instr)} was given."
)
if dem_instr.type != "error":
return dem_instr

dets = sorted(get_detectors(dem_instr))
logs = sorted(get_logicals(dem_instr))
dets_target = list(map(stim.target_relative_detector_id, dets))
logs_target = list(map(stim.target_logical_observable_id, logs))
prob = dem_instr.args_copy()

return stim.DemInstruction(
type="error", targets=dets_target + logs_target, args=prob
)
45 changes: 45 additions & 0 deletions qec_util/dem_instrs/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from collections.abc import Iterable


def xor_two_lists(list1: Iterable, list2: Iterable) -> tuple:
"""Returns the symmetric difference of two lists.
Note that the resulting list has been sorted.
"""
return tuple(sorted(set(list1).symmetric_difference(list2)))


def xor_lists(*elements: Iterable) -> tuple:
"""Returns the symmetric difference of multiple lists.
Note that the resulting list has been sorted.
"""
output = []
for element in elements:
output = xor_two_lists(output, element)
return tuple(sorted(output))


def xor_two_probs(p: float | int, q: float | int) -> float | int:
"""Returns the probability of only one of the events happening.
Parameters
----------
p
Probability of one event.
q
Probability of the other event.
"""
return p * (1 - q) + (1 - p) * q


def xor_probs(*probs: float | int) -> float | int:
"""Returns the probability of an odd number of events happening.
Parameters
----------
*probs
Probabilities of each of the events.
"""
odd_prob = probs[0]
for prob in probs[1:]:
odd_prob = xor_two_probs(prob, odd_prob)
return odd_prob
4 changes: 4 additions & 0 deletions qec_util/dems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
is_instr_in_dem,
get_max_weight_hyperedge,
disjoint_graphs,
get_flippable_detectors,
get_flippable_logicals,
)


Expand All @@ -13,4 +15,6 @@
"is_instr_in_dem",
"get_max_weight_hyperedge",
"disjoint_graphs",
"get_flippable_detectors",
"get_flippable_logicals",
]
67 changes: 41 additions & 26 deletions qec_util/dems/dems.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import stim
import networkx as nx

from .util import sorting_index
from ..dem_instrs import get_detectors, get_logicals, sorted_dem_instr


def remove_gauge_detectors(dem: stim.DetectorErrorModel) -> stim.DetectorErrorModel:
Expand Down Expand Up @@ -65,30 +65,18 @@ def dem_difference(
)

dem_1_ordered = stim.DetectorErrorModel()
num_dets = dem_1.num_detectors
for dem_instr in dem_1.flattened():
if dem_instr.type != "error":
continue

# remove separators
targets = [t for t in dem_instr.targets_copy() if not t.is_separator()]

targets = sorted(targets, key=lambda x: sorting_index(x, num_dets))
prob = dem_instr.args_copy()[0]
dem_1_ordered.append("error", prob, targets)
dem_1_ordered.append(sorted_dem_instr(dem_instr))

dem_2_ordered = stim.DetectorErrorModel()
num_dets = dem_2.num_detectors
for dem_instr in dem_2.flattened():
if dem_instr.type != "error":
continue

# remove separators
targets = [t for t in dem_instr.targets_copy() if not t.is_separator()]

targets = sorted(targets, key=lambda x: sorting_index(x, num_dets))
prob = dem_instr.args_copy()[0]
dem_2_ordered.append("error", prob, targets)
dem_2_ordered.append(sorted_dem_instr(dem_instr))

diff_1 = stim.DetectorErrorModel()
for dem_instr in dem_1_ordered:
Expand Down Expand Up @@ -120,20 +108,13 @@ def is_instr_in_dem(
f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
)

num_dets = dem.num_detectors
prob = dem_instr.args_copy()[0]
targets = [t for t in dem_instr.targets_copy() if not t.is_separator()]
targets = sorted(targets, key=lambda x: sorting_index(x, num_dets))
dem_instr = sorted_dem_instr(dem_instr)

for instr in dem.flattened():
if instr.type != "error":
continue
if instr.args_copy()[0] != prob:
for other_instr in dem.flattened():
if other_instr.type != "error":
continue

other_targets = [t for t in instr.targets_copy() if not t.is_separator()]
other_targets = sorted(other_targets, key=lambda x: sorting_index(x, num_dets))
if other_targets == targets:
if dem_instr == sorted_dem_instr(other_instr):
return True

return False
Expand Down Expand Up @@ -208,3 +189,37 @@ def disjoint_graphs(dem: stim.DetectorErrorModel) -> list[list[int]]:
subgraphs = [list(c) for c in nx.connected_components(g)]

return subgraphs


def get_flippable_detectors(dem: stim.DetectorErrorModel) -> set[int]:
"""Returns a the detector indices present in the given DEM
that are triggered by some errors.
"""
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(
f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
)

dets = set()
for dem_instr in dem.flattened():
if dem_instr.type == "error":
dets.update(get_detectors(dem_instr))

return dets


def get_flippable_logicals(dem: stim.DetectorErrorModel) -> set[int]:
"""Returns a the logical observable indices present in the given DEM
that are triggered by some errors.
"""
if not isinstance(dem, stim.DetectorErrorModel):
raise TypeError(
f"'dem' must be a stim.DetectorErrorModel, but {type(dem)} was given."
)

logs = set()
for dem_instr in dem.flattened():
if dem_instr.type == "error":
logs.update(get_logicals(dem_instr))

return logs
26 changes: 0 additions & 26 deletions qec_util/dems/util.py

This file was deleted.

Loading

0 comments on commit c8ad42c

Please sign in to comment.