Skip to content

Commit

Permalink
Only test numba on Python < 3.12
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Oct 19, 2023
1 parent e3f400b commit 602b521
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 28 deletions.
18 changes: 13 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,30 @@ classifiers = [
dependencies = ["typing-extensions", "black", "graphviz"]

[project.optional-dependencies]
dev = ["pre-commit", "black", "mypy", "flake8", "isort", "anywidget[dev]"]

test = ["pytest", "mypy", "scikit-learn", "array_api_compat", "syrupy", "numba"]
array = ["scikit-learn", "array_api_compat", 'numba; python_version<"3.12"']
dev = [
"pre-commit",
"black",
"mypy",
"flake8",
"isort",
"anywidget[dev]",
"egglog[docs,test]",
]

test = ["pytest", "mypy", "syrupy", "egglog[array]"]

docs = [
"pydata-sphinx-theme",
"myst-nb",
"sphinx-autodoc-typehints",
"sphinx-gallery",
"nbconvert",
"scikit-learn",
"array_api_compat",
"matplotlib",
"anywidget",
"numba",
"seaborn",
"egglog[array]",
]


Expand Down
43 changes: 23 additions & 20 deletions python/egglog/exp/array_api_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
"""

from __future__ import annotations
from ast import Import

import operator

from egglog import *
from egglog.exp.array_api import *
from llvmlite import ir
from numba.core import types
from numba.core.imputils import impl_ret_untracked, lower_builtin
from numba.core.typing.templates import AbstractTemplate, infer_global, signature

array_api_numba_module = Module([array_api_module])

Expand Down Expand Up @@ -63,19 +60,25 @@ def _std(y: NDArray, x: NDArray, i: Int):

# Inline these changes until this PR is released to add suport for checking dtypes equal
# https://github.com/numba/numba/pull/9249


@infer_global(operator.eq)
class DtypeEq(AbstractTemplate):
def generic(self, args, kws):
[lhs, rhs] = args
if isinstance(lhs, types.DType) and isinstance(rhs, types.DType):
return signature(types.boolean, lhs, rhs)


@lower_builtin(operator.eq, types.DType, types.DType)
def const_eq_impl(context, builder, sig, args):
arg1, arg2 = sig.args
val = 1 if arg1 == arg2 else 0
res = ir.Constant(ir.IntType(1), val)
return impl_ret_untracked(context, builder, sig.return_type, res)
try:
from llvmlite import ir
from numba.core import types
from numba.core.imputils import impl_ret_untracked, lower_builtin
from numba.core.typing.templates import AbstractTemplate, infer_global, signature
except ImportError:
pass
else:
@infer_global(operator.eq)
class DtypeEq(AbstractTemplate):
def generic(self, args, kws):
[lhs, rhs] = args
if isinstance(lhs, types.DType) and isinstance(rhs, types.DType):
return signature(types.boolean, lhs, rhs)


@lower_builtin(operator.eq, types.DType, types.DType)
def const_eq_impl(context, builder, sig, args):
arg1, arg2 = sig.args
val = 1 if arg1 == arg2 else 0
res = ir.Constant(ir.IntType(1), val)
return impl_ret_untracked(context, builder, sig.return_type, res)
11 changes: 8 additions & 3 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import numba
import pytest
from egglog.exp.array_api import *
from egglog.exp.array_api_numba import array_api_numba_module
Expand Down Expand Up @@ -253,8 +252,14 @@ def test_sklearn_lda_runs():
optimized_res = fn(X_np, y_np) # type: ignore
assert np.allclose(real_res, optimized_res)

numba_res = numba.njit(fn)(X_np, y_np)
assert np.allclose(real_res, numba_res)
# Numba isn't supported on all platforms, so only test this if we can import
try:
import numba
except ImportError:
pass
else:
numba_res = numba.njit(fn)(X_np, y_np)
assert np.allclose(real_res, numba_res)


def test_reshape_index():
Expand Down

0 comments on commit 602b521

Please sign in to comment.