Skip to content

Commit

Permalink
added ability to compute focused gflows
Browse files Browse the repository at this point in the history
  • Loading branch information
akissinger committed Nov 10, 2024
1 parent 398254a commit 9394cc6
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 70 deletions.
58 changes: 36 additions & 22 deletions pyzx/gflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,22 @@

from typing import Dict, Set, Tuple, Optional

from networkx import neighbors

from .extract import bi_adj
from .linalg import Mat2
from .graph.base import BaseGraph, VertexType, VT, ET
from .utils import vertex_is_zx


def gflow(
g: BaseGraph[VT, ET], delayed: bool=False, pauli: bool=False
g: BaseGraph[VT, ET], focus: bool=False, reverse: bool=False, pauli: bool=False
) -> Optional[Tuple[Dict[VT, int], Dict[VT, Set[VT]], int]]:
r"""Compute the gflow of a diagram in graph-like form.
:param g: A ZX-graph.
:param delayed: Compute the maximally-delayed gflow
:param focus: Compute the focussed gflow
:param reverse: Reverse the roles of inputs and outputs
:param pauli: Compute the Pauli flow, restricted to {XZ, X} measurements
Based on algorithm by Perdrix and Mhalla.
Expand Down Expand Up @@ -72,7 +75,6 @@ def gflow(
gflow: Dict[VT, Set[VT]] = {}
ty = g.types()

processed: Set[VT] = set(g.outputs()) | g.grounds()
vertices: Set[VT] = set(v for v in g.vertices() if vertex_is_zx(ty[v]))
pattern_inputs: Set[VT] = set()
pattern_outputs: Set[VT] = set()
Expand All @@ -82,11 +84,16 @@ def gflow(
pattern_inputs |= set(n for n in g.neighbors(inp) if vertex_is_zx(ty[n]))
for outp in g.outputs():
pattern_outputs |= set(n for n in g.neighbors(outp) if vertex_is_zx(ty[n]))

if reverse:
pattern_inputs, pattern_outputs = pattern_outputs, pattern_inputs

if pauli:
paulis = set(v for v in vertices.difference(pattern_inputs) if g.phase(v) in (0,1))

processed = pattern_outputs.copy()
processed: Set[VT] = pattern_outputs.copy() | g.grounds()
non_outputs = list(vertices.difference(pattern_outputs))
zerovec = Mat2.zeros(len(non_outputs), 1)
for v in processed:
l[v] = 0

Expand All @@ -95,26 +102,33 @@ def gflow(
correct = set()
processed_prime = [
v
for v in processed.difference(pattern_inputs) | paulis
if delayed or any(w not in processed for w in g.neighbors(v))
]
candidates = [
v
for v in vertices.difference(processed)
if any(w in processed_prime for w in g.neighbors(v))
for v in (processed | paulis).difference(pattern_inputs)
if focus or any(w not in processed for w in g.neighbors(v))
]

zerovec = Mat2.zeros(len(candidates), 1)

m = bi_adj(g, processed_prime, candidates)
for index, u in enumerate(candidates):
vu = zerovec.copy()
vu.data[index][0] = 1
x = m.solve(vu)
if x:
correct.add(u)
gflow[u] = {processed_prime[i] for i in range(x.rows()) if x.data[i][0]}
l[u] = k
if focus:
clean = non_outputs
else:
clean = [v for v in vertices
if v not in processed and
any(w in processed_prime for w in g.neighbors(v))]

# candidates = [
# v
# for v in vertices.difference(processed)
# if any(w in processed_prime for w in g.neighbors(v))
# ]

m = bi_adj(g, processed_prime, clean)
for index, u in enumerate(clean):
if not focus or (u not in processed and any(w in processed_prime for w in g.neighbors(v))):
vu = zerovec.copy()
vu.data[index][0] = 1
x = m.solve(vu)
if x:
correct.add(u)
gflow[u] = {processed_prime[i] for i in range(x.rows()) if x.data[i][0]}
l[u] = k

if not correct:
if len(vertices) == len(processed):
Expand Down
Loading

0 comments on commit 9394cc6

Please sign in to comment.