Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes Wires objects as wire labels bug #6933

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@

<h3>Bug fixes 🐛</h3>

* Fix `qml.wires.Wires` initialization to disallow `Wires` objects as wires labels.
Now, `Wires` is idempotent, e.g. `Wires([Wires([0]), Wires([1])])==Wires([0, 1])`.
[(#6933)](https://github.com/PennyLaneAI/pennylane/pull/6933)

* `qml.capture.PlxprInterpreter` now correctly handles propagation of constants when interpreting higher-order primitives
[(#6913)](https://github.com/PennyLaneAI/pennylane/pull/6913)

Expand Down
15 changes: 13 additions & 2 deletions pennylane/wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ def _process(wires):
if len(set_of_wires) != len(tuple_of_wires):
raise WireError(f"Wires must be unique; got {wires}.")

return tuple_of_wires
# Tuple of wires are flattened by iterating through each wire label and
# checking if it is a Wires object. If so, flatten the Wires object into a tuple of wire labels.
# The nested tuple of wires are then stitched together using itertools.chain.
andrijapau marked this conversation as resolved.
Show resolved Hide resolved
return tuple(itertools.chain(*(_flatten_wires_object(x) for x in tuple_of_wires)))
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved


class Wires(Sequence):
Expand All @@ -120,7 +123,7 @@ class Wires(Sequence):
"""

def _flatten(self):
"""Serialize Wires into a flattened representation according to the PyTree convension."""
"""Serialize Wires into a flattened representation according to the PyTree convention."""
return self._labels, ()

@classmethod
Expand Down Expand Up @@ -731,5 +734,13 @@ def __rxor__(self, other):

WiresLike = Union[Wires, Iterable[Hashable], Hashable]


def _flatten_wires_object(wire_label):
"""Converts the input to a tuple of wire labels."""
if isinstance(wire_label, Wires):
return wire_label.labels
return [wire_label]


# Register Wires as a PyTree-serializable class
register_pytree(Wires, Wires._flatten, Wires._unflatten) # pylint: disable=protected-access
12 changes: 9 additions & 3 deletions tests/test_wires.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@
class TestWires:
"""Tests for the ``Wires`` class."""

def test_wires_object_as_label(self):
"""Tests that a Wires object can be used as a label for another Wires object."""
assert Wires([0, 1]) == Wires([Wires([0]), Wires([1])])
assert Wires(["a", "b", 1]) == Wires([Wires(["a", "b"]), Wires([1])])
assert Wires([Wires([(0, 0), (0, 1)])]) == Wires([(0, 0), (0, 1)])

def test_error_if_wires_none(self):
"""Tests that a TypeError is raised if None is given as wires."""
with pytest.raises(TypeError, match="Must specify a set of wires."):
Expand Down Expand Up @@ -74,7 +80,7 @@ def test_creation_from_wires_lists(self):
"""Tests that a Wires object can be created from a list of Wires."""

wires = Wires([Wires([0]), Wires([1]), Wires([2])])
assert wires.labels == (Wires([0]), Wires([1]), Wires([2]))
assert wires.labels == (0, 1, 2)

@pytest.mark.parametrize(
"iterable", [[1, 0, 4], ["a", "b", "c"], [0, 1, None], ["a", 1, "ancilla"]]
Expand Down Expand Up @@ -148,7 +154,7 @@ def test_contains(
wires = Wires([0, 1, 2, 3, Wires([4, 5]), None])

assert 0 in wires
assert Wires([4, 5]) in wires
assert Wires([4, 5]) not in wires
assert None in wires
assert Wires([1]) not in wires
assert Wires([0, 3]) not in wires
Expand All @@ -170,7 +176,7 @@ def test_contains_wires(

assert not wires.contains_wires(0) # wrong type
assert not wires.contains_wires([0, 1]) # wrong type
assert not wires.contains_wires(
assert wires.contains_wires(
Wires([4, 5])
) # looks up 4 and 5 in wires, which are not present

Expand Down