Skip to content

Commit

Permalink
Update drawer catalyst and add draw_mpl (#4815)
Browse files Browse the repository at this point in the history
**Description of the Change:**

- Update how to check if Catalyst available
- Add draw mpl

---------

Co-authored-by: Ali Asadi <ali@xanadu.ai>
Co-authored-by: Josh Izaac <josh146@gmail.com>
Co-authored-by: David Ittah <dime10@users.noreply.github.com>
Co-authored-by: Matthew Silverman <matthews@xanadu.ai>
  • Loading branch information
5 people authored Nov 30, 2023
1 parent 48aeec8 commit e043de2
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/interface-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ jobs:
install_pennylane_lightning_master: false
pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }}
pytest_markers: external
additional_pip_packages: git+https://github.com/Quantomatic/pyzx.git@master pennylane-catalyst
additional_pip_packages: git+https://github.com/Quantomatic/pyzx.git@master pennylane-catalyst matplotlib


qcut-tests:
Expand Down
4 changes: 2 additions & 2 deletions pennylane/_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,8 @@ def jacobian(func, argnum=None, method=None, h=None):
functions.
Returns:
function: the function that returns the Jacobian of the input
function with respect to the arguments in argnum
function: the function that returns the Jacobian of the input function with respect to the
arguments in argnum
.. note::
Expand Down
2 changes: 2 additions & 0 deletions pennylane/drawer/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ def circuit2(x, y):
:target: javascript:void(0);
"""
if catalyst_qjit(qnode):
qnode = qnode.user_function
if hasattr(qnode, "construct"):
return _draw_mpl_qnode(
qnode,
Expand Down
123 changes: 112 additions & 11 deletions tests/drawer/test_draw_catalyst.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the pennylane drawer with Catalyst."""
# pylint: disable=import-outside-toplevel
# pylint: disable=import-outside-toplevel,protected-access
import pytest
import pennylane as qml

pyzx = pytest.importorskip("catalyst")
catalyst = pytest.importorskip("catalyst")
mpl = pytest.importorskip("matplotlib")

pytestmark = pytest.mark.external


class TestCatalyst:
class TestCatalystDraw:
"""Drawer integration test with Catalyst jitted QNodes."""

def test_simple_circuit(self):
"""Test a simple circuit that does not use Catalyst features."""
import catalyst

@catalyst.qjit
@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234)))
def circuit(x, y, z):
"""A quantum circuit on three wires."""
Expand All @@ -43,9 +43,10 @@ def circuit(x, y, z):
@pytest.mark.parametrize("c", [0, 1])
def test_cond_circuit(self, c):
"""Test a circuit with a Catalyst conditional."""
import catalyst

@catalyst.qjit
import catalyst # pylint: disable=redefined-outer-name

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234)))
def circuit(x, y, z, c):
"""A quantum circuit on three wires."""
Expand All @@ -69,9 +70,10 @@ def conditional_flip():
@pytest.mark.parametrize("c", [1, 2])
def test_for_loop_circuit(self, c):
"""Test a circuit with a Catalyst for_loop"""
import catalyst

@catalyst.qjit
import catalyst # pylint: disable=redefined-outer-name

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y, z, c):
"""A quantum circuit on three wires."""
Expand All @@ -95,9 +97,10 @@ def loop(i):
@pytest.mark.parametrize("c", [0, 1])
def test_while_loop_circuit(self, c):
"""Test a circuit with a Catalyst while_loop"""
import catalyst

@catalyst.qjit
import catalyst # pylint: disable=redefined-outer-name

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y, z, c):
"""A quantum circuit on three wires."""
Expand All @@ -120,3 +123,101 @@ def loop_rx(x):
"0: ──RX──RX─┤ <Z>\n1: ──RY─────┤ \n2: ──RZ─────┤ ",
]
assert qml.draw(circuit, decimals=None)(1.234, 2.345, 3.456, c) == expected[c]


class TestCatalystDrawMpl:
"""MPL Drawer integration test with Catalyst jitted QNodes."""

def test_simple_circuit(self):
"""Test a simple circuit that does not use Catalyst features."""

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234)))
def circuit(x, y, z):
"""A quantum circuit on three wires."""
qml.RX(x, wires=0)
qml.RY(y, wires="a")
qml.RZ(z, wires=1.234)
return qml.expval(qml.PauliZ(0))

fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456)
assert isinstance(fig, mpl.figure.Figure)
assert isinstance(ax, mpl.axes._axes.Axes)

@pytest.mark.parametrize("c", [0, 1])
def test_cond_circuit(self, c):
"""Test a circuit with a Catalyst conditional."""

import catalyst # pylint: disable=redefined-outer-name

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234)))
def circuit(x, y, z, c):
"""A quantum circuit on three wires."""

@catalyst.cond(c)
def conditional_flip():
qml.PauliX(wires=0)

qml.RX(x, wires=0)
conditional_flip()
qml.RY(y, wires="a")
qml.RZ(z, wires=1.234)
return qml.expval(qml.PauliZ(0))

fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456, c)
assert isinstance(fig, mpl.figure.Figure)
assert isinstance(ax, mpl.axes._axes.Axes)

@pytest.mark.parametrize("c", [1, 2])
def test_for_loop_circuit(self, c):
"""Test a circuit with a Catalyst for_loop"""

import catalyst # pylint: disable=redefined-outer-name

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y, z, c):
"""A quantum circuit on three wires."""

@catalyst.for_loop(0, c, 1)
def loop(i):
qml.Hadamard(wires=i)

qml.RX(x, wires=0)
loop() # pylint: disable=no-value-for-parameter
qml.RY(y, wires=1)
qml.RZ(z, wires=2)
return qml.expval(qml.PauliZ(0))

fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456, c)
assert isinstance(fig, mpl.figure.Figure)
assert isinstance(ax, mpl.axes._axes.Axes)

@pytest.mark.parametrize("c", [0, 1])
def test_while_loop_circuit(self, c):
"""Test a circuit with a Catalyst while_loop"""

import catalyst # pylint: disable=redefined-outer-name

@qml.qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(x, y, z, c):
"""A quantum circuit on three wires."""

@catalyst.while_loop(lambda x: x < 2.0)
def loop_rx(x):
# perform some work and update (some of) the arguments
qml.RX(x, wires=0)
return x + 1

# apply the while loop
qml.RX(x, wires=0)
qml.RY(y, wires=1)
loop_rx(c)
qml.RZ(z, wires=2)
return qml.expval(qml.PauliZ(0))

fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456, c)
assert isinstance(fig, mpl.figure.Figure)
assert isinstance(ax, mpl.axes._axes.Axes)

0 comments on commit e043de2

Please sign in to comment.