Skip to content
This repository has been archived by the owner on Mar 15, 2022. It is now read-only.

Commit

Permalink
Sync changes from cirq (#307)
Browse files Browse the repository at this point in the history
- Update diagram tests
- Use cirq.testing.assert_same_diagram
- _canonical_exponent_period -> _period
- exponent canonicalization only happens when equating
- Fix broken `_apply_unitary_to_tensor_` method on combined double excitation gate
  • Loading branch information
Strilanc authored Nov 13, 2018
1 parent 7b94038 commit b493c53
Show file tree
Hide file tree
Showing 13 changed files with 889 additions and 886 deletions.
16 changes: 8 additions & 8 deletions openfermioncirq/gates/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _eigen_components(self):
[0, 0, 0, 1]])),
]

def _canonical_exponent_period(self) -> Optional[float]:
def _period(self) -> Optional[float]:
return 2

def _with_exponent(self,
Expand Down Expand Up @@ -84,7 +84,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs
symbols = 'fswap', 'fswap'
return cirq.CircuitDiagramInfo(
wire_symbols=symbols,
exponent=self.half_turns)
exponent=self._diagram_exponent(args))

def __str__(self) -> str:
if self.half_turns == 1:
Expand Down Expand Up @@ -180,7 +180,7 @@ def _eigen_components(self):
[0, 0, 0, 0]]))
]

def _canonical_exponent_period(self) -> Optional[float]:
def _period(self) -> Optional[float]:
return 4

def _apply_unitary_to_tensor_(self,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs
) -> cirq.CircuitDiagramInfo:
return cirq.CircuitDiagramInfo(
wire_symbols=('XXYY', 'XXYY'),
exponent=self.half_turns)
exponent=self._diagram_exponent(args))

def __repr__(self):
if self.half_turns == 1:
Expand Down Expand Up @@ -316,7 +316,7 @@ def _apply_unitary_to_tensor_(self,
slices=[zo, oz],
out=available_buffer)

def _canonical_exponent_period(self) -> Optional[float]:
def _period(self) -> Optional[float]:
return 4

def _with_exponent(self, exponent: Union[cirq.Symbol, float]) -> 'YXXYGate':
Expand All @@ -332,7 +332,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs
) -> cirq.CircuitDiagramInfo:
return cirq.CircuitDiagramInfo(
wire_symbols=('YXXY', '#2'),
exponent=self.half_turns)
exponent=self._diagram_exponent(args))

def __repr__(self):
if self.half_turns == 1:
Expand Down Expand Up @@ -435,7 +435,7 @@ def _eigen_components(self):
(0.5, np.diag([0, 1, 1, 0])),
]

def _canonical_exponent_period(self) -> Optional[float]:
def _period(self) -> Optional[float]:
return 2

def _with_exponent(self,
Expand All @@ -446,7 +446,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs
) -> cirq.CircuitDiagramInfo:
return cirq.CircuitDiagramInfo(
wire_symbols=('Z', 'Z'),
exponent=self.half_turns)
exponent=self._diagram_exponent(args))

def __repr__(self) -> str:
if self.half_turns == 1:
Expand Down
26 changes: 12 additions & 14 deletions openfermioncirq/gates/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def test_fswap_matrix():
[0, 0.5-0.5j, 0.5+0.5j, 0],
[0, 0, 0, 1j]]))

