diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 8f6cfec4..c90d0c5b 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -142,9 +142,13 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph': graph did not. """ from .graph import Graph # imported here to prevent circularity + from .multigraph import Multigraph if (backend is None): backend = type(self).backend g = Graph(backend = backend) + if isinstance(self, Multigraph) and isinstance(g, Multigraph): + g.set_auto_simplify(self._auto_simplify) # type: ignore + # mypy issue https://github.com/python/mypy/issues/16413 g.track_phases = self.track_phases g.scalar = self.scalar.copy(conjugate=adjoint) g.merge_vdata = self.merge_vdata @@ -390,14 +394,19 @@ def merge(self, other: 'BaseGraph') -> Tuple[List[VT],List[ET]]: def subgraph_from_vertices(self,verts: List[VT]) -> 'BaseGraph': """Returns the subgraph consisting of the specified vertices.""" from .graph import Graph # imported here to prevent circularity + from .multigraph import Multigraph g = Graph(backend=type(self).backend) + if isinstance(self, Multigraph) and isinstance(g, Multigraph): + g.set_auto_simplify(self._auto_simplify) # type: ignore + # mypy issue https://github.com/python/mypy/issues/16413 ty = self.types() rs = self.rows() qs = self.qubits() phase = self.phases() grounds = self.grounds() - edges = [self.edge(v,w) for v in verts for w in verts if self.connected(v,w)] + edges = [e for e in self.edges() \ + if self.edge_st(e)[0] in verts and self.edge_st(e)[1] in verts] vert_map = dict() for v in verts: diff --git a/pyzx/graph/diff.py b/pyzx/graph/diff.py index 5f22057a..512c51c6 100644 --- a/pyzx/graph/diff.py +++ b/pyzx/graph/diff.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json +from collections import Counter from typing import Any, Callable, Generic, Optional, List, Dict, Tuple -import copy from ..utils import VertexType, EdgeType, FractionLike, FloatInt, phase_to_s from .base import BaseGraph, VT, ET @@ -56,12 +57,11 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None: self.new_edges = [] self.removed_edges = [] - for e in (new_edges - old_edges): + for e in Counter(new_edges - old_edges).elements(): self.new_edges.append((g2.edge_st(e), g2.edge_type(e))) - for e in (old_edges - new_edges): + for e in Counter(old_edges - new_edges).elements(): s,t = g1.edge_st(e) - if s in self.removed_verts or t in self.removed_verts: continue self.removed_edges.append(e) for v in new_verts: @@ -94,8 +94,8 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None: def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: g = copy.deepcopy(g) - g.remove_vertices(self.removed_verts) g.remove_edges(self.removed_edges) + g.remove_vertices(self.removed_verts) for v in self.new_verts: g.add_vertex_indexed(v) g.set_position(v,*self.changed_pos[v]) diff --git a/pyzx/graph/multigraph.py b/pyzx/graph/multigraph.py index 7a2541ce..617b683e 100644 --- a/pyzx/graph/multigraph.py +++ b/pyzx/graph/multigraph.py @@ -102,11 +102,11 @@ def clone(self) -> 'Multigraph': cpy.phase_mult = self.phase_mult.copy() cpy.max_phase_index = self.max_phase_index return cpy - + def set_auto_simplify(self, s: bool): """Automatically remove parallel edges as edges are added""" self._auto_simplify = s - + def multigraph(self): return False