diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 9d88744ff53..c1aeee48214 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -96,6 +96,9 @@ allowing error types to be more consistent with the context the `decompose` function is used in. [(#5669)](https://github.com/PennyLaneAI/pennylane/pull/5669) +* The `qml.pytrees` module now has `flatten` and `unflatten` methods for serializing pytrees. + [(#5701)](https://github.com/PennyLaneAI/pennylane/pull/5701) + * Empty initialization of `PauliVSpace` is permitted. [(#5675)](https://github.com/PennyLaneAI/pennylane/pull/5675) diff --git a/pennylane/pytrees.py b/pennylane/pytrees.py index ddc82a66e19..1952d2d0d91 100644 --- a/pennylane/pytrees.py +++ b/pennylane/pytrees.py @@ -1,4 +1,4 @@ -# Copyright 2018-2023 Xanadu Quantum Technologies Inc. +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,8 +14,8 @@ """ An internal module for working with pytrees. """ - -from typing import Any, Callable, Tuple +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Tuple has_jax = True try: @@ -30,6 +30,57 @@ UnflattenFn = Callable[[Leaves, Metadata], Any] +def flatten_list(obj: list): + """Flatten a list.""" + return obj, None + + +def flatten_tuple(obj: tuple): + """Flatten a tuple.""" + return obj, None + + +def flatten_dict(obj: dict): + """Flatten a dictionary.""" + return obj.values(), tuple(obj.keys()) + + +flatten_registrations: dict[type, FlattenFn] = { + list: flatten_list, + tuple: flatten_tuple, + dict: flatten_dict, +} + + +def unflatten_list(data, _) -> list: + """Unflatten a list.""" + return data if isinstance(data, list) else list(data) + + +def unflatten_tuple(data, _) -> tuple: + """Unflatten a tuple.""" + return tuple(data) + + +def unflatten_dict(data, metadata) -> dict: + """Unflatten a dictinoary.""" + return dict(zip(metadata, data)) + + +unflatten_registrations: dict[type, UnflattenFn] = { + list: unflatten_list, + tuple: unflatten_tuple, + dict: unflatten_dict, +} + + +def _register_pytree_with_pennylane( + pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn +): + flatten_registrations[pytree_type] = flatten_fn + unflatten_registrations[pytree_type] = unflatten_fn + + def _register_pytree_with_jax(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn): def jax_unflatten(aux, parameters): return unflatten_fn(parameters, aux) @@ -40,7 +91,8 @@ def jax_unflatten(aux, parameters): def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: UnflattenFn): """Register a type with all available pytree backends. - Current backends is jax. + Current backends are jax and pennylane. + Args: pytree_type (type): the type to register, such as ``qml.RX`` flatten_fn (Callable): a function that splits an object into trainable leaves and hashable metadata. @@ -52,7 +104,115 @@ def register_pytree(pytree_type: type, flatten_fn: FlattenFn, unflatten_fn: Unfl Side Effects: ``pytree`` type becomes registered with available backends. + .. seealso:: :func:`~.flatten`, :func:`~.unflatten`. + """ + _register_pytree_with_pennylane(pytree_type, flatten_fn, unflatten_fn) + if has_jax: _register_pytree_with_jax(pytree_type, flatten_fn, unflatten_fn) + + +@dataclass(repr=False) +class PyTreeStructure: + """A pytree data structure, holding the type, metadata, and child pytree structures. + + >>> op = qml.adjoint(qml.RX(0.1, 0)) + >>> data, structure = qml.pytrees.flatten(op) + >>> structure + PyTree(AdjointOperation, (), [PyTree(RX, (, ()), [Leaf])]) + + A leaf is defined as just a ``PyTreeStructure`` with ``type=None``. + """ + + type: Optional[type] = None + """The type corresponding to the node. If ``None``, then the structure is a leaf.""" + + metadata: Metadata = () + """Any metadata needed to reproduce the original object.""" + + children: list["PyTreeStructure"] = field(default_factory=list) + """The children of the pytree node. Can be either other structures or terminal leaves.""" + + @property + def is_leaf(self) -> bool: + """Whether or not the structure is a leaf.""" + return self.type is None + + def __repr__(self): + if self.is_leaf: + return "PyTreeStructure()" + return f"PyTreeStructure({self.type.__name__}, {self.metadata}, {self.children})" + + def __str__(self): + if self.is_leaf: + return "Leaf" + children_string = ", ".join(str(c) for c in self.children) + return f"PyTree({self.type.__name__}, {self.metadata}, [{children_string}])" + + +leaf = PyTreeStructure(None, (), []) + + +def flatten(obj) -> tuple[list[Any], PyTreeStructure]: + """Flattens a pytree into leaves and a structure. + + Args: + obj (Any): any object + + Returns: + List[Any], Union[Structure, Leaf]: a list of leaves and a structure representing the object + + >>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0)) + >>> data, structure = flatten(op) + >>> data + [1.2, 2.3, 3.4] + >>> structure + , ()), (Leaf, Leaf, Leaf))>,))> + + See also :function:`~.unflatten`. + + """ + flatten_fn = flatten_registrations.get(type(obj), None) + if flatten_fn is None: + return [obj], leaf + leaves, metadata = flatten_fn(obj) + + flattened_leaves = [] + child_structures = [] + for l in leaves: + child_leaves, child_structure = flatten(l) + flattened_leaves += child_leaves + child_structures.append(child_structure) + + structure = PyTreeStructure(type(obj), metadata, child_structures) + return flattened_leaves, structure + + +def unflatten(data: List[Any], structure: PyTreeStructure) -> Any: + """Bind data to a structure to reconstruct a pytree object. + + Args: + data (Iterable): iterable of numbers and numeric arrays + structure (Structure, Leaf): The pytree structure object + + Returns: + A repacked pytree. + + .. see-also:: :function:`~.flatten` + + >>> op = qml.adjoint(qml.Rot(1.2, 2.3, 3.4, wires=0)) + >>> data, structure = flatten(op) + >>> unflatten([-2, -3, -4], structure) + Adjoint(Rot(-2, -3, -4, wires=[0])) + + """ + return _unflatten(iter(data), structure) + + +def _unflatten(new_data, structure): + if structure.is_leaf: + return next(new_data) + children = tuple(_unflatten(new_data, s) for s in structure.children) + return unflatten_registrations[structure.type](children, structure.metadata) diff --git a/tests/test_pytrees.py b/tests/test_pytrees.py new file mode 100644 index 00000000000..c2829c4385f --- /dev/null +++ b/tests/test_pytrees.py @@ -0,0 +1,144 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for the pennylane pytrees module +""" +import pennylane as qml +from pennylane.pytrees import PyTreeStructure, flatten, leaf, register_pytree, unflatten + + +def test_structure_repr_str(): + """Test the repr of the structure class.""" + op = qml.RX(0.1, wires=0) + _, structure = qml.pytrees.flatten(op) + expected = "PyTreeStructure(RX, (, ()), [PyTreeStructure()])" + assert repr(structure) == expected + expected_str = "PyTree(RX, (, ()), [Leaf])" + assert str(structure) == expected_str + + +def test_register_new_class(): + """Test that new objects can be registered, flattened, and unflattened.""" + + # pylint: disable=too-few-public-methods + class MyObj: + """a dummy object.""" + + def __init__(self, a): + self.a = a + + def obj_flatten(obj): + return (obj.a,), None + + def obj_unflatten(data, _): + return MyObj(data[0]) + + register_pytree(MyObj, obj_flatten, obj_unflatten) + + obj = MyObj(0.5) + + data, structure = flatten(obj) + assert data == [0.5] + assert structure == PyTreeStructure(MyObj, None, [leaf]) + + new_obj = unflatten([1.0], structure) + assert isinstance(new_obj, MyObj) + assert new_obj.a == 1.0 + + +def test_list(): + """Test that pennylane treats list as a pytree.""" + + x = [1, 2, [3, 4]] + + data, structure = flatten(x) + assert data == [1, 2, 3, 4] + assert structure == PyTreeStructure( + list, None, [leaf, leaf, PyTreeStructure(list, None, [leaf, leaf])] + ) + + new_x = unflatten([5, 6, 7, 8], structure) + assert new_x == [5, 6, [7, 8]] + + +def test_tuple(): + """Test that pennylane can handle tuples as pytrees.""" + x = (1, 2, (3, 4)) + + data, structure = flatten(x) + assert data == [1, 2, 3, 4] + assert structure == PyTreeStructure( + tuple, None, [leaf, leaf, PyTreeStructure(tuple, None, [leaf, leaf])] + ) + + new_x = unflatten([5, 6, 7, 8], structure) + assert new_x == (5, 6, (7, 8)) + + +def test_dict(): + """Test that pennylane can handle dictionaries as pytees.""" + + x = {"a": 1, "b": {"c": 2, "d": 3}} + + data, structure = flatten(x) + assert data == [1, 2, 3] + assert structure == PyTreeStructure( + dict, ("a", "b"), [leaf, PyTreeStructure(dict, ("c", "d"), [leaf, leaf])] + ) + new_x = unflatten([5, 6, 7], structure) + assert new_x == {"a": 5, "b": {"c": 6, "d": 7}} + + +def test_nested_pl_object(): + """Test that we can flatten and unflatten nested pennylane object.""" + + tape = qml.tape.QuantumScript( + [qml.adjoint(qml.RX(0.1, wires=0))], + [qml.expval(2 * qml.X(0))], + shots=50, + trainable_params=(0, 1), + ) + + data, structure = flatten(tape) + assert data == [0.1, 2, None] + + wires0 = qml.wires.Wires(0) + op_structure = PyTreeStructure( + tape[0].__class__, (), [PyTreeStructure(qml.RX, (wires0, ()), [leaf])] + ) + list_op_struct = PyTreeStructure(list, None, [op_structure]) + + sprod_structure = PyTreeStructure( + qml.ops.SProd, (), [leaf, PyTreeStructure(qml.X, (wires0, ()), [])] + ) + meas_structure = PyTreeStructure( + qml.measurements.ExpectationMP, (("wires", None),), [sprod_structure, leaf] + ) + list_meas_struct = PyTreeStructure(list, None, [meas_structure]) + tape_structure = PyTreeStructure( + qml.tape.QuantumScript, + (tape.shots, tape.trainable_params), + [list_op_struct, list_meas_struct], + ) + + assert structure == tape_structure + + new_tape = unflatten([3, 4, None], structure) + expected_new_tape = qml.tape.QuantumScript( + [qml.adjoint(qml.RX(3, wires=0))], + [qml.expval(4 * qml.X(0))], + shots=50, + trainable_params=(0, 1), + ) + assert qml.equal(new_tape, expected_new_tape)