diff --git a/pyzx/tensor.py b/pyzx/tensor.py index 81a23b96..37b4edef 100644 --- a/pyzx/tensor.py +++ b/pyzx/tensor.py @@ -132,25 +132,25 @@ def tensorfy(g: 'BaseGraph[VT,ET]', preserve_scalar:bool=True) -> np.ndarray: neigh = list(g.neighbors(v)) d = len(neigh) if v in inputs: - if types[v] != 0: raise ValueError("Wrong type for input:", v, types[v]) + if types[v] != VertexType.BOUNDARY: raise ValueError("Wrong type for input:", v, types[v]) continue # inputs already taken care of if v in outputs: if d != 1: raise ValueError("Weird output") - if types[v] != 0: raise ValueError("Wrong type for output:",v, types[v]) + if types[v] != VertexType.BOUNDARY: raise ValueError("Wrong type for output:",v, types[v]) d += 1 t = id2 else: phase = pi*phases[v] - if types[v] == 1: + if types[v] == VertexType.Z: t = Z_to_tensor(d,phase) - elif types[v] == 2: + elif types[v] == VertexType.X: t = X_to_tensor(d,phase) - elif types[v] == 3: + elif types[v] == VertexType.H_BOX: t = H_to_tensor(d,phase) - elif types[v] == 4 or types[v] == 5: + elif types[v] == VertexType.W_INPUT or types[v] == VertexType.W_OUTPUT: if phase != 0: raise ValueError("Phase on W node") t = W_to_tensor(d) - elif types[v] == 6: + elif types[v] == VertexType.Z_BOX: if phase != 0: raise ValueError("Phase on Z box") label = get_z_box_label(g, v) t = Z_box_to_tensor(d, label)