Skip to content

Commit

Permalink
lazy-load tape and operator batch sizes (#4911)
Browse files Browse the repository at this point in the history
**Context:**
We compute operator batch size and initialization, and it's oftentimes
unnecessary. We should only compute it when it's requested

**Description of the Change:**
Only compute operator/tape batch size when they are requested.
Some details:
- inlined the batch-size computation for `ScalarSymbolicOp`
- removed the `_check_batching` overrides in controlled and adjoint
because they are not needed (codecov complained). They both inherit from
`SymbolicOp`, and that doesn't call `super().__init__` so it was never
computing it anyway

**Benefits:**
- Will run expensive computations less frequently
- the Tensorflow issue when pre-computing batch sizes has disappeared!

**Possible Drawbacks:**
the `_check_batching` function did some data validation as well.
operators initialized with poorly-shaped data would fail right away, but
now we'll only see the error when the batch size is requested

**Related GitHub Issues:**
[sc-45972]

---------

Co-authored-by: David Wierichs <david.wierichs@xanadu.ai>
  • Loading branch information
timmysilv and dwierichs authored Dec 5, 2023
1 parent 1b0da44 commit ac0a6b7
Show file tree
Hide file tree
Showing 18 changed files with 149 additions and 101 deletions.
3 changes: 3 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@
* `qml.ArbitraryUnitary` now supports batching.
[(#4745)](https://github.com/PennyLaneAI/pennylane/pull/4745)

* Operator and tape batch sizes are evaluated lazily.
[(#4911)](https://github.com/PennyLaneAI/pennylane/pull/4911)

<h4>Performance improvements and benchmarking</h4>

* `default.qubit` no longer uses a dense matrix for `MultiControlledX` for more than 8 operation wires.
Expand Down
4 changes: 2 additions & 2 deletions pennylane/_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ def _local_tape_expand(tape, depth, stop_at):
# Update circuit info
new_tape.wires = copy.copy(tape.wires)
new_tape.num_wires = tape.num_wires
new_tape._batch_size = tape.batch_size
new_tape._output_dim = tape.output_dim
new_tape._batch_size = tape._batch_size
new_tape._output_dim = tape._output_dim
return new_tape


Expand Down
41 changes: 16 additions & 25 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@

SUPPORTED_INTERFACES = {"numpy", "scipy", "autograd", "torch", "tensorflow", "jax"}
__use_new_opmath = False
_UNSET_BATCH_SIZE = -1 # indicates that the (lazy) batch size has not yet been accessed/computed


class OperatorPropertyUndefined(Exception):
Expand Down Expand Up @@ -662,10 +663,10 @@ def compute_decomposition(theta, wires):
the ``_check_batching`` method, by comparing the shape of the input data to
the expected shape. Therefore, it is necessary to call ``_check_batching`` on
any new input parameters passed to the operator. By default, any class inheriting
from :class:`~.operation.Operator` will do so within its ``__init__`` method, and
other objects may do so when updating the data of the ``Operator``.
from :class:`~.operation.Operator` will do so the first time its
``batch_size`` property is accessed.
``_check_batching`` modifies the following class attributes:
``_check_batching`` modifies the following instance attributes:
- ``_ndim_params``: The number of dimensions of the parameters passed to
``_check_batching``. For an ``Operator`` that does _not_ set the ``ndim_params``
Expand All @@ -678,9 +679,9 @@ def compute_decomposition(theta, wires):
not support broadcasting will report to not be broadcasted independently of the
input.
Both attributes are not defined if ``_check_batching`` is not called. Therefore it
*needs to be called* within custom ``__init__`` implementations, either directly
or by calling ``Operator.__init__``.
These two properties are defined lazily, and accessing the public version of either
one of them (in other words, without the leading underscore) for the first time will
trigger a call to ``_check_batching``, which validates and sets these properties.
"""
# pylint: disable=too-many-public-methods, too-many-instance-attributes

Expand Down Expand Up @@ -1076,39 +1077,25 @@ def __init__(self, *params, wires=None, id=None):
f"{len(self._wires)} wires given, {self.num_wires} expected."
)

self._check_batching(params)
self._batch_size = _UNSET_BATCH_SIZE
self._ndim_params = _UNSET_BATCH_SIZE

self.data = tuple(np.array(p) if isinstance(p, (list, tuple)) else p for p in params)

self.queue()

def _check_batching(self, params):
def _check_batching(self):
"""Check if the expected numbers of dimensions of parameters coincides with the
ones received and sets the ``_batch_size`` attribute.
Args:
params (tuple): Parameters with which the operator is instantiated
The check always passes and sets the ``_batch_size`` to ``None`` for the default
``Operator.ndim_params`` property but subclasses may overwrite it to define fixed
expected numbers of dimensions, allowing to infer a batch size.
"""
self._batch_size = None
params = self.data

try:
ndims = tuple(qml.math.ndim(p) for p in params)
except ValueError as e:
# TODO:[dwierichs] When using tf.function with an input_signature that contains
# an unknown-shaped input, ndim() will not be able to determine the number of
# dimensions because they are not specified yet. Failing example: Let `fun` be
# a single-parameter QNode.
# `tf.function(fun, input_signature=(tf.TensorSpec(shape=None, dtype=tf.float32),))`
# There might be a way to support batching nonetheless, which remains to be
# investigated. For now, the batch_size is left to be `None` when instantiating
# an operation with abstract parameters that make `qml.math.ndim` fail.
if any(qml.math.is_abstract(p) for p in params):
return
raise e
ndims = tuple(qml.math.ndim(p) for p in params)

if any(len(qml.math.shape(p)) >= 1 and qml.math.shape(p)[0] is None for p in params):
# if the batch dimension is unknown, then skip the validation
Expand Down Expand Up @@ -1192,6 +1179,8 @@ def ndim_params(self):
Returns:
tuple: Number of dimensions for each trainable parameter.
"""
if self._batch_size is _UNSET_BATCH_SIZE:
self._check_batching()
return self._ndim_params

@property
Expand All @@ -1206,6 +1195,8 @@ def batch_size(self):
Returns:
int or None: Size of the parameter broadcasting dimension if present, else ``None``.
"""
if self._batch_size is _UNSET_BATCH_SIZE:
self._check_batching()
return self._batch_size

@property
Expand Down
4 changes: 0 additions & 4 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,10 +288,6 @@ def __init__(self, base=None, id=None):
def __repr__(self):
return f"Adjoint({self.base})"

# pylint: disable=protected-access
def _check_batching(self, params):
self.base._check_batching(params)

@property
def ndim_params(self):
return self.base.ndim_params
Expand Down
6 changes: 3 additions & 3 deletions pennylane/ops/op_math/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import pennylane as qml
from pennylane import math
from pennylane.operation import Operator
from pennylane.operation import Operator, _UNSET_BATCH_SIZE
from pennylane.wires import Wires

# pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -67,9 +67,9 @@ def __init__(self, *operands: Operator, id=None): # pylint: disable=super-init-
self._overlapping_ops = None
self._pauli_rep = self._build_pauli_rep()
self.queue()
self._check_batching(None) # unused param
self._batch_size = _UNSET_BATCH_SIZE

def _check_batching(self, _):
def _check_batching(self):
batch_sizes = {op.batch_size for op in self if op.batch_size is not None}
if len(batch_sizes) > 1:
raise ValueError(
Expand Down
4 changes: 0 additions & 4 deletions pennylane/ops/op_math/controlled.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,6 @@ def hash(self):
def has_matrix(self):
return self.base.has_matrix

# pylint: disable=protected-access
def _check_batching(self, params):
self.base._check_batching(params)

@property
def batch_size(self):
return self.base.batch_size
Expand Down
32 changes: 16 additions & 16 deletions pennylane/ops/op_math/symbolicop.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import numpy as np

import pennylane as qml
from pennylane.operation import Operator
from pennylane.operation import Operator, _UNSET_BATCH_SIZE
from pennylane.queuing import QueuingManager


Expand Down Expand Up @@ -166,10 +166,24 @@ class ScalarSymbolicOp(SymbolicOp):
def __init__(self, base, scalar: float, id=None):
self.scalar = np.array(scalar) if isinstance(scalar, list) else scalar
super().__init__(base, id=id)
self._batch_size = self._check_and_compute_batch_size(scalar)
self._batch_size = _UNSET_BATCH_SIZE

@property
def batch_size(self):
if self._batch_size is _UNSET_BATCH_SIZE:
base_batch_size = self.base.batch_size
if qml.math.ndim(self.scalar) == 0:
# coeff is not batched
self._batch_size = base_batch_size
else:
# coeff is batched
scalar_size = qml.math.size(self.scalar)
if base_batch_size is not None and base_batch_size != scalar_size:
raise ValueError(
"Broadcasting was attempted but the broadcasted dimensions "
f"do not match: {scalar_size}, {base_batch_size}."
)
self._batch_size = scalar_size
return self._batch_size

@property
Expand All @@ -181,20 +195,6 @@ def data(self, new_data):
self.scalar = new_data[0]
self.base.data = new_data[1:]

def _check_and_compute_batch_size(self, scalar):
batched_scalar = qml.math.ndim(scalar) > 0
scalar_size = qml.math.size(scalar)
if not batched_scalar:
# coeff is not batched
return self.base.batch_size
# coeff is batched
if self.base.batch_size is not None and self.base.batch_size != scalar_size:
raise ValueError(
"Broadcasting was attempted but the broadcasted dimensions "
f"do not match: {scalar_size}, {self.base.batch_size}."
)
return scalar_size

@property
def has_matrix(self):
return self.base.has_matrix or isinstance(self.base, qml.Hamiltonian)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/ops/qubit/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def __init__(
# while H.coeffs is the original tensor
super().__init__(*coeffs_flat, wires=self._wires, id=id)

def _check_batching(self, params):
def _check_batching(self):
"""Override for Hamiltonian, batching is not yet supported."""

def label(self, decimals=None, base_label=None, cache=None):
Expand Down
18 changes: 9 additions & 9 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
Shots,
)
from pennylane.typing import TensorLike
from pennylane.operation import Observable, Operator, Operation
from pennylane.operation import Observable, Operator, Operation, _UNSET_BATCH_SIZE
from pennylane.pytrees import register_pytree
from pennylane.queuing import AnnotatedQueue, process_queue
from pennylane.wires import Wires
Expand Down Expand Up @@ -198,8 +198,8 @@ def __init__(
self._trainable_params = trainable_params
self._graph = None
self._specs = None
self._output_dim = 0
self._batch_size = None
self._output_dim = None
self._batch_size = _UNSET_BATCH_SIZE

self.wires = _empty_wires
self.num_wires = 0
Expand Down Expand Up @@ -347,11 +347,15 @@ def batch_size(self):
Returns:
int or None: The batch size of the quantum script if present, else ``None``.
"""
if self._batch_size is _UNSET_BATCH_SIZE:
self._update_batch_size()
return self._batch_size

@property
def output_dim(self):
"""The (inferred) output dimension of the quantum script."""
if self._output_dim is None:
self._update_output_dim() # this will set _batch_size if it isn't already
return self._output_dim

@property
Expand Down Expand Up @@ -434,10 +438,6 @@ def _update(self):
self._update_par_info() # Updates _par_info; O(ops+obs)

self._update_observables() # Updates _obs_sharing_wires and _obs_sharing_wires_id
self._update_batch_size() # Updates _batch_size; O(ops)

# The following line requires _batch_size to be up to date
self._update_output_dim() # Updates _output_dim; O(obs)

def _update_circuit_info(self):
"""Update circuit metadata
Expand Down Expand Up @@ -890,8 +890,8 @@ def copy(self, copy_operations=False):
new_qscript._update_par_info()
new_qscript._obs_sharing_wires = self._obs_sharing_wires
new_qscript._obs_sharing_wires_id = self._obs_sharing_wires_id
new_qscript._batch_size = self.batch_size
new_qscript._output_dim = self.output_dim
new_qscript._batch_size = self._batch_size
new_qscript._output_dim = self._output_dim

return new_qscript

Expand Down
8 changes: 4 additions & 4 deletions pennylane/tape/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ def stop_at(obj): # pylint: disable=unused-argument
# Update circuit info
new_tape.wires = copy.copy(tape.wires)
new_tape.num_wires = tape.num_wires
new_tape._batch_size = tape.batch_size
new_tape._output_dim = tape.output_dim
new_tape._batch_size = tape._batch_size
new_tape._output_dim = tape._output_dim
return new_tape


Expand Down Expand Up @@ -278,8 +278,8 @@ def expand_tape_state_prep(tape, skip_first=True):
# Update circuit info
new_tape.wires = copy.copy(tape.wires)
new_tape.num_wires = tape.num_wires
new_tape._batch_size = tape.batch_size
new_tape._output_dim = tape.output_dim
new_tape._batch_size = tape._batch_size
new_tape._output_dim = tape._output_dim
return new_tape


Expand Down
5 changes: 3 additions & 2 deletions tests/devices/qubit/test_apply_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np
from scipy.stats import unitary_group
import pennylane as qml
from pennylane.operation import _UNSET_BATCH_SIZE, Operation


from pennylane.devices.qubit.apply_operation import (
Expand Down Expand Up @@ -51,7 +52,7 @@ def test_custom_operator_with_matrix():
)

# pylint: disable=too-few-public-methods
class CustomOp(qml.operation.Operation):
class CustomOp(Operation):
num_wires = 1

def matrix(self):
Expand Down Expand Up @@ -800,7 +801,7 @@ def test_batch_size_set_if_missing(self, method, ml_framework):
param = qml.math.asarray([0.1, 0.2, 0.3], like=ml_framework)
state = np.ones((2, 2)) / 2
op = qml.RX(param, 0)
op._batch_size = None # pylint:disable=protected-access
assert op._batch_size is _UNSET_BATCH_SIZE # pylint:disable=protected-access
state = method(op, state)
assert state.shape == (3, 2, 2)
assert op.batch_size == 3
Expand Down
3 changes: 2 additions & 1 deletion tests/ops/op_math/test_composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,11 @@ def test_batch_size_not_all_batched(self):
def test_different_batch_sizes_raises_error(self):
"""Test that an error is raised if the operands have different batch sizes."""
base = qml.RX(np.array([1.2, 2.3, 3.4]), 0)
op = ValidOp(base, qml.RY(1, 0), qml.RZ(np.array([1, 2, 3, 4]), wires=2))
with pytest.raises(
ValueError, match="Broadcasting was attempted but the broadcasted dimensions"
):
_ = ValidOp(base, qml.RY(1, 0), qml.RZ(np.array([1, 2, 3, 4]), wires=2))
_ = op.batch_size

def test_decomposition_raises_error(self):
"""Test that calling decomposition() raises a ValueError."""
Expand Down
3 changes: 2 additions & 1 deletion tests/ops/op_math/test_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,11 @@ def test_batching_properties(self):
def test_different_batch_sizes_raises_error(self):
"""Test that using different batch sizes for base and scalar raises an error."""
base = qml.RX(np.array([1.2, 2.3, 3.4]), 0)
op = Exp(base, np.array([0.1, 1.2, 2.3, 3.4]))
with pytest.raises(
ValueError, match="Broadcasting was attempted but the broadcasted dimensions"
):
_ = Exp(base, np.array([0.1, 1.2, 2.3, 3.4]))
_ = op.batch_size


class TestMatrix:
Expand Down
3 changes: 2 additions & 1 deletion tests/ops/op_math/test_pow_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,10 +374,11 @@ def test_batching_properties(self, power_method):
def test_different_batch_sizes_raises_error(self, power_method):
"""Test that using different batch sizes for base and scalar raises an error."""
base = qml.RX(np.array([1.2, 2.3, 3.4]), 0)
op = power_method(base, np.array([0.1, 1.2, 2.3, 3.4]))
with pytest.raises(
ValueError, match="Broadcasting was attempted but the broadcasted dimensions"
):
_ = power_method(base, np.array([0.1, 1.2, 2.3, 3.4]))
_ = op.batch_size

op_pauli_reps = (
(qml.PauliZ(wires=0), 1, qml.pauli.PauliSentence({qml.pauli.PauliWord({0: "Z"}): 1})),
Expand Down
3 changes: 2 additions & 1 deletion tests/ops/op_math/test_sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,11 @@ def test_batching_properties(self):
def test_different_batch_sizes_raises_error(self):
"""Test that using different batch sizes for base and scalar raises an error."""
base = qml.RX(np.array([1.2, 2.3, 3.4]), 0)
op = qml.s_prod(np.array([0.1, 1.2, 2.3, 3.4]), base)
with pytest.raises(
ValueError, match="Broadcasting was attempted but the broadcasted dimensions"
):
_ = qml.s_prod(np.array([0.1, 1.2, 2.3, 3.4]), base)
_ = op.batch_size


class TestSimplify:
Expand Down
Loading

0 comments on commit ac0a6b7

Please sign in to comment.