Skip to content

Commit

Permalink
Tests with combined params
Browse files Browse the repository at this point in the history
  • Loading branch information
PietropaoloFrisoni committed Feb 9, 2025
1 parent 1986b6a commit 4510e7e
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 8 deletions.
11 changes: 4 additions & 7 deletions pennylane/transforms/decompose.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,7 @@ def sub_interpret_operation(self, op: qml.operation.Operator, current_depth: int
"""

if not op.has_plxpr_decomposition:
return self.interpret_operation(op)

if self.gate_set(op):
if not op.has_plxpr_decomposition or self.gate_set(op):
return super().interpret_operation(op)

return self._evaluate_jaxpr_decomposition(op, current_depth)
Expand Down Expand Up @@ -163,10 +160,10 @@ def decompose_operation(self, op: qml.operation.Operator, current_depth: int = 0
See also: :meth:`~.interpret_operation_eqn`, :meth:`~.interpret_operation`.
"""
if self.gate_set(op):
return self.interpret_operation(op)
return super().interpret_operation(op)

max_expansion = (
self.max_expansion + current_depth if self.max_expansion is not None else None
self.max_expansion - current_depth if self.max_expansion is not None else None
)

depth_tracker = {"current_depth": current_depth}
Expand Down Expand Up @@ -194,7 +191,7 @@ def _evaluate_jaxpr_decomposition(self, op: qml.operation.Operator, current_dept
print(f"_evaluate_jaxpr_decomposition: op={op}, current_depth={current_depth}")

if self.gate_set(op):
return self.interpret_operation(op)
return super().interpret_operation(op)

if self.max_expansion is not None and current_depth >= self.max_expansion:
return super().interpret_operation(op)
Expand Down
107 changes: 106 additions & 1 deletion tests/capture/transforms/test_capture_dynamic_decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,79 @@ def circuit(x, wire):
check_jaxpr_eqns(for_loop_eqns, expected_ops_for_loop)
check_jaxpr_eqns(while_loop_eqns, expected_ops_while_loop)

@pytest.mark.parametrize(
"max_expansion, gate_set, expected_ops, expected_ops_for_loop, expected_ops_while_loop",
[
(
1,
[qml.RX, qml.RY, qml.RZ, qml.CNOT],
# CustomOpNestedOpControlFlow -> Rot, CustomOpNestedOp (before cond)
# Rot -> qml.RZ, qml.RY, qml.RZ
[qml.Rot, CustomOpNestedOp],
# CustomOpNestedOp is in the for loop of the true branch
# CustomOpNestedOp -> RX, SimpleCustomOp
# SimpleCustomOp -> Hadamard, Hadamard
[CustomOpNestedOp],
# SimpleCustomOp is in the while loop of the false branch
[SimpleCustomOp],
),
(
2,
[qml.RX, qml.RY, qml.RZ, CustomOpNestedOp],
# CustomOpNestedOpControlFlow -> Rot, CustomOpNestedOp (before cond)
# Rot -> qml.RZ, qml.RY, qml.RZ, CustomOpNestedOp is in the gate set
[qml.RZ, qml.RY, qml.RZ, CustomOpNestedOp],
# CustomOpNestedOp is in the for loop of the true branch
# CustomOpNestedOp -> RX, SimpleCustomOp
# SimpleCustomOp -> Hadamard, Hadamard
[CustomOpNestedOp],
# SimpleCustomOp is in the while loop of the false branch
# SimpleCustomOp -> Hadamard, Hadamard
# Hadamard -> RZ, RX, RZ
[qml.Hadamard, qml.Hadamard],
),
(
3,
[qml.RX, qml.RY, qml.RZ, qml.Rot, SimpleCustomOp],
# CustomOpNestedOpControlFlow -> Rot, CustomOpNestedOp (before cond)
# Rot -> qml.RZ, qml.RY, qml.RZ
# CustomOpNestedOp -> RX, SimpleCustomOp
[qml.Rot, qml.RX, SimpleCustomOp],
# CustomOpNestedOp is in the for loop of the true branch
# CustomOpNestedOp -> RX, SimpleCustomOp
[qml.RX, SimpleCustomOp],
# SimpleCustomOp is in the while loop of the false branch
[SimpleCustomOp],
),
],
)
def test_nested_decomp_control_flow_max_exp_gate_set(
self, max_expansion, gate_set, expected_ops, expected_ops_for_loop, expected_ops_while_loop
):
"""Test that a nested decomposition custom operation that contains control flow is correctly decomposed using a gate set and max expansion."""

@DecomposeInterpreter(max_expansion=max_expansion, gate_set=gate_set)
def circuit(x, wire):
CustomOpNestedOpControlFlow(x, wires=wire)
return qml.expval(qml.Z(wires=wire))

jaxpr = jax.make_jaxpr(circuit)(0.5, wire=0)
jaxpr_eqns = get_jaxpr_eqns(jaxpr)

ops_before_cond = len(expected_ops)
check_jaxpr_eqns(jaxpr_eqns[0:ops_before_cond], expected_ops)

# The + 1 is for the operation that determines the branches of the cond primitive
cond_eqns = get_eqns_cond_branches(jaxpr_eqns[ops_before_cond + 1])
for_loop_eqns = get_eqns_for_loop(cond_eqns[0][0])
while_loop_eqns = get_eqns_while_loop(cond_eqns[1][0])

for_loop_eqns = [eqn for eqn in for_loop_eqns if eqn.primitive != jax.lax.sin_p]
while_loop_eqns = [eqn for eqn in while_loop_eqns if eqn.primitive != jax.lax.add_p]

check_jaxpr_eqns(for_loop_eqns, expected_ops_for_loop)
check_jaxpr_eqns(while_loop_eqns, expected_ops_while_loop)

@pytest.mark.parametrize(
"max_expansion, expected_ops",
[
Expand All @@ -1090,7 +1163,7 @@ def circuit(x, wire):
],
)
def test_nested_decomp_no_plxpr_decomp_max_exp(self, max_expansion, expected_ops):
"""Test that a QNode with a nested decomposition custom operation that contains an operator with no plxpr decomposition is correctly decomposed."""
"""Test that a nested decomposition custom operation that contains an operator with no plxpr decomposition is correctly decomposed."""

@DecomposeInterpreter(max_expansion=max_expansion)
def circuit(x, wire):
Expand Down Expand Up @@ -1118,6 +1191,7 @@ def circuit(x, wire):
],
)
def test_nested_decomp_no_plxpr_decomposition_gate_set(self, gate_set, expected_ops):
"""Test that a nested decomposition custom operation that contains an operator with no plxpr decomposition is correctly decomposed using a custom gate set."""

