Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: ndonnx device() support; TST: better ndonnx test coverage #232

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
device : Hardware device the array data resides on.

"""
if is_numpy_array(x):
if is_numpy_array(x) or is_ndonnx_array(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I don't quite understand is this: https://github.com/data-apis/array-api-compat/pull/232/files#diff-d5d8fa69860e03769cc235a54027efad4376a8c48eeab71aaa11365ea308123bR141 states that the array API support is internal to ndonnx, which seems to imply there's no need to special-case it here?

The note reads "Similar to JAX, ndonnx Array API support is contained directly in ndonnx." --- and indeed, there are no workarounds for jax in the codebase.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because they support the standard imperfectly. The array-api-tests for .device and .to_device are XFAILed:
https://github.com/Quantco/ndonnx/actions/workflows/array-api.yml

api-coverage-tests/array_api_tests/test_has_names.py::test_has_names[array_method-to_device] 
[gw1] [ 13%] XFAIL api-coverage-tests/array_api_tests/test_has_names.py::test_has_names[array_method-to_device] 
api-coverage-tests/array_api_tests/test_has_names.py::test_has_names[array_attribute-T] 
[gw1] [ 14%] PASSED api-coverage-tests/array_api_tests/test_has_names.py::test_has_names[array_attribute-T] 
api-coverage-tests/array_api_tests/test_has_names.py::test_has_names[array_attribute-device] 
[gw1] [ 14%] XFAIL api-coverage-tests/array_api_tests/test_has_names.py::test_has_names[array_attribute-device] 

Adding .device and .to_device won't break any library's backwards compatibility, so it would be best if each library just fixed it on their side. And yet here we have a function that works around the quirks of every library - last but not least to support least-than-newest versions of it.

Copy link
Contributor

@adityagoel4512 adityagoel4512 Jan 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will address this upstream. I should say that it might be deceptive to allow this to work with "cpu".

A serialized ONNX graph from ndonnx could later be executed on many different hardware targets or runtimes. ONNX is one level removed from the execution target by design.

Based on a quick read of the device page of the specification, we may do something similar to what you're doing with Dask.

if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
Expand Down
5 changes: 5 additions & 0 deletions docs/supported-array-libraries.md
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ The minimum supported Dask version is 2023.12.0.

Similar to JAX, `sparse` Array API support is contained directly in `sparse`.

(ndonnx-support)=
## [ndonnx](https://github.com/quantco/ndonnx)

Similar to JAX, `ndonnx` Array API support is contained directly in `ndonnx`.

(array-api-strict-support)=
## [array-api-strict](https://data-apis.org/array-api-strict/)

Expand Down
18 changes: 15 additions & 3 deletions tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import pytest

wrapped_libraries = ["numpy", "cupy", "torch", "dask.array"]
all_libraries = wrapped_libraries + ["array_api_strict", "jax.numpy", "sparse"]

all_libraries = wrapped_libraries + [
"array_api_strict", "jax.numpy", "ndonnx", "sparse"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This too. Why are we testing here libraries which we don't wrap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because _helpers.py is chock full of special cases and exceptions for them.

]

def import_(library, wrapper=False):
if library == 'cupy':
if library in ('cupy', 'ndonnx'):
pytest.importorskip(library)
if wrapper:
if 'jax' in library:
Expand All @@ -20,3 +21,14 @@ def import_(library, wrapper=False):
library = 'array_api_compat.' + library

return import_module(library)


def xfail(request: pytest.FixtureRequest, reason: str) -> None:
"""
XFAIL the currently running test.

Unlike ``pytest.xfail``, allow rest of test to execute instead of immediately
halting it, so that it may result in a XPASS.
xref https://github.com/pandas-dev/pandas/issues/38902
"""
request.node.add_marker(pytest.mark.xfail(reason=reason))
3 changes: 3 additions & 0 deletions tests/test_array_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def test_array_namespace(library, api_version, use_compat):
if use_compat and library not in wrapped_libraries:
pytest.raises(ValueError, lambda: array_namespace(array, use_compat=use_compat))
return
if library == "ndonnx" and api_version in ("2021.12", "2022.12"):
pytest.skip("Unsupported API version")

namespace = array_namespace(array, api_version=api_version, use_compat=use_compat)

if use_compat is False or use_compat is None and library not in wrapped_libraries:
Expand Down
29 changes: 19 additions & 10 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
from array_api_compat import ( # noqa: F401
is_numpy_array, is_cupy_array, is_torch_array,
is_dask_array, is_jax_array, is_pydata_sparse_array,
is_ndonnx_array,
is_numpy_namespace, is_cupy_namespace, is_torch_namespace,
is_dask_namespace, is_jax_namespace, is_pydata_sparse_namespace,
is_array_api_strict_namespace,
is_array_api_strict_namespace, is_ndonnx_namespace,
)

from array_api_compat import (
device, is_array_api_obj, is_lazy_array, is_writeable_array, size, to_device
)
from ._helpers import import_, wrapped_libraries, all_libraries
from ._helpers import all_libraries, import_, wrapped_libraries, xfail


is_array_functions = {
'numpy': 'is_numpy_array',
Expand All @@ -25,6 +27,7 @@
'dask.array': 'is_dask_array',
'jax.numpy': 'is_jax_array',
'sparse': 'is_pydata_sparse_array',
'ndonnx': 'is_ndonnx_array',
}

is_namespace_functions = {
Expand All @@ -35,6 +38,7 @@
'jax.numpy': 'is_jax_namespace',
'sparse': 'is_pydata_sparse_namespace',
'array_api_strict': 'is_array_api_strict_namespace',
'ndonnx': 'is_ndonnx_namespace',
}


Expand Down Expand Up @@ -185,7 +189,10 @@ class C:


@pytest.mark.parametrize("library", all_libraries)
def test_device(library):
def test_device(library, request):
if library == "ndonnx":
xfail(request, reason="Needs ndonnx >=0.9.4")

xp = import_(library, wrapper=True)

# We can't test much for device() and to_device() other than that
Expand Down Expand Up @@ -223,17 +230,19 @@ def test_to_device_host(library):
@pytest.mark.parametrize("target_library", is_array_functions.keys())
@pytest.mark.parametrize("source_library", is_array_functions.keys())
def test_asarray_cross_library(source_library, target_library, request):
def _xfail(reason: str) -> None:
# Allow rest of test to execute instead of immediately xfailing
# xref https://github.com/pandas-dev/pandas/issues/38902
request.node.add_marker(pytest.mark.xfail(reason=reason))

if source_library == "dask.array" and target_library == "torch":
# TODO: remove xfail once
# https://github.com/dask/dask/issues/8260 is resolved
_xfail(reason="Bug in dask raising error on conversion")
xfail(request, reason="Bug in dask raising error on conversion")
elif (
source_library == "ndonnx"
and target_library not in ("array_api_strict", "ndonnx", "numpy")
):
xfail(request, reason="The truth value of lazy Array Array(dtype=Boolean) is unknown")
elif source_library == "ndonnx" and target_library == "numpy":
xfail(request, reason="produces numpy array of ndonnx scalar arrays")
elif source_library == "jax.numpy" and target_library == "torch":
_xfail(reason="casts int to float")
xfail(request, reason="casts int to float")
elif source_library == "cupy" and target_library != "cupy":
# cupy explicitly disallows implicit conversions to CPU
pytest.skip(reason="cupy does not support implicit conversion to CPU")
Expand Down
Loading