Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Capture] unitary_to_rot is plxpr compatible #6916

Merged
merged 42 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
52d94fc
initial commit
andrijapau Feb 3, 2025
8081adc
add tests
andrijapau Feb 3, 2025
c1c3b6b
add test
andrijapau Feb 3, 2025
b2df9c6
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 3, 2025
23596a1
adjust test_unitary_to_rot_plxpr_to_plxpr
andrijapau Feb 3, 2025
6abfc6e
changelog
andrijapau Feb 3, 2025
297fce9
changelog
andrijapau Feb 3, 2025
4f45f4d
update tests
andrijapau Feb 3, 2025
09aabf5
update tests
andrijapau Feb 3, 2025
1d49179
update tests
andrijapau Feb 3, 2025
ef782f7
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 3, 2025
bb5c487
Merge branch 'unitary-to-rot-interpreter' of github.com:PennyLaneAI/p…
andrijapau Feb 3, 2025
7a6d7c0
update tests
andrijapau Feb 3, 2025
6ac7eee
add higher order prim tests
andrijapau Feb 4, 2025
11cbba8
add more higher order prim tests
andrijapau Feb 4, 2025
b325412
add more higher order prim tests
andrijapau Feb 4, 2025
127973d
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 4, 2025
8c46bcd
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 5, 2025
2c82d12
use qml.capture.pause()
andrijapau Feb 5, 2025
2e88b53
add execution to qnode integration tests
andrijapau Feb 5, 2025
bceb306
seperate jac and grad tests
andrijapau Feb 6, 2025
89b4401
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 6, 2025
9141090
Update pennylane/transforms/unitary_to_rot.py
andrijapau Feb 7, 2025
0959aa5
Update pennylane/transforms/unitary_to_rot.py
andrijapau Feb 7, 2025
6126f9b
Update pennylane/transforms/unitary_to_rot.py
andrijapau Feb 7, 2025
f3e51a3
Update pennylane/transforms/unitary_to_rot.py
andrijapau Feb 7, 2025
e0d0d3a
black
andrijapau Feb 7, 2025
3c7cc0f
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 7, 2025
8b474e9
Update tests/capture/transforms/test_capture_unitary_to_rot.py
andrijapau Feb 7, 2025
3dc5147
Update tests/capture/transforms/test_capture_unitary_to_rot.py
andrijapau Feb 7, 2025
566ecea
replace TODOs in tests
andrijapau Feb 7, 2025
9a8659c
use for loop to clean up tests
andrijapau Feb 7, 2025
52ddb00
fix: Update to just use interpret_operation
andrijapau Feb 7, 2025
ebe1d28
Update pennylane/transforms/unitary_to_rot.py
andrijapau Feb 7, 2025
d5e8632
fix mudits suggestion
andrijapau Feb 7, 2025
ad3c9f5
Update pennylane/transforms/unitary_to_rot.py
andrijapau Feb 7, 2025
5094fa6
add expand_plxpr_transforms tests
andrijapau Feb 10, 2025
f437a93
add three qubit case lol
andrijapau Feb 10, 2025
dbef1fd
Update tests/capture/transforms/test_capture_unitary_to_rot.py
andrijapau Feb 10, 2025
2af57a5
Update doc/releases/changelog-dev.md
andrijapau Feb 10, 2025
00e460e
Merge branch 'master' into unitary-to-rot-interpreter
andrijapau Feb 11, 2025
ea4a582
improve tests
andrijapau Feb 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

<h3>New features since last release</h3>

* Added class `qml.capture.transforms.UnitaryToRotInterpreter` that decomposes pennylane operators
following the same API as `qml.transforms.unitary_to_rot` when experimental program capture is enabled.
[(#6916)](https://github.com/PennyLaneAI/pennylane/pull/6916)
mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
andrijapau marked this conversation as resolved.
Show resolved Hide resolved

<h3>Improvements 🛠</h3>

* Add a `qml.capture.pause()` context manager for pausing program capture in an error-safe way.
Expand Down
65 changes: 64 additions & 1 deletion pennylane/transforms/unitary_to_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
A transform for decomposing arbitrary single-qubit QubitUnitary gates into elementary gates.
"""
from functools import lru_cache, partial

import pennylane as qml
from pennylane.ops.op_math.decompositions import one_qubit_decomposition, two_qubit_decomposition
Expand All @@ -23,7 +24,69 @@
from pennylane.typing import PostprocessingFn


@transform
@lru_cache
def _get_plxpr_unitary_to_rot():
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr

from pennylane.capture import PlxprInterpreter
from pennylane.operation import Operator
except ImportError: # pragma: no cover
return None, None

# pylint: disable=redefined-outer-name, too-few-public-methods
class UnitaryToRotInterpreter(PlxprInterpreter):
"""Plxpr Interpreter for applying the ``unitary_to_rot``
transform when program capture is enabled."""

def interpret_operation(self, op: Operator):
"""Decompose a PennyLane operation instance if it is a QubitUnitary.

Args:
op (Operator): a pennylane operator instance

Returns:
list: The decomposed operations.

This method is only called when the operator's output is a dropped variable,
so the output will not affect later equations in the circuit.

See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`.
"""
if isinstance(op, qml.QubitUnitary):
ops = []
with qml.capture.pause():
matrix_shape = qml.math.shape(op.parameters[0])
if matrix_shape == (2, 2):
ops = one_qubit_decomposition(op.parameters[0], op.wires[0])
elif matrix_shape == (4, 4):
ops = two_qubit_decomposition(op.parameters[0], op.wires)
# List comprehensions are run in a separate scope.
# The automatic insertion of __class__ and self for zero-argument super does not work in such a nested scope.
# pylint: disable=super-with-arguments
return [super(UnitaryToRotInterpreter, self).interpret_operation(o) for o in ops]

return super().interpret_operation(op)

# pylint: disable=redefined-outer-name
def unitary_to_rot_plxpr_to_plxpr(
jaxpr, consts, _, __, *args
): # pylint: disable=unused-argument
interpreter = UnitaryToRotInterpreter()

def wrapper(*inner_args):
return interpreter.eval(jaxpr, consts, *inner_args)

return make_jaxpr(wrapper)(*args)

return UnitaryToRotInterpreter, unitary_to_rot_plxpr_to_plxpr


UnitaryToRotInterpreter, unitary_to_rot_plxpr_to_plxpr = _get_plxpr_unitary_to_rot()


@partial(transform, plxpr_transform=unitary_to_rot_plxpr_to_plxpr)
def unitary_to_rot(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
r"""Quantum function transform to decomposes all instances of single-qubit and
select instances of two-qubit :class:`~.QubitUnitary` operations to
Expand Down
Loading