From 5e0a9905b27dfb7556d0fdc1dd1118c8895b42ed Mon Sep 17 00:00:00 2001 From: Aleks Kissinger Date: Mon, 16 Dec 2024 10:52:35 +0000 Subject: [PATCH] improved to_rg with a max-cut heuristic --- pyzx/graph/base.py | 4 +- pyzx/graph/graph_s.py | 1 + pyzx/pauliweb.py | 6 ++- pyzx/simplify.py | 93 +++++++++++++++++++++++-------------------- 4 files changed, 57 insertions(+), 47 deletions(-) diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index f966c303..eb78cccd 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -1011,10 +1011,10 @@ def set_auto_simplify(self, s: bool) -> None: def is_phase_gadget(self, v: VT) -> bool: """Returns True if the vertex is the 'hub' of a phase gadget""" - if self.phase(v) != 0 or self.vertex_degree(v) < 2: + if not vertex_is_zx(self.type(v)) or self.phase(v) != 0 or self.vertex_degree(v) < 2: return False for w in self.neighbors(v): - if self.vertex_degree(w) == 1: + if vertex_is_zx(self.type(w)) and self.vertex_degree(w) == 1: return True return False diff --git a/pyzx/graph/graph_s.py b/pyzx/graph/graph_s.py index 902479c6..1f06d680 100644 --- a/pyzx/graph/graph_s.py +++ b/pyzx/graph/graph_s.py @@ -219,6 +219,7 @@ def vertices_in_range(self, start, end): def edges(self, s=None, t=None): if s is not None and t is not None: + if self.connected(s, t): yield (s,t) if s < t else (t,s) elif s is not None: for t in self.graph[s]: diff --git a/pyzx/pauliweb.py b/pyzx/pauliweb.py index 084dcdb6..33e5df45 100644 --- a/pyzx/pauliweb.py +++ b/pyzx/pauliweb.py @@ -80,9 +80,11 @@ def __repr__(self): def preprocess(g: BaseGraph[VT,ET]): - g.normalize() + #g.normalize() gadgetize(g) - to_rg(g) + gadgets = set(v for v in g.vertices() if g.is_phase_gadget(v)) + boundary_spiders = set(v for v in g.vertices() if any(g.type(w) == VertexType.BOUNDARY for w in g.neighbors(v))) + to_rg(g, init_z=boundary_spiders, init_x=gadgets) in_circ = Circuit(len(g.inputs())) for j,i in enumerate(g.inputs()): diff --git a/pyzx/simplify.py b/pyzx/simplify.py index aa60db94..1f3ce5da 100644 --- a/pyzx/simplify.py +++ b/pyzx/simplify.py @@ -312,54 +312,60 @@ def to_gh(g: BaseGraph[VT,ET],quiet:bool=True) -> None: et = g.edge_type(e) g.set_edge_type(e, toggle_edge(et)) -def to_rg(g: BaseGraph[VT,ET], select:Optional[Callable[[VT],bool]]=None, change_gadgets: bool=True) -> None: + +def max_cut(g: BaseGraph[VT,ET], vs0: Optional[Set[VT]]=None, vs1: Optional[Set[VT]]=None) -> Tuple[Set[VT],Set[VT]]: + """Approximate the MAX-CUT of a graph, starting with an initial partition + + This uses the quadratic-time SG3 heuristic explained by Wang et al in https://arxiv.org/abs/2312.10895 . + """ + if vs0 == None: vs0 = set() + if vs1 == None: vs1 = set() + # print(f'vs0={vs0} vs1={vs1}') + remaining = set(g.vertices()) - vs0 - vs1 + while len(remaining) > 0: + score_max = -1 + v_max: Optional[VT] = None + in0 = True + for v in remaining: + wt0 = sum(len(list(g.edges(v,w))) for w in vs1) + wt1 = sum(len(list(g.edges(v,w))) for w in vs0) + score = abs(wt0 - wt1) + if score > score_max: + # print(f'{v}: score={score}, wt0={wt0}, wt1={wt1}') + score_max = score + v_max = v + in0 = wt0 >= wt1 + # print(f'choosing {v_max} for set {"vs0" if in0 else "vs1"}') + + if v_max == None: raise RuntimeError("No max found") + remaining.remove(v_max) + if in0: vs0.add(v_max) + else: vs1.add(v_max) + return(vs0, vs1) + +def to_rg(g: BaseGraph[VT,ET], init_z: Optional[Set[VT]]=None, init_x: Optional[Set[VT]]=None) -> None: """Try to eliminate H-edges by turning green nodes red - By default, this does a breadth-first search starting at an arbitrary node, flipping the - color of alternating layers. For a ZX-diagram that is graph-like and 2-colorable, this will - eliminate all of the interior H-edges. - - Alternatively, the function `select` can be provided instructing the method where to flip colors. + This implements a quadratic-time max-cut heuristic to eliminate H-edges. In the future, we may want + to implement a linear-time version that does a worse job for very large graphs. :param g: A ZX-graph. - :param select: A function taking in vertices and returning ``True`` or ``False``. - :param change_gadgets: A flag saying always change gadgets to X.""" + :param init_z: An optional set of vertices to make Z. + :param init_x: An optional set of vertices to make X. + """ ty = g.types() - if select is None: - remaining = set() - for w in g.vertices(): - if change_gadgets and g.is_phase_gadget(w): - if vertex_is_zx(ty[w]): - g.set_type(w, toggle_vertex(ty[w])) - for e in g.incident_edges(w): - g.set_edge_type(e, toggle_edge(g.edge_type(e))) - else: - remaining.add(w) - while len(remaining) > 0: - v = next(iter(remaining)) - # if v is a boundary, set `flip` such that its adjacent edge will not be an H-edge afterwards - if ty[v] == VertexType.BOUNDARY and g.incident_edges(v)[0] == EdgeType.SIMPLE: - flip = True - else: - flip = False - nhd = set([v]) - while len(nhd) > 0: - for w in nhd: - if flip and vertex_is_zx(ty[w]): - g.set_type(w, toggle_vertex(ty[w])) - for e in g.incident_edges(w): - g.set_edge_type(e, toggle_edge(g.edge_type(e))) - flip = not flip - remaining -= nhd - nhd = set.union(*(set(g.neighbors(w)) for w in nhd)).intersection(remaining) - - else: - for v in g.vertices(): - if g.is_phase_gadget(v) and select(v) and vertex_is_zx(ty[v]): - g.set_type(v, toggle_vertex(ty[v])) - for e in g.incident_edges(v): - g.set_edge_type(e, toggle_edge(g.edge_type(e))) + vs0, vs1 = max_cut(g, init_z, init_x) + for v in vs0: + if ty[v] == VertexType.X: + g.set_type(v, VertexType.Z) + for e in g.incident_edges(v): + g.set_edge_type(e, toggle_edge(g.edge_type(e))) + for v in vs1: + if ty[v] == VertexType.Z: + g.set_type(v, VertexType.X) + for e in g.incident_edges(v): + g.set_edge_type(e, toggle_edge(g.edge_type(e))) def gadgetize(g: BaseGraph[VT,ET]): """Convert every non-Clifford phase to a phase gadget""" @@ -680,4 +686,5 @@ def to_clifford_normal_form_graph(g: BaseGraph[VT,ET]) -> None: for q1,q2 in czs: g.add_edge((cz_v[q1],cz_v[q2]),EdgeType.HADAMARD) - to_rg(g,select=lambda v: v in v_outputs) + # TODO: re-introduce correct to_rg behaviour here + #to_rg(g,select=lambda v: v in v_outputs)