From 602b52183f1993ad69950956f5de0785bbe7e615 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 19 Oct 2023 12:49:27 -0400 Subject: [PATCH] Only test numba on Python < 3.12 --- pyproject.toml | 18 ++++++++---- python/egglog/exp/array_api_numba.py | 43 +++++++++++++++------------- python/tests/test_array_api.py | 11 +++++-- 3 files changed, 44 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4388c69..df4f22b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,19 @@ classifiers = [ dependencies = ["typing-extensions", "black", "graphviz"] [project.optional-dependencies] -dev = ["pre-commit", "black", "mypy", "flake8", "isort", "anywidget[dev]"] -test = ["pytest", "mypy", "scikit-learn", "array_api_compat", "syrupy", "numba"] +array = ["scikit-learn", "array_api_compat", 'numba; python_version<"3.12"'] +dev = [ + "pre-commit", + "black", + "mypy", + "flake8", + "isort", + "anywidget[dev]", + "egglog[docs,test]", +] + +test = ["pytest", "mypy", "syrupy", "egglog[array]"] docs = [ "pydata-sphinx-theme", @@ -38,12 +48,10 @@ docs = [ "sphinx-autodoc-typehints", "sphinx-gallery", "nbconvert", - "scikit-learn", - "array_api_compat", "matplotlib", "anywidget", - "numba", "seaborn", + "egglog[array]", ] diff --git a/python/egglog/exp/array_api_numba.py b/python/egglog/exp/array_api_numba.py index 306c5a1d..a3fc3e5f 100644 --- a/python/egglog/exp/array_api_numba.py +++ b/python/egglog/exp/array_api_numba.py @@ -3,15 +3,12 @@ """ from __future__ import annotations +from ast import Import import operator from egglog import * from egglog.exp.array_api import * -from llvmlite import ir -from numba.core import types -from numba.core.imputils import impl_ret_untracked, lower_builtin -from numba.core.typing.templates import AbstractTemplate, infer_global, signature array_api_numba_module = Module([array_api_module]) @@ -63,19 +60,25 @@ def _std(y: NDArray, x: NDArray, i: Int): # Inline these changes until this PR is released to add suport for checking dtypes equal # https://github.com/numba/numba/pull/9249 - - -@infer_global(operator.eq) -class DtypeEq(AbstractTemplate): - def generic(self, args, kws): - [lhs, rhs] = args - if isinstance(lhs, types.DType) and isinstance(rhs, types.DType): - return signature(types.boolean, lhs, rhs) - - -@lower_builtin(operator.eq, types.DType, types.DType) -def const_eq_impl(context, builder, sig, args): - arg1, arg2 = sig.args - val = 1 if arg1 == arg2 else 0 - res = ir.Constant(ir.IntType(1), val) - return impl_ret_untracked(context, builder, sig.return_type, res) +try: + from llvmlite import ir + from numba.core import types + from numba.core.imputils import impl_ret_untracked, lower_builtin + from numba.core.typing.templates import AbstractTemplate, infer_global, signature +except ImportError: + pass +else: + @infer_global(operator.eq) + class DtypeEq(AbstractTemplate): + def generic(self, args, kws): + [lhs, rhs] = args + if isinstance(lhs, types.DType) and isinstance(rhs, types.DType): + return signature(types.boolean, lhs, rhs) + + + @lower_builtin(operator.eq, types.DType, types.DType) + def const_eq_impl(context, builder, sig, args): + arg1, arg2 = sig.args + val = 1 if arg1 == arg2 else 0 + res = ir.Constant(ir.IntType(1), val) + return impl_ret_untracked(context, builder, sig.return_type, res) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 8e10a020..de0f71b5 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,4 +1,3 @@ -import numba import pytest from egglog.exp.array_api import * from egglog.exp.array_api_numba import array_api_numba_module @@ -253,8 +252,14 @@ def test_sklearn_lda_runs(): optimized_res = fn(X_np, y_np) # type: ignore assert np.allclose(real_res, optimized_res) - numba_res = numba.njit(fn)(X_np, y_np) - assert np.allclose(real_res, numba_res) + # Numba isn't supported on all platforms, so only test this if we can import + try: + import numba + except ImportError: + pass + else: + numba_res = numba.njit(fn)(X_np, y_np) + assert np.allclose(real_res, numba_res) def test_reshape_index():