Skip to content

Commit

Permalink
Promote gradient_kwargs to a positional keyword argument in QNode (
Browse files Browse the repository at this point in the history
…#6828)

**Context:**
`gradient_kwargs` is now a positional keyword argument for the `QNode`.
This means you can not simply express,
```python
qml.QNode(func, dev, h=1)
```
instead, you must deliberately,
```python
qml.QNode(func, dev, gradient_kwargs={"h":1})
```
This allows easier and cleaner input validation. 

This PR could have wide-spread impact as it is very common to just
specify `gradient_kwargs` casually as additional kwargs.

### **Eco-System**
- [x] Catalyst: PennyLaneAI/catalyst#1480
- [x] Lightning:
PennyLaneAI/pennylane-lightning#1045
- [x] QML Demos: No instances of deprecated code found.

### **Plugins**
- [x] Pennylane-AQT: No instances of deprecated code found.
- [x] Pennylane-Qiskit: No instances of deprecated code found.
- [x] Pennylane-IonQ: No instances of deprecated code found.
- [x] Pennylane-Qrack: No instances of deprecated code found.
- [x] Pennylane-Cirq: No instances of deprecated code found.
- [x] Pennylane-Qulacs: No instances of deprecated code found.

**Description of the Change:**

Allow additional kwargs for now to ensure same functionality, but raise
a deprecation warning. Append those additional kwargs to the internal
gradient_kwargs dictionary.

**Benefits:** Improved input validation for users.

**Possible Drawbacks:** Might have missed some eco-system changes.
Especially with CI **sometimes** not raising
`PennyLaneDeprecationWarning`s as errors 😢 .

[sc-81531]

---------

Co-authored-by: Christina Lee <christina@xanadu.ai>
  • Loading branch information
andrijapau and albi3ro authored Jan 23, 2025
1 parent 875ae11 commit 8a12fa5
Show file tree
Hide file tree
Showing 22 changed files with 434 additions and 260 deletions.
7 changes: 7 additions & 0 deletions doc/development/deprecations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ deprecations are listed below.
Pending deprecations
--------------------

* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated
and will be removed in v0.42. The gradient keyword arguments should be passed to the new
keyword argument ``gradient_kwargs`` via an explicit dictionary, like ``gradient_kwargs={"h": 1e-4}``.

- Deprecated in v0.41
- Will be removed in v0.42

* The `qml.gradients.hamiltonian_grad` function has been deprecated.
This gradient recipe is not required with the new operator arithmetic system.

Expand Down
6 changes: 6 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@

<h3>Deprecations 👋</h3>

* Specifying gradient keyword arguments as any additional keyword argument to the qnode is deprecated
and will be removed in v0.42. The gradient keyword arguments should be passed to the new
keyword argument `gradient_kwargs` via an explicit dictionary. This change will improve qnode argument
validation.
[(#6828)](https://github.com/PennyLaneAI/pennylane/pull/6828)

* The `qml.gradients.hamiltonian_grad` function has been deprecated.
This gradient recipe is not required with the new operator arithmetic system.
[(#6849)](https://github.com/PennyLaneAI/pennylane/pull/6849)
Expand Down
35 changes: 24 additions & 11 deletions pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,8 @@ def _validate_gradient_kwargs(gradient_kwargs: dict) -> None:
elif kwarg not in qml.gradients.SUPPORTED_GRADIENT_KWARGS:
warnings.warn(
f"Received gradient_kwarg {kwarg}, which is not included in the list of "
"standard qnode gradient kwargs."
"standard qnode gradient kwargs. Please specify all gradient kwargs through "
"the gradient_kwargs argument as a dictionary."
)


Expand Down Expand Up @@ -284,9 +285,7 @@ class QNode:
as the name suggests. If not provided,
the device will determine the best choice automatically. For usage details, please refer to the
:doc:`dynamic quantum circuits page </introduction/dynamic_quantum_circuits>`.
Keyword Args:
**kwargs: Any additional keyword arguments provided are passed to the differentiation
gradient_kwargs (dict): A dictionary of keyword arguments that are passed to the differentiation
method. Please refer to the :mod:`qml.gradients <.gradients>` module for details
on supported options for your chosen gradient transform.
Expand Down Expand Up @@ -505,10 +504,12 @@ def __init__(
device_vjp: Union[None, bool] = False,
postselect_mode: Literal[None, "hw-like", "fill-shots"] = None,
mcm_method: Literal[None, "deferred", "one-shot", "tree-traversal"] = None,
**gradient_kwargs,
gradient_kwargs: Optional[dict] = None,
**kwargs,
):
self._init_args = locals()
del self._init_args["self"]
del self._init_args["kwargs"]

if logger.isEnabledFor(logging.DEBUG):
logger.debug(
Expand Down Expand Up @@ -536,7 +537,16 @@ def __init__(
if not isinstance(device, qml.devices.Device):
device = qml.devices.LegacyDeviceFacade(device)

gradient_kwargs = gradient_kwargs or {}
if kwargs:
if any(k in qml.gradients.SUPPORTED_GRADIENT_KWARGS for k in list(kwargs.keys())):
warnings.warn(
f"Specifying gradient keyword arguments {list(kwargs.keys())} is deprecated and will be removed in v0.42. Instead, please specify all arguments in the gradient_kwargs argument.",
qml.PennyLaneDeprecationWarning,
)
gradient_kwargs |= kwargs
_validate_gradient_kwargs(gradient_kwargs)

if "shots" in inspect.signature(func).parameters:
warnings.warn(
"Detected 'shots' as an argument to the given quantum function. "
Expand Down Expand Up @@ -676,16 +686,19 @@ def circuit(x):
tensor(0.5403, dtype=torch.float64)
"""
if not kwargs:
valid_params = (
set(self._init_args.copy().pop("gradient_kwargs"))
| qml.gradients.SUPPORTED_GRADIENT_KWARGS
)
valid_params = set(self._init_args.copy()) | qml.gradients.SUPPORTED_GRADIENT_KWARGS
raise ValueError(
f"Must specify at least one configuration property to update. Valid properties are: {valid_params}."
)
original_init_args = self._init_args.copy()
gradient_kwargs = original_init_args.pop("gradient_kwargs")
original_init_args.update(gradient_kwargs)
# gradient_kwargs defaults to None
original_init_args["gradient_kwargs"] = original_init_args["gradient_kwargs"] or {}
# nested dictionary update
new_gradient_kwargs = kwargs.pop("gradient_kwargs", {})
old_gradient_kwargs = original_init_args.get("gradient_kwargs").copy()
old_gradient_kwargs.update(new_gradient_kwargs)
kwargs["gradient_kwargs"] = old_gradient_kwargs

original_init_args.update(kwargs)
updated_qn = QNode(**original_init_args)
# pylint: disable=protected-access
Expand Down
55 changes: 34 additions & 21 deletions tests/gradients/core/test_pulse_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,7 +1385,10 @@ def test_simple_qnode_expval(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1415,7 +1418,10 @@ def test_simple_qnode_expval_two_evolves(self, num_split_times, shots, tol, seed
ham_y = qml.pulse.constant * qml.PauliX(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_x)(params[0], T_x)
Expand Down Expand Up @@ -1444,7 +1450,10 @@ def test_simple_qnode_probs(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand All @@ -1471,7 +1480,10 @@ def test_simple_qnode_probs_expval(self, num_split_times, shots, tol, seed):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1504,7 +1516,10 @@ def test_simple_qnode_jit(self, num_split_times, time_interface):
ham_single_q_const = qml.pulse.constant * qml.PauliY(0)

@qml.qnode(
dev, interface="jax", diff_method=stoch_pulse_grad, num_split_times=num_split_times
dev,
interface="jax",
diff_method=stoch_pulse_grad,
gradient_kwargs={"num_split_times": num_split_times},
)
def circuit(params, T=None):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1543,8 +1558,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
sampler_seed=seed,
gradient_kwargs={"num_split_times": num_split_times, "sampler_seed": seed},
)
qnode_backprop = qml.QNode(ansatz, dev, interface="jax")

Expand Down Expand Up @@ -1575,8 +1589,7 @@ def test_qnode_probs_expval_broadcasting(self, num_split_times, shots, tol, seed
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
use_broadcasting=True,
gradient_kwargs={"num_split_times": num_split_times, "use_broadcasting": True},
)
def circuit(params):
qml.evolve(ham_single_q_const)(params, T)
Expand Down Expand Up @@ -1620,18 +1633,22 @@ def ansatz(params):
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
use_broadcasting=True,
sampler_seed=seed,
gradient_kwargs={
"num_split_times": num_split_times,
"use_broadcasting": True,
"sampler_seed": seed,
},
)
circuit_no_bc = qml.QNode(
ansatz,
dev,
interface="jax",
diff_method=stoch_pulse_grad,
num_split_times=num_split_times,
use_broadcasting=False,
sampler_seed=seed,
gradient_kwargs={
"num_split_times": num_split_times,
"use_broadcasting": False,
"sampler_seed": seed,
},
)
params = [jnp.array(0.4)]
jac_bc = jax.jacobian(circuit_bc)(params)
Expand Down Expand Up @@ -1685,9 +1702,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
num_split_times=7,
use_broadcasting=True,
sampler_seed=seed,
gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True},
)
cost_jax = qml.QNode(ansatz, dev, interface="jax")
params = (0.42,)
Expand Down Expand Up @@ -1730,9 +1745,7 @@ def ansatz(params):
dev,
interface="jax",
diff_method=qml.gradients.stoch_pulse_grad,
num_split_times=7,
use_broadcasting=True,
sampler_seed=seed,
gradient_kwargs={"num_split_times": 7, "sampler_seed": seed, "use_broadcasting": True},
)
cost_jax = qml.QNode(ansatz, dev, interface="jax")

Expand Down
2 changes: 1 addition & 1 deletion tests/gradients/finite_diff/test_spsa_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def test_invalid_sampler_rng(self):
"""Tests that if sampler_rng has an unexpected type, an error is raised."""
dev = qml.device("default.qubit", wires=1)

@qml.qnode(dev, diff_method="spsa", sampler_rng="foo")
@qml.qnode(dev, diff_method="spsa", gradient_kwargs={"sampler_rng": "foo"})
def circuit(param):
qml.RX(param, wires=0)
return qml.expval(qml.PauliZ(0))
Expand Down
6 changes: 5 additions & 1 deletion tests/gradients/parameter_shift/test_cv_gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,11 @@ def qf(x, y):

grad_F = jax.grad(qf)(*par)

@qml.qnode(device=gaussian_dev, diff_method="parameter-shift", force_order2=True)
@qml.qnode(
device=gaussian_dev,
diff_method="parameter-shift",
gradient_kwargs={"force_order2": True},
)
def qf2(x, y):
qml.Displacement(0.5, 0, wires=[0])
qml.Squeezing(x, 0, wires=[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1970,12 +1970,12 @@ def expval(self, observable, **kwargs):

dev = DeviceSupporingSpecialObservable(wires=1, shots=None)

@qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast)
@qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast})
def qnode(x):
qml.RY(x, wires=0)
return qml.expval(SpecialObservable(wires=0))

@qml.qnode(dev, diff_method="parameter-shift", broadcast=broadcast)
@qml.qnode(dev, diff_method="parameter-shift", gradient_kwargs={"broadcast": broadcast})
def reference_qnode(x):
qml.RY(x, wires=0)
return qml.expval(qml.PauliZ(wires=0))
Expand Down
13 changes: 11 additions & 2 deletions tests/resource/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,15 @@ class TestSpecsTransform:
"""Tests for the transform specs using the QNode"""

def sample_circuit(self):

@qml.transforms.merge_rotations
@qml.transforms.undo_swaps
@qml.transforms.cancel_inverses
@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4)
@qml.qnode(
qml.device("default.qubit"),
diff_method="parameter-shift",
gradient_kwargs={"shifts": pnp.pi / 4},
)
def circuit(x):
qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1))
qml.RX(x, wires=0)
Expand Down Expand Up @@ -222,7 +227,11 @@ def test_splitting_transforms(self):

@qml.transforms.split_non_commuting
@qml.transforms.merge_rotations
@qml.qnode(qml.device("default.qubit"), diff_method="parameter-shift", shifts=pnp.pi / 4)
@qml.qnode(
qml.device("default.qubit"),
diff_method="parameter-shift",
gradient_kwargs={"shifts": pnp.pi / 4},
)
def circuit(x):
qml.RandomLayers(qml.numpy.array([[1.0, 2.0]]), wires=(0, 1))
qml.RX(x, wires=0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def circuit(coeffs):
qml.MottonenStatePreparation(coeffs, wires=[0, 1])
return qml.probs(wires=[0, 1])

circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", h=0.05)
circuit_fd = qml.QNode(circuit, dev, diff_method="finite-diff", gradient_kwargs={"h": 0.05})
circuit_ps = qml.QNode(circuit, dev, diff_method="parameter-shift")
circuit_exact = qml.QNode(circuit, dev_no_shots)

Expand Down
Loading

0 comments on commit 8a12fa5

Please sign in to comment.