cirq.testing.assert_apply_unitary_to_tensor_is_consistent_with_unitary(
cirq.testing.assert_has_consistent_apply_unitary_for_various_exponents(
val=ofc.FSWAP,
exponents=[1, -0.5, 0.5, 0.25, -0.25, 0.1, cirq.Symbol('s')])


def test_xxyy_init():
assert ofc.XXYYGate(half_turns=0.5).half_turns == 0.5
assert ofc.XXYYGate(half_turns=1.5).half_turns == 1.5
assert ofc.XXYYGate(half_turns=5).half_turns == 1
assert ofc.XXYYGate(half_turns=5).half_turns == 5


def test_xxyy_init_with_multiple_args_fails():
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_xxyy_decompose(half_turns):


def test_xxyy_matrix():
cirq.testing.assert_apply_unitary_to_tensor_is_consistent_with_unitary(
cirq.testing.assert_has_consistent_apply_unitary_for_various_exponents(
ofc.XXYY,
exponents=[1, -0.5, 0.5, 0.25, -0.25, 0.1, cirq.Symbol('s')])

Expand Down Expand Up @@ -151,7 +151,7 @@ def test_xxyy_matrix():
def test_yxxy_init():
assert ofc.YXXYGate(half_turns=0.5).half_turns == 0.5
assert ofc.YXXYGate(half_turns=1.5).half_turns == 1.5
assert ofc.YXXYGate(half_turns=5).half_turns == 1
assert ofc.YXXYGate(half_turns=5).half_turns == 5


def test_yxxy_init_with_multiple_args_fails():
Expand Down Expand Up @@ -190,7 +190,7 @@ def test_yxxy_decompose(half_turns):


def test_yxxy_matrix():
cirq.testing.assert_apply_unitary_to_tensor_is_consistent_with_unitary(
cirq.testing.assert_has_consistent_apply_unitary_for_various_exponents(
ofc.YXXY,
exponents=[1, -0.5, 0.5, 0.25, -0.25, 0.1, cirq.Symbol('s')])

Expand Down Expand Up @@ -233,8 +233,8 @@ def test_yxxy_matrix():

def test_zz_init():
assert ofc.ZZGate(half_turns=0.5).half_turns == 0.5
assert ofc.ZZGate(half_turns=1.5).half_turns == -0.5
assert ofc.ZZGate(half_turns=5).half_turns == 1
assert ofc.ZZGate(half_turns=1.5).half_turns == 1.5
assert ofc.ZZGate(half_turns=5).half_turns == 5


def test_zz_init_with_multiple_args_fails():
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_zz_repr():


def test_zz_matrix():
cirq.testing.assert_apply_unitary_to_tensor_is_consistent_with_unitary(
cirq.testing.assert_has_consistent_apply_unitary_for_various_exponents(
ofc.ZZ,
exponents=[1, -0.5, 0.5, 0.25, -0.25, 0.1, cirq.Symbol('s')])

Expand Down Expand Up @@ -336,13 +336,11 @@ def test_zz_matrix():
])
def test_two_qubit_rotation_gates_on_simulator(
gate, half_turns, initial_state, correct_state, atol):
simulator = cirq.google.XmonSimulator()
a, b = cirq.LineQubit.range(2)
circuit = cirq.Circuit.from_ops(gate(a, b)**half_turns)
initial_state = initial_state.astype(numpy.complex64)
result = simulator.simulate(circuit, initial_state=initial_state)
result = circuit.apply_unitary_effect_to_state(initial_state)
cirq.testing.assert_allclose_up_to_global_phase(
result.final_state, correct_state, atol=atol)
result, correct_state, atol=atol)


def test_common_gate_text_diagrams():
Expand All @@ -361,11 +359,11 @@ def test_common_gate_text_diagrams():
b: ───×ᶠ───×ᶠ^0.5───XXYY───#2─────Z───
""")

assert circuit.to_text_diagram(use_unicode_characters=False).strip() == """
cirq.testing.assert_has_diagram(circuit, """
a: ---fswap---fswap-------XXYY---YXXY---Z---
| | | | |
b: ---fswap---fswap^0.5---XXYY---#2-----Z---
""".strip()
""", use_unicode_characters=False)

circuit = cirq.Circuit.from_ops(
ofc.XXYY(a, b)**0.5,
Expand Down
42 changes: 23 additions & 19 deletions openfermioncirq/gates/four_qubit_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def state_swap_eigen_component(x: str, y: str, sign: int = 1):
└ ┘
Args:
x, y: The states to swap, as bitstrings.
x: The first state to swap, as a bitstring.
y: The second state to swap, as a bitstring.
sign: The sign of the off-diagonal elements (indicated by +/-1).
Returns: The eigen-component.
Expand Down Expand Up @@ -136,7 +137,7 @@ def _apply_unitary_to_tensor_(self,
slices=[a, b],
out=available_buffer)

def _canonical_exponent_period(self) -> Optional[float]:
def _period(self) -> Optional[float]:
return 2

def _with_exponent(self,
Expand Down Expand Up @@ -187,8 +188,9 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs
'/\\ \/',
'\/ /\\',
'\/ /\\')
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols,
exponent=self.half_turns)
return cirq.CircuitDiagramInfo(
wire_symbols=wire_symbols,
exponent=self._diagram_exponent(args))

def __repr__(self):
if self.half_turns == 1:
Expand Down Expand Up @@ -262,7 +264,7 @@ def _eigen_components(self):
# projector onto subspace spanned by basis states with
# Hamming weight != 2
zero_component = np.diag([int(bin(i).count('1') != 2)
for i in range(16)])
for i in range(16)])

state_pairs = (('1001', '0110'),
('0101', '1010'),
Expand All @@ -271,12 +273,12 @@ def _eigen_components(self):
plus_minus_components = tuple(
(weight * sign / 2,
state_swap_eigen_component(state_pair[0], state_pair[1], sign))
for weight, state_pair in zip(self.weights, state_pairs)
for sign in (-1, 1))
for weight, state_pair in zip(self.weights, state_pairs)
for sign in (-1, 1))

return ((0, zero_component),) + plus_minus_components

def _canonical_exponent_period(self) -> Optional[float]:
def _period(self) -> Optional[float]:
return None

def _with_exponent(self,
Expand Down Expand Up @@ -339,7 +341,7 @@ def _circuit_diagram_info_(self, args: cirq.CircuitDiagramInfoArgs
else:
wire_symbols = ('a*a*aa',) * 4
return cirq.CircuitDiagramInfo(wire_symbols=wire_symbols,
exponent=self.half_turns)
exponent=self._diagram_exponent(args))

def absorb_exponent_into_weights(self):
self.weights = tuple((w * self._exponent) % 4 for w in self.weights)
Expand All @@ -352,26 +354,28 @@ def _apply_unitary_to_tensor_(self,
) -> Union[np.ndarray, NotImplementedType]:
if cirq.is_parameterized(self):
return NotImplemented
inner_matrix = cirq.unitary(cirq.Rx(-np.pi*self.half_turns))
am = cirq.unitary(cirq.Rx(-np.pi * self.half_turns * self.weights[0]))
bm = cirq.unitary(cirq.Rx(-np.pi * self.half_turns * self.weights[1]))
cm = cirq.unitary(cirq.Rx(-np.pi * self.half_turns * self.weights[2]))

a1 = cirq.slice_for_qubits_equal_to(axes, 0b0011)
b1 = cirq.slice_for_qubits_equal_to(axes, 0b1001)
c1 = cirq.slice_for_qubits_equal_to(axes, 0b0101)
a1 = cirq.slice_for_qubits_equal_to(axes, 0b1001)
b1 = cirq.slice_for_qubits_equal_to(axes, 0b0101)
c1 = cirq.slice_for_qubits_equal_to(axes, 0b0011)

a2 = cirq.slice_for_qubits_equal_to(axes, 0b1100)
b2 = cirq.slice_for_qubits_equal_to(axes, 0b0110)
c2 = cirq.slice_for_qubits_equal_to(axes, 0b1010)
a2 = cirq.slice_for_qubits_equal_to(axes, 0b0110)
b2 = cirq.slice_for_qubits_equal_to(axes, 0b1010)
c2 = cirq.slice_for_qubits_equal_to(axes, 0b1100)

cirq.apply_matrix_to_slices(target_tensor,
inner_matrix,
am,
slices=[a1, a2],
out=available_buffer)
cirq.apply_matrix_to_slices(available_buffer,
inner_matrix,
bm,
slices=[b1, b2],
out=target_tensor)
return cirq.apply_matrix_to_slices(target_tensor,
inner_matrix,
cm,
slices=[c1, c2],
out=available_buffer)

Expand Down
Loading

0 comments on commit b493c53

Please sign in to comment.