Skip to content

Commit

Permalink
fix interface detection with closure variables (#6892)
Browse files Browse the repository at this point in the history
**Context:**

We were not properly detecting the interface of a qnode when all
parameters were being passed as closure variables:
```
import jax
import numpy as np

V = jax.numpy.array([[ 0.53672126+0.j        , -0.1126064 -2.41479668j],
              [-0.1126064 +2.41479668j,  1.48694623+0.j        ]])
eigen_vals, eigen_vecs = jax.numpy.linalg.eigh(V)
umat = eigen_vecs.T
wires = range(len(umat))

@jax.jit
@qml.qnode(qml.device("lightning.qubit", wires = wires))
def circuit():
   qml.BasisRotation(wires=wires, unitary_matrix=umat)
   return qml.state()

print(circuit())
```

Since with the qnode, we detect the interface from what it gets called
with, we did not detect that the `umat` was a jax variable, and that
it's decomposition would produce tracers.

**Description of the Change:**

We stop detecting the interface from what the qnode is called with. We
only detect the interface from what is on the tape.

We also detecting jitting by creating a new array and seeing if it a
tracer. Closure variables are not tracers until they are transformed, so
detecting jitting with unmodified closure variables will not work.

**Benefits:**

Better use of closure variables.

**Possible Drawbacks:**

**Related GitHub Issues:**

---------

Co-authored-by: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com>
  • Loading branch information
albi3ro and andrijapau authored Jan 29, 2025
1 parent 04bfe4d commit 58d4f4f
Show file tree
Hide file tree
Showing 18 changed files with 78 additions and 115 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@

<h3>Bug fixes 🐛</h3>

* The interface is now detected from the data in the circuit, not the arguments to the `QNode`. This allows
interface data to be strictly passed as closure variables and still be detected.
[(#6892)](https://github.com/PennyLaneAI/pennylane/pull/6892)

* `BasisState` now casts its input to integers.
[(#6844)](https://github.com/PennyLaneAI/pennylane/pull/6844)

Expand Down
2 changes: 2 additions & 0 deletions pennylane/devices/reference_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def simulate(tape: qml.tape.QuantumTape, seed=None) -> qml.typing.Result:
# 2) apply all the operations
for op in tape.operations:
op_mat = op.matrix(wire_order=tape.wires)
if qml.math.get_interface(op_mat) != "numpy":
raise ValueError("Reference qubit can only work with numpy data.")
state = qml.math.matmul(op_mat, state)

# 3) perform measurements
Expand Down
11 changes: 2 additions & 9 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,22 +808,14 @@ def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
# construct the tape
tape = self.construct(args, kwargs)

if self.interface == "auto":
interface = qml.math.get_interface(*args, *list(kwargs.values()))
try:
interface = get_canonical_interface_name(interface)
except ValueError:
interface = Interface.NUMPY
else:
interface = self.interface
# Calculate the classical jacobians if necessary
self._transform_program.set_classical_component(self, args, kwargs)

res = qml.execute(
(tape,),
device=self.device,
diff_method=self.diff_method,
interface=interface,
interface=self.interface,
transform_program=self._transform_program,
gradient_kwargs=self.gradient_kwargs,
**self.execute_kwargs,
Expand All @@ -835,6 +827,7 @@ def _impl_call(self, *args, **kwargs) -> qml.typing.Result:
if (
len(tape.get_parameters(trainable_only=False)) == 0
and not self._transform_program.is_informative
and self.interface != "auto"
):
res = _convert_to_interface(res, qml.math.get_canonical_interface_name(self.interface))

Expand Down
38 changes: 6 additions & 32 deletions pennylane/workflow/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,38 +35,12 @@
]


def _get_jax_interface_name(tapes):
"""Check all parameters in each tape and output the name of the suitable
JAX interface.
This function checks each tape and determines if any of the gate parameters
was transformed by a JAX transform such as ``jax.jit``. If so, it outputs
the name of the JAX interface with jit support.
Note that determining if jit support should be turned on is done by
checking if parameters are abstract. Parameters can be abstract not just
for ``jax.jit``, but for other JAX transforms (vmap, pmap, etc.) too. The
reason is that JAX doesn't have a public API for checking whether or not
the execution is within the jit transform.
Args:
tapes (Sequence[.QuantumTape]): batch of tapes to execute
Returns:
str: name of JAX interface that fits the tape parameters, "jax" or
"jax-jit"
def _get_jax_interface_name() -> Interface:
"""Check if we are in a jitting context by creating a dummy array and seeing if it's
abstract.
"""
for t in tapes:
for op in t:
# Unwrap the observable from a MeasurementProcess
if not isinstance(op, qml.ops.Prod):
op = getattr(op, "obs", op)
if op is not None:
# Some MeasurementProcess objects have op.obs=None
if any(qml.math.is_abstract(param) for param in op.data):
return Interface.JAX_JIT

return Interface.JAX
x = qml.math.asarray([0], like="jax")
return Interface.JAX_JIT if qml.math.is_abstract(x) else Interface.JAX


# pylint: disable=import-outside-toplevel
Expand Down Expand Up @@ -118,7 +92,7 @@ def _resolve_interface(interface: Union[str, Interface], tapes: QuantumScriptBat
"version of jax to enable the 'jax' interface." # pragma: no cover
) from e # pragma: no cover

interface = _get_jax_interface_name(tapes)
interface = _get_jax_interface_name()

return interface

Expand Down
40 changes: 40 additions & 0 deletions tests/devices/test_reference_qubit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2018-2025 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests for reference qubit.
"""

import pytest

import pennylane as qml


@pytest.mark.parametrize(
"interface",
(
pytest.param("autograd", marks=pytest.mark.autograd),
pytest.param("jax", marks=pytest.mark.jax),
pytest.param("torch", marks=pytest.mark.torch),
pytest.param("tensorflow", marks=pytest.mark.tf),
),
)
def test_error_on_non_numpy_data(interface):
"""Test that an error is thrown in the interface data is not numpy."""

x = qml.math.asarray(0.5, like=interface)
tape = qml.tape.QuantumScript([qml.RX(x, 0)], [qml.expval(qml.Z(0))])
dev = qml.device("reference.qubit", wires=1)

with pytest.raises(ValueError, match="Reference qubit can only work with numpy data."):
dev.execute(tape)
2 changes: 1 addition & 1 deletion tests/logging/test_logging_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def circuit(params):
"pennylane.workflow.execution",
[
"device=<default.qubit device (wires=2)",
"diff_method=None, interface=Interface.AUTOGRAD",
"diff_method=None, interface=auto",
],
),
(
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/execute/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,11 +241,7 @@ def cost(a):
res = qml.jacobian(cost)(a)
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/execute/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,7 @@ def cost(a):
if not shots.has_partitioned_shots:
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/execute/test_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,7 @@ def cost(a):
res = tape.jacobian(cost_res, a, experimental_use_pfor=not device_vjp)
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/execute/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,7 @@ def cost(a):
if not shots.has_partitioned_shots:
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
if shots.has_partitioned_shots:
Expand Down
22 changes: 0 additions & 22 deletions tests/workflow/interfaces/qnode/test_autograd_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,28 +80,6 @@
class TestQNode:
"""Test that using the QNode with Autograd integrates with the PennyLane stack"""

# pylint:disable=unused-argument
def test_execution_no_interface(
self, interface, dev, diff_method, grad_on_execution, device_vjp
):
"""Test execution works without an interface"""
if diff_method == "backprop":
pytest.skip("Test does not support backprop")

@qnode(dev, interface=None)
def circuit(a):
qml.RY(a, wires=0)
qml.RX(0.2, wires=0)
return qml.expval(qml.PauliZ(0))

a = np.array(0.1, requires_grad=True)

res = circuit(a)

# without the interface, the QNode simply returns a scalar array or float
assert isinstance(res, (np.ndarray, float))
assert qml.math.shape(res) == tuple() # pylint: disable=comparison-with-callable

def test_execution_with_interface(
self, interface, dev, diff_method, grad_on_execution, device_vjp
):
Expand Down
13 changes: 13 additions & 0 deletions tests/workflow/interfaces/qnode/test_jax_jit_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -3219,3 +3219,16 @@ def circuit(x):
finally:
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_enable_x64", True)


def test_no_inputs_jitting():
"""Test that if we jit a qnode with no inputs, we can still detect the jitting and proper interface."""

@jax.jit
@qml.qnode(qml.device("reference.qubit", wires=1))
def circuit():
qml.StatePrep(jax.numpy.array([1, 0]), 0)
return qml.state()

res = circuit()
assert qml.math.allclose(res, jax.numpy.array([1, 0]))
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/run/test_autograd_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ def cost(a):
res = qml.jacobian(cost)(a)
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = qml.gradients.param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/run/test_jax_jit_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,7 @@ def cost(a):
if not shots.has_partitioned_shots:
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = qml.gradients.param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/run/test_jax_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,7 @@ def cost(a):
if not shots.has_partitioned_shots:
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = qml.gradients.param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,7 @@ def cost(a):
res = tape.jacobian(cost_res, a, experimental_use_pfor=not device_vjp)
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = qml.gradients.param_shift(tape)
expected = fn(device.execute(tapes))

expected = -qml.math.sin(a)
assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
assert np.allclose(res, -tf.sin(a), atol=atol_for_shots(shots))
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/run/test_tensorflow_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,7 @@ def cost(a):
res = tape.jacobian(cost_res, a, experimental_use_pfor=not device_vjp)
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = qml.gradients.param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
assert np.allclose(res, expected, atol=atol_for_shots(shots), rtol=0)
Expand Down
6 changes: 1 addition & 5 deletions tests/workflow/interfaces/run/test_torch_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,7 @@ def cost(a):
if not shots.has_partitioned_shots:
assert res.shape == () # pylint: disable=no-member

# compare to standard tape jacobian
tape = qml.tape.QuantumScript([qml.RY(a, wires=0)], [qml.expval(qml.PauliZ(0))])
tape.trainable_params = [0]
tapes, fn = qml.gradients.param_shift(tape)
expected = fn(device.execute(tapes))
expected = -qml.math.sin(a)

assert expected.shape == ()
if shots.has_partitioned_shots:
Expand Down

0 comments on commit 58d4f4f

Please sign in to comment.