From e043de2e692661916a2d76c957654fdfef261b53 Mon Sep 17 00:00:00 2001 From: Romain Moyard Date: Thu, 30 Nov 2023 16:17:00 -0500 Subject: [PATCH] Update drawer catalyst and add draw_mpl (#4815) **Description of the Change:** - Update how to check if Catalyst available - Add draw mpl --------- Co-authored-by: Ali Asadi Co-authored-by: Josh Izaac Co-authored-by: David Ittah Co-authored-by: Matthew Silverman --- .github/workflows/interface-unit-tests.yml | 2 +- pennylane/_grad.py | 4 +- pennylane/drawer/draw.py | 2 + tests/drawer/test_draw_catalyst.py | 123 +++++++++++++++++++-- 4 files changed, 117 insertions(+), 14 deletions(-) diff --git a/.github/workflows/interface-unit-tests.yml b/.github/workflows/interface-unit-tests.yml index 7a9db91ccd5..0aa27f71139 100644 --- a/.github/workflows/interface-unit-tests.yml +++ b/.github/workflows/interface-unit-tests.yml @@ -370,7 +370,7 @@ jobs: install_pennylane_lightning_master: false pytest_coverage_flags: ${{ inputs.pytest_coverage_flags }} pytest_markers: external - additional_pip_packages: git+https://github.com/Quantomatic/pyzx.git@master pennylane-catalyst + additional_pip_packages: git+https://github.com/Quantomatic/pyzx.git@master pennylane-catalyst matplotlib qcut-tests: diff --git a/pennylane/_grad.py b/pennylane/_grad.py index 1f3fd6f7732..cf1615713fe 100644 --- a/pennylane/_grad.py +++ b/pennylane/_grad.py @@ -234,8 +234,8 @@ def jacobian(func, argnum=None, method=None, h=None): functions. Returns: - function: the function that returns the Jacobian of the input - function with respect to the arguments in argnum + function: the function that returns the Jacobian of the input function with respect to the + arguments in argnum .. note:: diff --git a/pennylane/drawer/draw.py b/pennylane/drawer/draw.py index 2b993cf7d98..4ead7723eef 100644 --- a/pennylane/drawer/draw.py +++ b/pennylane/drawer/draw.py @@ -521,6 +521,8 @@ def circuit2(x, y): :target: javascript:void(0); """ + if catalyst_qjit(qnode): + qnode = qnode.user_function if hasattr(qnode, "construct"): return _draw_mpl_qnode( qnode, diff --git a/tests/drawer/test_draw_catalyst.py b/tests/drawer/test_draw_catalyst.py index 8053d72fee5..1075c30b0ba 100644 --- a/tests/drawer/test_draw_catalyst.py +++ b/tests/drawer/test_draw_catalyst.py @@ -12,23 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. """Test the pennylane drawer with Catalyst.""" -# pylint: disable=import-outside-toplevel +# pylint: disable=import-outside-toplevel,protected-access import pytest import pennylane as qml -pyzx = pytest.importorskip("catalyst") +catalyst = pytest.importorskip("catalyst") +mpl = pytest.importorskip("matplotlib") pytestmark = pytest.mark.external -class TestCatalyst: +class TestCatalystDraw: """Drawer integration test with Catalyst jitted QNodes.""" def test_simple_circuit(self): """Test a simple circuit that does not use Catalyst features.""" - import catalyst - @catalyst.qjit + @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234))) def circuit(x, y, z): """A quantum circuit on three wires.""" @@ -43,9 +43,10 @@ def circuit(x, y, z): @pytest.mark.parametrize("c", [0, 1]) def test_cond_circuit(self, c): """Test a circuit with a Catalyst conditional.""" - import catalyst - @catalyst.qjit + import catalyst # pylint: disable=redefined-outer-name + + @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234))) def circuit(x, y, z, c): """A quantum circuit on three wires.""" @@ -69,9 +70,10 @@ def conditional_flip(): @pytest.mark.parametrize("c", [1, 2]) def test_for_loop_circuit(self, c): """Test a circuit with a Catalyst for_loop""" - import catalyst - @catalyst.qjit + import catalyst # pylint: disable=redefined-outer-name + + @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=3)) def circuit(x, y, z, c): """A quantum circuit on three wires.""" @@ -95,9 +97,10 @@ def loop(i): @pytest.mark.parametrize("c", [0, 1]) def test_while_loop_circuit(self, c): """Test a circuit with a Catalyst while_loop""" - import catalyst - @catalyst.qjit + import catalyst # pylint: disable=redefined-outer-name + + @qml.qjit @qml.qnode(qml.device("lightning.qubit", wires=3)) def circuit(x, y, z, c): """A quantum circuit on three wires.""" @@ -120,3 +123,101 @@ def loop_rx(x): "0: ──RX──RX─┤ \n1: ──RY─────┤ \n2: ──RZ─────┤ ", ] assert qml.draw(circuit, decimals=None)(1.234, 2.345, 3.456, c) == expected[c] + + +class TestCatalystDrawMpl: + """MPL Drawer integration test with Catalyst jitted QNodes.""" + + def test_simple_circuit(self): + """Test a simple circuit that does not use Catalyst features.""" + + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234))) + def circuit(x, y, z): + """A quantum circuit on three wires.""" + qml.RX(x, wires=0) + qml.RY(y, wires="a") + qml.RZ(z, wires=1.234) + return qml.expval(qml.PauliZ(0)) + + fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456) + assert isinstance(fig, mpl.figure.Figure) + assert isinstance(ax, mpl.axes._axes.Axes) + + @pytest.mark.parametrize("c", [0, 1]) + def test_cond_circuit(self, c): + """Test a circuit with a Catalyst conditional.""" + + import catalyst # pylint: disable=redefined-outer-name + + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=(0, "a", 1.234))) + def circuit(x, y, z, c): + """A quantum circuit on three wires.""" + + @catalyst.cond(c) + def conditional_flip(): + qml.PauliX(wires=0) + + qml.RX(x, wires=0) + conditional_flip() + qml.RY(y, wires="a") + qml.RZ(z, wires=1.234) + return qml.expval(qml.PauliZ(0)) + + fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456, c) + assert isinstance(fig, mpl.figure.Figure) + assert isinstance(ax, mpl.axes._axes.Axes) + + @pytest.mark.parametrize("c", [1, 2]) + def test_for_loop_circuit(self, c): + """Test a circuit with a Catalyst for_loop""" + + import catalyst # pylint: disable=redefined-outer-name + + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(x, y, z, c): + """A quantum circuit on three wires.""" + + @catalyst.for_loop(0, c, 1) + def loop(i): + qml.Hadamard(wires=i) + + qml.RX(x, wires=0) + loop() # pylint: disable=no-value-for-parameter + qml.RY(y, wires=1) + qml.RZ(z, wires=2) + return qml.expval(qml.PauliZ(0)) + + fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456, c) + assert isinstance(fig, mpl.figure.Figure) + assert isinstance(ax, mpl.axes._axes.Axes) + + @pytest.mark.parametrize("c", [0, 1]) + def test_while_loop_circuit(self, c): + """Test a circuit with a Catalyst while_loop""" + + import catalyst # pylint: disable=redefined-outer-name + + @qml.qjit + @qml.qnode(qml.device("lightning.qubit", wires=3)) + def circuit(x, y, z, c): + """A quantum circuit on three wires.""" + + @catalyst.while_loop(lambda x: x < 2.0) + def loop_rx(x): + # perform some work and update (some of) the arguments + qml.RX(x, wires=0) + return x + 1 + + # apply the while loop + qml.RX(x, wires=0) + qml.RY(y, wires=1) + loop_rx(c) + qml.RZ(z, wires=2) + return qml.expval(qml.PauliZ(0)) + + fig, ax = qml.draw_mpl(circuit, decimals=None)(1.234, 2.345, 3.456, c) + assert isinstance(fig, mpl.figure.Figure) + assert isinstance(ax, mpl.axes._axes.Axes)