diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index 811efb2a..8fce9aa0 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -95,7 +95,6 @@ def __init__(self) -> None: self.phase_master: Optional['simplify.Simplifier'] = None self.phase_mult: Dict[int,Literal[1,-1]] = dict() self.max_phase_index: int = -1 - self._vdata: Dict[VT,Dict[str,Any]] = dict() # merge_vdata(v0,v1) is an optional, custom function for merging # vdata of v1 into v0 during spider fusion etc. @@ -215,52 +214,6 @@ def map_qubits(self, qubit_map:Mapping[int,Tuple[float,float]]) -> None: self.set_qubit(v, qf) self.set_row(v, rf) - - # def replace_subgraph(self, left_row: FloatInt, right_row: FloatInt, replace: BaseGraph[VT,ET]) -> None: - # """Deletes the subgraph of all nodes with rank strictly between ``left_row`` - # and ``right_row`` and replaces it with the graph ``replace``. - # The amount of nodes on the left row should match the amount of inputs of - # the replacement graph and the same for the right row and the outputs. - # The graphs are glued together based on the qubit index of the vertices.""" - - # qleft = [v for v in self.vertices() if self.row(v)==left_row] - # qright= [v for v in self.vertices() if self.row(v)==right_row] - # r_inputs = replace.inputs() - # r_outputs = replace.outputs() - # if len(qleft) != len(r_inputs): - # raise TypeError("Inputs do not match glueing vertices") - # if len(qright) != len(r_outputs): - # raise TypeError("Outputs do not match glueing vertices") - # if set(self.qubit(v) for v in qleft) != set(replace.qubit(v) for v in r_inputs): - # raise TypeError("Input qubit indices do not match") - # if set(self.qubit(v) for v in qright)!= set(replace.qubit(v) for v in r_outputs): - # raise TypeError("Output qubit indices do not match") - - # self.remove_vertices([v for v in self.vertices() if (left_row < self.row(v) and self.row(v) < right_row)]) - # self.remove_edges([self.edge(s,t) for s in qleft for t in qright if self.connected(s,t)]) - # rdepth = replace.depth() -1 - # for v in (v for v in self.vertices() if self.row(v)>=right_row): - # self.set_row(v, self.row(v)+rdepth) - - # vtab = {} - # for v in replace.vertices(): - # if v in r_inputs or v in r_outputs: continue - # vtab[v] = self.add_vertex(replace.type(v), - # replace.qubit(v), - # replace.row(v)+left_row, - # replace.phase(v), - # replace.is_ground(v)) - # for v in r_inputs: - # vtab[v] = [i for i in qleft if self.qubit(i) == replace.qubit(v)][0] - - # for v in r_outputs: - # vtab[v] = [i for i in qright if self.qubit(i) == replace.qubit(v)][0] - - # etab = {e:self.edge(vtab[replace.edge_s(e)],vtab[replace.edge_t(e)]) for e in replace.edges()} - # self.add_edges(etab.values()) - # for e,f in etab.items(): - # self.set_edge_type(f, replace.edge_type(e)) - def compose(self, other: BaseGraph[VT,ET]) -> None: """Inserts a graph after this one. The amount of qubits of the graphs must match. Also available by the operator `graph1 + graph2`""" @@ -302,7 +255,7 @@ def compose(self, other: BaseGraph[VT,ET]) -> None: qubit=other.qubit(v), row=offset + other.row(v), ground=other.is_ground(v)) - if v in other._vdata: self._vdata[w] = other._vdata[v] + self.set_vdata_dict(w, other.vdata_dict(v)) vtab[v] = w for e in other.edges(): s,t = other.edge_st(e) @@ -325,11 +278,10 @@ def tensor(self, other: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: height = max((self.qubits().values()), default=0) + 1 rs = other.rows() phases = other.phases() - vdata = other._vdata vertex_map = dict() for v in other.vertices(): w = g.add_vertex(ts[v],qs[v]+height,rs[v],phases[v],g.is_ground(v)) - if v in vdata: g._vdata[w] = vdata[v] + g.set_vdata_dict(w, other.vdata_dict(v)) vertex_map[v] = w for e in other.edges(): s,t = other.edge_st(e) @@ -964,6 +916,10 @@ def set_position(self, vertex: VT, q: FloatInt, r: FloatInt): self.set_qubit(vertex, q) self.set_row(vertex, r) + def clear_vdata(self, vertex: VT) -> None: + """Removes all vdata associated to a vertex""" + raise NotImplementedError("Not implemented on backend" + type(self).backend) + def vdata_keys(self, vertex: VT) -> Sequence[str]: """Returns an iterable of the vertex data key names. Used e.g. in making a copy of the graph in a backend-independent way.""" @@ -978,6 +934,14 @@ def set_vdata(self, vertex: VT, key: str, val: Any) -> None: """Sets the vertex data associated to key to val.""" raise NotImplementedError("Not implemented on backend" + type(self).backend) + def vdata_dict(self, vertex: VT) -> Dict[str, Any]: + return { key: self.vdata(vertex, key) for key in self.vdata_keys(vertex) } + + def set_vdata_dict(self, vertex: VT, d: Dict[str, Any]) -> None: + self.clear_vdata(vertex) + for k, v in d.items(): + self.set_vdata(vertex, k, v) + def is_well_formed(self) -> bool: """Returns whether the graph is a well-formed ZX-diagram. This means that it has no isolated boundary vertices, diff --git a/pyzx/graph/diff.py b/pyzx/graph/diff.py index f68713f8..830e8e24 100644 --- a/pyzx/graph/diff.py +++ b/pyzx/graph/diff.py @@ -61,7 +61,6 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None: self.new_edges.append((g2.edge_st(e), g2.edge_type(e))) for e in Counter(old_edges - new_edges).elements(): - s,t = g1.edge_st(e) self.removed_edges.append(e) for v in new_verts: @@ -70,8 +69,10 @@ def calculate_diff(self, g1: BaseGraph[VT,ET], g2: BaseGraph[VT,ET]) -> None: self.changed_vertex_types[v] = g2.type(v) if g1.phase(v) != g2.phase(v): self.changed_phases[v] = g2.phase(v) - if g1._vdata.get(v, None) != g2._vdata.get(v, None): - self.changed_vdata[v] = g2._vdata.get(v, None) + d1 = g1.vdata_dict(v) + d2 = g2.vdata_dict(v) + if d1 != d2: + self.changed_vdata[v] = d2 pos1 = g1.qubit(v), g1.row(v) pos2 = g2.qubit(v), g2.row(v) if pos1 != pos2: @@ -106,7 +107,7 @@ def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: if v in self.changed_phases: g.set_phase(v,self.changed_phases[v]) if v in self.changed_vdata: - g._vdata[v] = self.changed_vdata[v] + g.set_vdata_dict(v, self.changed_vdata[v]) for st, ty in self.new_edges: g.add_edge(st,ty) @@ -124,7 +125,7 @@ def apply_diff(self,g: BaseGraph[VT,ET]) -> BaseGraph[VT,ET]: for v in self.changed_vdata: if v in self.new_verts: continue - g._vdata[v] = self.changed_vdata[v] + g.set_vdata_dict(v, self.changed_vdata[v]) for e in self.changed_edge_types: g.set_edge_type(e,self.changed_edge_types[e]) diff --git a/pyzx/graph/graph_s.py b/pyzx/graph/graph_s.py index 39ce5443..379a75b7 100644 --- a/pyzx/graph/graph_s.py +++ b/pyzx/graph/graph_s.py @@ -337,6 +337,9 @@ def set_ground(self, vertex, flag=True): else: self._grounds.discard(vertex) + def clear_vdata(self, vertex): + if vertex in self._vdata: + del self._vdata[vertex] def vdata_keys(self, vertex): return self._vdata.get(vertex, {}).keys() def vdata(self, vertex, key, default=0): diff --git a/pyzx/graph/multigraph.py b/pyzx/graph/multigraph.py index fbf63d8a..951d14d6 100644 --- a/pyzx/graph/multigraph.py +++ b/pyzx/graph/multigraph.py @@ -412,6 +412,9 @@ def set_ground(self, vertex, flag=True): else: self._grounds.discard(vertex) + def clear_vdata(self, vertex): + if vertex in self._vdata: + del self._vdata[vertex] def vdata_keys(self, vertex): return self._vdata.get(vertex, {}).keys() def vdata(self, vertex, key, default=0):