Skip to content

Commit

Permalink
Improve performance for default.qubit.compute_vjp (#4841)
Browse files Browse the repository at this point in the history
[sc-46350]

For 10 qubits, 2 strongly entangling layers, and an expval on each wire:


Before these changes:

<img width="837" alt="Screenshot 2023-11-15 at 9 54 01 AM"
src="https://github.com/PennyLaneAI/pennylane/assets/6364575/f8f08f51-700b-4e2d-9952-bed95a2eac2a">


After these changes:

<img width="806" alt="Screenshot 2023-11-15 at 9 54 58 AM"
src="https://github.com/PennyLaneAI/pennylane/assets/6364575/f94bfdc1-b0f8-44e8-ae3e-d5c67bb8e0f4">
  • Loading branch information
albi3ro authored Nov 30, 2023
1 parent e043de2 commit 2de14a4
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 44 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@
`qml.QNode` or `qml.execute`.
[(#4557)](https://github.com/PennyLaneAI/pennylane/pull/4557)
[(#4654)](https://github.com/PennyLaneAI/pennylane/pull/4654)
[(#4841)](https://github.com/PennyLaneAI/pennylane/pull/4841)

```pycon
>>> dev = qml.device('default.qubit')
Expand Down
19 changes: 18 additions & 1 deletion pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def name(self):
"""The name of the device."""
return "default.qubit"

_state_cache: Optional[dict] = None
"""
A cache to store the "pre-rotated state" for reuse between the forward pass call to ``execute`` and
subsequent calls to ``compute_vjp``. ``None`` indicates that no caching is required.
"""

# pylint:disable = too-many-arguments
def __init__(
self,
Expand Down Expand Up @@ -469,6 +475,7 @@ def execute(
circuits = [circuits]

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._state_cache = {} if execution_config.use_device_jacobian_product else None
interface = (
execution_config.interface
if execution_config.gradient_method in {"backprop", None}
Expand All @@ -482,6 +489,7 @@ def execute(
prng_key=self._prng_key,
debugger=self._debugger,
interface=interface,
state_cache=self._state_cache,
)
for c in circuits
)
Expand Down Expand Up @@ -736,7 +744,16 @@ def compute_vjp(

max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
res = tuple(adjoint_vjp(circuit, cots) for circuit, cots in zip(circuits, cotangents))

def _state(circuit):
return (
None if self._state_cache is None else self._state_cache.get(circuit.hash, None)
)

res = tuple(
adjoint_vjp(circuit, cots, state=_state(circuit))
for circuit, cots in zip(circuits, cotangents)
)
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
Expand Down
18 changes: 14 additions & 4 deletions pennylane/devices/qubit/adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,20 @@ def adjoint_vjp(tape: QuantumTape, cotangents: Tuple[Number], state=None):

ket = state if state is not None else get_final_state(tape)[0]

if np.shape(cotangents) == tuple():
cotangents = (cotangents,)
obs = qml.dot(cotangents, tape.observables)
bra = apply_operation(obs, ket)
cotangents = (cotangents,) if qml.math.shape(cotangents) == tuple() else cotangents
new_cotangents, new_observables = [], []
for c, o in zip(cotangents, tape.observables):
if not np.allclose(c, 0.0):
new_cotangents.append(c)
new_observables.append(o)
if len(new_cotangents) == 0:
return tuple(0.0 for _ in tape.trainable_params)
obs = qml.dot(new_cotangents, new_observables)
if obs._pauli_rep is not None:
flat_bra = obs._pauli_rep.dot(ket.flatten(), wire_order=list(range(tape.num_wires)))
bra = flat_bra.reshape(ket.shape)
else:
bra = apply_operation(obs, ket)

param_number = len(tape.get_parameters(trainable_only=False, operations_only=True)) - 1
trainable_param_number = len(tape.trainable_params) - 1
Expand Down
13 changes: 12 additions & 1 deletion pennylane/devices/qubit/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# limitations under the License.
"""Simulate a quantum script."""
# pylint: disable=protected-access
from typing import Optional

from numpy.random import default_rng
import numpy as np

Expand Down Expand Up @@ -197,8 +199,14 @@ def measure_final_state(circuit, state, is_state_batched, rng=None, prng_key=Non
return results


# pylint: disable=too-many-arguments
def simulate(
circuit: qml.tape.QuantumScript, rng=None, prng_key=None, debugger=None, interface=None
circuit: qml.tape.QuantumScript,
rng=None,
prng_key=None,
debugger=None,
interface=None,
state_cache: Optional[dict] = None,
) -> Result:
"""Simulate a single quantum script.
Expand All @@ -214,6 +222,7 @@ def simulate(
generated. Only for simulation using JAX.
debugger (_Debugger): The debugger to use
interface (str): The machine learning interface to create the initial state with
state_cache=None (Optional[dict]): A dictionary mapping the hash of a circuit to the pre-rotated state. Used to pass the state between forward passes and vjp calculations.
Returns:
tuple(TensorLike): The results of the simulation
Expand All @@ -229,4 +238,6 @@ def simulate(
"""
state, is_state_batched = get_final_state(circuit, debugger=debugger, interface=interface)
if state_cache is not None:
state_cache[circuit.hash] = state
return measure_final_state(circuit, state, is_state_batched, rng=rng, prng_key=prng_key)
87 changes: 49 additions & 38 deletions tests/devices/qubit/test_adjoint_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class TestAdjointJacobian:
def test_custom_wire_labels(self, tol):
"""Test that adjoint_jacbonian works as expected when custom wire labels are used."""
qs = QuantumScript(
[qml.RX(0.123, wires="a"), qml.RY(0.456, wires="b")], [qml.expval(qml.PauliX("a"))]
[qml.RX(0.123, wires="a"), qml.RY(0.456, wires="b")],
[qml.expval(qml.PauliX("a"))],
trainable_params=[0, 1],
)
qs.trainable_params = {0, 1}

calculated_val = adjoint_jacobian(qs)

Expand All @@ -49,11 +50,11 @@ def test_pauli_rotation_gradient(self, G, theta, tol):

prep_op = qml.StatePrep(np.array([1.0, -1.0], requires_grad=False) / np.sqrt(2), wires=0)
qs = QuantumScript(
ops=[prep_op, G(theta, wires=[0])], measurements=[qml.expval(qml.PauliZ(0))]
ops=[prep_op, G(theta, wires=[0])],
measurements=[qml.expval(qml.PauliZ(0))],
trainable_params=[1],
)

qs.trainable_params = {1}

calculated_val = adjoint_jacobian(qs)
# compare to finite differences
tapes, fn = qml.gradients.finite_diff(qs)
Expand All @@ -72,9 +73,9 @@ def test_Rot_gradient(self, theta, tol):
qs = QuantumScript(
ops=[prep_op, qml.Rot(*params, wires=[0])],
measurements=[qml.expval(qml.PauliZ(0))],
trainable_params=[1, 2, 3],
)

qs.trainable_params = {1, 2, 3}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs = qs_valid[0]

Expand Down Expand Up @@ -110,8 +111,7 @@ def test_gradients(self, op, obs, tol):
]
measurements = [qml.expval(obs(wires=0)), qml.expval(qml.PauliZ(wires=1))]

qs = QuantumScript(ops, measurements)
qs.trainable_params = set(range(1, 1 + op.num_params))
qs = QuantumScript(ops, measurements, trainable_params=list(range(1, 1 + op.num_params)))

qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]
Expand Down Expand Up @@ -188,9 +188,9 @@ def test_gradient_gate_with_multiple_parameters(self, tol):
qs = QuantumScript(
[qml.RX(0.4, wires=[0]), qml.Rot(x, y, z, wires=[0]), qml.RY(-0.2, wires=[0])],
[qml.expval(qml.PauliZ(0))],
trainable_params=[1, 2, 3],
)

qs.trainable_params = {1, 2, 3}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand Down Expand Up @@ -218,9 +218,9 @@ def test_state_prep(self, prep_op, tol):
qs = QuantumScript(
[prep_op, qml.RX(0.4, wires=[0]), qml.Rot(x, y, z, wires=[0]), qml.RY(-0.2, wires=[0])],
[qml.expval(qml.PauliZ(0))],
trainable_params=[2, 3, 4],
)

qs.trainable_params = {2, 3, 4}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand Down Expand Up @@ -248,9 +248,9 @@ def test_gradient_of_tape_with_hermitian(self, tol):
qml.CNOT(wires=[1, 2]),
],
[qml.expval(qml.Hermitian(mx, wires=[0, 2]))],
trainable_params=[0, 1, 2],
)

qs.trainable_params = {0, 1, 2}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand Down Expand Up @@ -279,9 +279,9 @@ def test_gradient_of_tape_with_tensor(self, tol):
qml.CNOT(wires=[1, 2]),
],
[qml.expval(qml.PauliX(0) @ qml.PauliY(2))],
trainable_params=[0, 1, 2],
)

qs.trainable_params = {0, 1, 2}
qs_valid, _ = qml.devices.preprocess.decompose(qs, adjoint_ops)
qs_valid = qs_valid[0]

Expand All @@ -304,8 +304,7 @@ def test_with_nontrainable_parametrized(self):
qml.RY(par, wires=0),
qml.QubitUnitary(np.eye(2), wires=0),
]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))])
qs.trainable_params = [0]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))], trainable_params=[0])

grad_adjoint = adjoint_jacobian(qs)
expected = [-np.sin(par)]
Expand All @@ -319,8 +318,7 @@ class TestAdjointJVP:
def test_single_param_single_obs(self, tangents, tol):
"""Test JVP is correct for a single parameter and observable"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))])
qs.trainable_params = {0}
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0])

actual = adjoint_jvp(qs, tangents)

Expand All @@ -331,8 +329,11 @@ def test_single_param_single_obs(self, tangents, tol):
def test_single_param_multi_obs(self, tangents, tol):
"""Test JVP is correct for a single parameter and multiple observables"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))])
qs.trainable_params = {0}
qs = QuantumScript(
[qml.RY(x, 0)],
[qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))],
trainable_params=[0],
)

actual = adjoint_jvp(qs, tangents)
assert isinstance(actual, tuple)
Expand All @@ -347,8 +348,9 @@ def test_multi_param_single_obs(self, tangents, tol):
x = np.array(0.654)
y = np.array(1.221)

qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))])
qs.trainable_params = {0, 1}
qs = QuantumScript(
[qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))], trainable_params=[0, 1]
)

actual = adjoint_jvp(qs, tangents)

Expand All @@ -364,8 +366,7 @@ def test_multi_param_multi_obs(self, tangents, tol):
y = np.array(1.221)

obs = [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0)), qml.expval(qml.PauliY(0))]
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs, trainable_params=[0, 1])

actual = adjoint_jvp(qs, tangents)
assert isinstance(actual, tuple)
Expand Down Expand Up @@ -393,8 +394,7 @@ def test_custom_wire_labels(self, tangents, wires, tol):
qml.expval(qml.PauliY(wires[1])),
qml.expval(qml.PauliX(wires[0])),
]
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs, trainable_params=[0, 1])
assert qs.wires.tolist() == wires

actual = adjoint_jvp(qs, tangents)
Expand All @@ -416,8 +416,7 @@ def test_with_nontrainable_parametrized(self):
qml.RY(par, wires=0),
qml.QubitUnitary(np.eye(2), wires=0),
]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))])
qs.trainable_params = [0]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))], trainable_params=[0])

jvp_adjoint = adjoint_jvp(qs, tangents)
expected = [-np.sin(par) * tangents[0]]
Expand All @@ -431,8 +430,7 @@ class TestAdjointVJP:
def test_single_param_single_obs(self, cotangents, tol):
"""Test VJP is correct for a single parameter and observable"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))])
qs.trainable_params = {0}
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0))], trainable_params=[0])

actual = adjoint_vjp(qs, cotangents)

Expand All @@ -444,8 +442,11 @@ def test_single_param_single_obs(self, cotangents, tol):
def test_single_param_multi_obs(self, cotangents, tol):
"""Test VJP is correct for a single parameter and multiple observables"""
x = np.array(0.654)
qs = QuantumScript([qml.RY(x, 0)], [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))])
qs.trainable_params = {0}
qs = QuantumScript(
[qml.RY(x, 0)],
[qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0))],
trainable_params=[0],
)

actual = adjoint_vjp(qs, cotangents)

Expand All @@ -458,8 +459,9 @@ def test_multi_param_single_obs(self, cotangents, tol):
x = np.array(0.654)
y = np.array(1.221)

qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))])
qs.trainable_params = {0, 1}
qs = QuantumScript(
[qml.RY(x, 0), qml.RZ(y, 0)], [qml.expval(qml.PauliY(0))], trainable_params=[0, 1]
)

actual = adjoint_vjp(qs, cotangents)
assert isinstance(actual, tuple)
Expand All @@ -477,8 +479,7 @@ def test_multi_param_multi_obs(self, cotangents, tol):
y = np.array(1.221)

obs = [qml.expval(qml.PauliZ(0)), qml.expval(qml.PauliX(0)), qml.expval(qml.PauliY(0))]
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, 0), qml.RZ(y, 0)], obs, trainable_params=[0, 1])

actual = adjoint_vjp(qs, cotangents)
assert isinstance(actual, tuple)
Expand Down Expand Up @@ -508,8 +509,7 @@ def test_custom_wire_labels(self, cotangents, wires, tol):
qml.expval(qml.PauliY(wires[1])),
qml.expval(qml.PauliX(wires[0])),
]
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs)
qs.trainable_params = {0, 1}
qs = QuantumScript([qml.RY(x, wires[0]), qml.RX(y, wires[1])], obs, trainable_params=[0, 1])
assert qs.wires.tolist() == wires

actual = adjoint_vjp(qs, cotangents)
Expand All @@ -531,9 +531,20 @@ def test_with_nontrainable_parametrized(self):
qml.RY(par, wires=0),
qml.QubitUnitary(np.eye(2), wires=0),
]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))])
qs.trainable_params = [0]
qs = QuantumScript(ops, [qml.expval(qml.PauliZ(0))], trainable_params=[0])

vjp_adjoint = adjoint_vjp(qs, cotangents)
expected = [-np.sin(par) * cotangents[0]]
assert np.allclose(vjp_adjoint, expected)

def test_hermitian_expval(self):
"""Test adjoint_vjp works with a hermitian expectation value."""

x = 1.2
H = qml.Hermitian(np.array([[1, 0], [0, -1]]), wires=0)
cotangent = (0.5,)

qs = QuantumScript([qml.RX(x, wires=0)], [qml.expval(H)], trainable_params=[0])

[vjp_adjoint] = adjoint_vjp(qs, cotangent)
assert qml.math.allclose(vjp_adjoint, -0.5 * np.sin(x))

0 comments on commit 2de14a4

Please sign in to comment.