diff --git a/pyzx/gflow.py b/pyzx/gflow.py index a81b0dc5..26252e7c 100644 --- a/pyzx/gflow.py +++ b/pyzx/gflow.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from fractions import Fraction from typing import Dict, Set, Tuple, Optional from .extract import bi_adj @@ -76,7 +77,8 @@ def gflow( vertices: Set[VT] = set(v for v in g.vertices() if vertex_is_zx(ty[v])) pattern_inputs: Set[VT] = set() pattern_outputs: Set[VT] = set() - paulis: Set[VT] = set() + pauli_x: Set[VT] = set() + pauli_y: Set[VT] = set() for inp in g.inputs(): pattern_inputs |= set(n for n in g.neighbors(inp) if vertex_is_zx(ty[n])) @@ -87,7 +89,8 @@ def gflow( pattern_inputs, pattern_outputs = pattern_outputs, pattern_inputs if pauli: - paulis = set(v for v in vertices.difference(pattern_inputs) if g.phase(v) in (0,1)) + pauli_x = set(v for v in vertices if g.phase(v) in (0,1)) + pauli_y = set(v for v in vertices if g.phase(v) in (Fraction(1,2),Fraction(-1,2))) processed: Set[VT] = pattern_outputs.copy() | g.grounds() non_outputs = list(vertices.difference(pattern_outputs)) @@ -100,7 +103,7 @@ def gflow( correct: Set[VT] = set() processed_prime = [ v - for v in (processed | paulis).difference(pattern_inputs) + for v in (processed | pauli_x | pauli_y).difference(pattern_inputs) if focus or any(w not in processed for w in g.neighbors(v)) ]