Skip to content

Commit

Permalink
[Capture] unitary_to_rot is plxpr compatible (#6916)
Browse files Browse the repository at this point in the history
**Context:**

This PR adds a `UnitaryToRotInterpreter` to apply the `unitary_to_rot`
transform natively to `plxpr`.

**Description of the Change:**

* Add `UnitaryToRotInterpreter` to transform `plxpr`

**Benefits:**

`unitary_to_rot` can be applied natively to plxpr.

```python
qml.capture.enable()
import jax

U1 = qml.Rot(1.0, 2.0, 3.0, wires=0)

@qml.capture.expand_plxpr_transforms
@qml.transforms.unitary_to_rot
def f(U1):
    qml.X(0)
    qml.QubitUnitary(U1, 0)
    qml.Y(0)
    return qml.expval(qml.Z(0))

>>> jaxpr = jax.make_jaxpr(f)(U1.matrix())
>>> tape = qml.tape.plxpr_to_tape(jaxpr.jaxpr, jaxpr.consts, U1.matrix()).operations
>>> pprint.pprint(tape)
[X(0),
 RZ(Array(1., dtype=float32), wires=[0]),
 RY(Array(2., dtype=float32), wires=[0]),
 RZ(Array(3., dtype=float32), wires=[0]),
 Y(0)]
```

**Possible Drawbacks:** None identified.

[sc-83556]

---------

Co-authored-by: Pietropaolo Frisoni <pietropaolo.frisoni@xanadu.ai>
Co-authored-by: Mudit Pandey <mudit.pandey@xanadu.ai>
  • Loading branch information
3 people authored Feb 11, 2025
1 parent 67fbcba commit 65c52a8
Show file tree
Hide file tree
Showing 3 changed files with 601 additions and 1 deletion.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

<h3>New features since last release</h3>

* Added class `qml.capture.transforms.UnitaryToRotInterpreter` that decomposes `qml.QubitUnitary` operators
following the same API as `qml.transforms.unitary_to_rot` when experimental program capture is enabled.
[(#6916)](https://github.com/PennyLaneAI/pennylane/pull/6916)

<h3>Improvements 🛠</h3>

* Add a decomposition for multi-controlled global phases into a one-less-controlled phase shift.
Expand Down
67 changes: 66 additions & 1 deletion pennylane/transforms/unitary_to_rot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
A transform for decomposing arbitrary single-qubit QubitUnitary gates into elementary gates.
"""
from functools import lru_cache, partial

import pennylane as qml
from pennylane.ops.op_math.decompositions import one_qubit_decomposition, two_qubit_decomposition
Expand All @@ -23,7 +24,71 @@
from pennylane.typing import PostprocessingFn


@transform
@lru_cache
def _get_plxpr_unitary_to_rot():
try:
# pylint: disable=import-outside-toplevel
from jax import make_jaxpr

from pennylane.capture import PlxprInterpreter
from pennylane.operation import Operator
except ImportError: # pragma: no cover
return None, None

# pylint: disable=redefined-outer-name, too-few-public-methods
class UnitaryToRotInterpreter(PlxprInterpreter):
"""Plxpr Interpreter for applying the ``unitary_to_rot``
transform when program capture is enabled."""

def interpret_operation(self, op: Operator):
"""Decompose a PennyLane operation instance if it is a QubitUnitary.
Args:
op (Operator): a pennylane operator instance
Returns:
list: The decomposed operations.
This method is only called when the operator's output is a dropped variable,
so the output will not affect later equations in the circuit.
See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`.
"""
if isinstance(op, qml.QubitUnitary):
ops = []
with qml.capture.pause():
matrix_shape = qml.math.shape(op.parameters[0])
if matrix_shape == (2, 2):
ops = one_qubit_decomposition(op.parameters[0], op.wires[0])
elif matrix_shape == (4, 4):
ops = two_qubit_decomposition(op.parameters[0], op.wires)
else:
ops = [op]
# List comprehensions are run in a separate scope.
# The automatic insertion of __class__ and self for zero-argument super does not work in such a nested scope.
# pylint: disable=super-with-arguments
return [super(UnitaryToRotInterpreter, self).interpret_operation(o) for o in ops]

return super().interpret_operation(op)

# pylint: disable=redefined-outer-name
def unitary_to_rot_plxpr_to_plxpr(
jaxpr, consts, _, __, *args
): # pylint: disable=unused-argument
interpreter = UnitaryToRotInterpreter()

def wrapper(*inner_args):
return interpreter.eval(jaxpr, consts, *inner_args)

return make_jaxpr(wrapper)(*args)

return UnitaryToRotInterpreter, unitary_to_rot_plxpr_to_plxpr


UnitaryToRotInterpreter, unitary_to_rot_plxpr_to_plxpr = _get_plxpr_unitary_to_rot()


@partial(transform, plxpr_transform=unitary_to_rot_plxpr_to_plxpr)
def unitary_to_rot(tape: QuantumScript) -> tuple[QuantumScriptBatch, PostprocessingFn]:
r"""Quantum function transform to decomposes all instances of single-qubit and
select instances of two-qubit :class:`~.QubitUnitary` operations to
Expand Down
Loading

0 comments on commit 65c52a8

Please sign in to comment.