@DecomposeInterpreter(gate_set=gate_set)
def circuit(x, wire):
Expand All @@ -1128,3 +1202,34 @@ def circuit(x, wire):
jaxpr_eqns = get_jaxpr_eqns(jaxpr)

check_jaxpr_eqns(jaxpr_eqns[0 : len(expected_ops)], expected_ops)

@pytest.mark.parametrize(
"max_expansion, gate_set, expected_ops",
[
(0, [CustomOpNoPlxprDecomposition], [CustomOpNoPlxprDecomposition]),
(1, [CustomOpNoPlxprDecomposition], [CustomOpNoPlxprDecomposition]),
(2, [CustomOpNoPlxprDecomposition], [CustomOpNoPlxprDecomposition]),
(0, [CustomOpNestedOpControlFlow], [CustomOpNoPlxprDecomposition]),
(1, [CustomOpNestedOpControlFlow], [CustomOpNestedOpControlFlow]),
(2, [CustomOpNestedOpControlFlow], [CustomOpNestedOpControlFlow]),
(0, [qml.RX, qml.RY, qml.RZ, qml.S], [CustomOpNoPlxprDecomposition]),
(1, [qml.RX, qml.RY, qml.RZ, qml.S], [CustomOpNestedOpControlFlow]),
(2, [qml.RX, qml.RY, qml.RZ, qml.S], [qml.S]),
(2, [qml.RX, qml.RY, qml.RZ], [qml.S]),
(None, [qml.RX, qml.RY, qml.RZ], [qml.RZ]),
],
)
def test_nested_decomp_no_plxpr_decomposition_max_exp_gate_set(
self, max_expansion, gate_set, expected_ops
):
"""Test that a custom operation that contains an operator with no plxpr decomposition is correctly decomposed using a custom gate set and max_expansion."""

@DecomposeInterpreter(max_expansion=max_expansion, gate_set=gate_set)
def circuit(x, wire):
CustomOpNoPlxprDecomposition(x, wires=wire)
return qml.expval(qml.Z(wires=wire))

jaxpr = jax.make_jaxpr(circuit)(0.5, wire=0)
jaxpr_eqns = get_jaxpr_eqns(jaxpr)

check_jaxpr_eqns(jaxpr_eqns[0 : len(expected_ops)], expected_ops)

0 comments on commit 4510e7e

Please sign in to comment.