Skip to content

Commit

Permalink
First draft of device preprocessing integration
Browse files Browse the repository at this point in the history
  • Loading branch information
mudit2812 committed Feb 4, 2025
1 parent b5ad835 commit 845fdd9
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
35 changes: 33 additions & 2 deletions pennylane/devices/default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@

def stopping_condition(op: qml.operation.Operator) -> bool:
"""Specify whether or not an Operator object is supported by the device."""
if qml.capture.enabled():
return op.has_matrix

if op.name == "QFT" and len(op.wires) >= 6:
return False
if op.name == "GroverOperator" and len(op.wires) >= 13:
Expand Down Expand Up @@ -521,10 +524,35 @@ def supports_derivatives(
return _supports_adjoint(circuit, device_wires=self.wires, device_name=self.name)
return False

def _preprocess_capture(
self, execution_config=DefaultExecutionConfig
) -> tuple[TransformProgram, ExecutionConfig]:
updated_values = {}

if execution_config.gradient_method == "best":
updated_values["gradient_method"] = "backprop"

updated_values["device_options"] = dict(execution_config.device_options) # copy
for option in self._device_options:
if option not in updated_values["device_options"]:
updated_values["device_options"][option] = getattr(self, f"_{option}")

execution_config = replace(execution_config, **updated_values)

for option, value in execution_config.device_options.items():
if option not in self._device_options:
raise qml.DeviceError(f"device option {option} not present on {self}")

if option == "max_workers" and value is not None:
raise qml.DeviceError("Cannot set 'max_workers' if program capture is enabled.")

transform_program = TransformProgram()
transform_program.add_transform(qml.transforms.decompose, gate_set=stopping_condition)
return transform_program, execution_config

@debug_logger
def preprocess(
self,
execution_config: ExecutionConfig = DefaultExecutionConfig,
self, execution_config: ExecutionConfig = DefaultExecutionConfig
) -> tuple[TransformProgram, ExecutionConfig]:
"""This function defines the device transform program to be applied and an updated device configuration.
Expand All @@ -540,6 +568,9 @@ def preprocess(
This device supports any qubit operations that provide a matrix
"""
if qml.capture.enabled():
return self._preprocess_capture(execution_config=execution_config)

config = self._setup_execution_config(execution_config)
transform_program = TransformProgram()

Expand Down
2 changes: 1 addition & 1 deletion pennylane/transforms/core/transform_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def __call__(
self, tapes: QuantumScriptBatch
) -> tuple[QuantumScriptBatch, BatchPostprocessingFn]: ...
def __call__(self, *args, **kwargs):
if isinstance(args[0], QuantumScriptBatch):
if isinstance(args[0], Sequence):
return self.__call_tapes(args[0])
return self.__call_jaxpr(*args, **kwargs)

Expand Down
9 changes: 8 additions & 1 deletion pennylane/workflow/_capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,14 @@ def user_transform_wrapper(*inner_args):
return jax.core.eval_jaxpr(qfunc_jaxpr, consts, *inner_args)

user_jaxpr = jax.make_jaxpr(user_transform_wrapper)(*non_const_args)
final_jaxpr = qnode.transform_program(user_jaxpr.jaxpr, user_jaxpr.consts, *non_const_args)
transformed_jaxpr = qnode.transform_program(
user_jaxpr.jaxpr, user_jaxpr.consts, *non_const_args
)

preprocess_program, _ = device.preprocess()
final_jaxpr = preprocess_program(
transformed_jaxpr.jaxpr, transformed_jaxpr.consts, *non_const_args
)

if batch_dims is None:
return device.eval_jaxpr(final_jaxpr.jaxpr, final_jaxpr.consts, *non_const_args)
Expand Down

0 comments on commit 845fdd9

Please sign in to comment.