diff --git a/openfermioncirq/primitives/swap_network.py b/openfermioncirq/primitives/swap_network.py index 87b15633..4d1394a0 100644 --- a/openfermioncirq/primitives/swap_network.py +++ b/openfermioncirq/primitives/swap_network.py @@ -12,7 +12,7 @@ """The linear swap network.""" -from typing import Callable, Sequence +from typing import Callable, Sequence, List import cirq @@ -21,10 +21,10 @@ def swap_network(qubits: Sequence[cirq.QubitId], operation: Callable[ - [int, int, cirq.QubitId, cirq.QubitId], cirq.OP_TREE]= - lambda p, q, p_qubit, q_qubit: (), + [int, int, cirq.QubitId, cirq.QubitId], cirq.OP_TREE + ] = lambda p, q, p_qubit, q_qubit: (), fermionic: bool=False, - offset: bool=False) -> cirq.OP_TREE: + offset: bool=False) -> List[cirq.Operation]: """Apply operations to pairs of qubits or modes using a swap network. This is used for applying operations between arbitrary pairs of qubits or @@ -109,11 +109,11 @@ def swap_network(qubits: Sequence[cirq.QubitId], Args: qubits: The qubits sorted so that the j-th qubit in the Sequence represents the j-th qubit or fermionic mode. - operation: A call to this function takes the form - ``operation(p, q, p_qubit, q_qubit)`` - where p and q are indices reprenting either qubits or fermionic - modes, and p_qubit and q_qubit are the qubits which represent them. - It returns the gate that should be applied to these qubits. + operation: Returns extra interactions to perform between qubits/modes as + they are swapped past each other. A call to this function takes the + form ``operation(p, q, p_qubit, q_qubit)`` where p and q are indices + representing either qubits or fermionic modes, and p_qubit and + q_qubit are the qubits which are currently storing those modes. fermionic: If True, use fermionic swaps under the JWT (that is, swap fermionic modes instead of qubits). If False, use normal qubit swaps. @@ -123,6 +123,7 @@ def swap_network(qubits: Sequence[cirq.QubitId], n_qubits = len(qubits) order = list(range(n_qubits)) swap_gate = FSWAP if fermionic else cirq.SWAP + result = [] for layer_num in range(n_qubits): lowest_active_qubit = (layer_num + offset) % 2 @@ -130,6 +131,9 @@ def swap_network(qubits: Sequence[cirq.QubitId], for i in range(lowest_active_qubit, n_qubits - 1, 2)) for i, j in active_pairs: p, q = order[i], order[j] - yield operation(p, q, qubits[i], qubits[j]) - yield swap_gate(qubits[i], qubits[j]) + extra_ops = operation(p, q, qubits[i], qubits[j]) + result.extend(cirq.flatten_op_tree(extra_ops)) + result.append(swap_gate(qubits[i], qubits[j])) order[i], order[j] = q, p + + return result diff --git a/openfermioncirq/primitives/swap_network_test.py b/openfermioncirq/primitives/swap_network_test.py index 7ba2b32f..74b4db39 100644 --- a/openfermioncirq/primitives/swap_network_test.py +++ b/openfermioncirq/primitives/swap_network_test.py @@ -94,3 +94,8 @@ def test_swap_network(): │ ×─× ×─× │ │ │ │ │ """.strip() + + +def test_reusable(): + ops = swap_network(cirq.LineQubit.range(5)) + assert list(ops) == list(ops)