Skip to content

Commit

Permalink
fix(ci): pyright is now happy
Browse files Browse the repository at this point in the history
  • Loading branch information
jeertmans committed Jan 10, 2024
1 parent 1d1d608 commit a4fd1c5
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 19 deletions.
4 changes: 4 additions & 0 deletions python/differt/_core/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .._core import geometry, rt

__all__ = ("geometry", "rt", "__all__")

__version__: str
Empty file removed python/differt/_core/geometry.pyi
Empty file.
3 changes: 3 additions & 0 deletions python/differt/_core/geometry/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ..._core.geometry import triangle_mesh

__all__ = ("triangle_mesh",)
Empty file removed python/differt/_core/rt.pyi
Empty file.
3 changes: 3 additions & 0 deletions python/differt/_core/rt/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ..._core.rt import utils

__all__ = ("utils",)
12 changes: 6 additions & 6 deletions python/differt/plotting/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from vispy.scene.canvas import SceneCanvas as Canvas


@dispatch
@dispatch # type: ignore
def draw_mesh(
vertices: Float[np.ndarray, "num_vertices 3"],
triangles: UInt[np.ndarray, "num_triangles 3"],
**kwargs: Any,
) -> Canvas | MplFigure | Figure:
) -> Canvas | MplFigure | Figure: # type: ignore
"""
Plot a 3D mesh made of triangles.
Expand Down Expand Up @@ -87,10 +87,10 @@ def _(
return fig.add_mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, **kwargs)


@dispatch
@dispatch # type: ignore
def draw_paths(
paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any
) -> Canvas | MplFigure | Figure:
) -> Canvas | MplFigure | Figure: # type: ignore
"""
Plot a batch of paths of the same length.
Expand Down Expand Up @@ -142,13 +142,13 @@ def _(paths: Float[np.ndarray, "*batch path_length 3"], **kwargs: Any) -> Figure
return fig


@dispatch
@dispatch # type: ignore
def draw_markers(
markers: Float[np.ndarray, "num_markers 3"],
labels: Sequence[str] | None = None,
text_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Canvas | MplFigure | Figure:
) -> Canvas | MplFigure | Figure: # type: ignore
"""
Plot markers and, optionally, their label.
Expand Down
22 changes: 11 additions & 11 deletions tests/plotting/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,18 @@ def monkey_import_module(name: str, *args: Any, **kwargs: Any) -> ModuleType:
m.setattr(builtins, "__import__", monkey_import)
m.setattr(importlib, "import_module", monkey_import_module)

yield m
yield m # type: ignore

return ctx


@dispatch
def my_plot_unimplemented(**kwargs: dict[str, Any]) -> SceneCanvas | MplFigure | Figure:
@dispatch # type: ignore
def my_plot_unimplemented(**kwargs: dict[str, Any]) -> SceneCanvas | MplFigure | Figure: # type: ignore
"""A plot function with no backend implementation."""


@dispatch
def my_plot(**kwargs: dict[str, Any]) -> SceneCanvas | MplFigure | Figure:
@dispatch # type: ignore
def my_plot(**kwargs: dict[str, Any]) -> SceneCanvas | MplFigure | Figure: # type: ignore
"""A plot function with dummy backend implementations."""


Expand All @@ -90,7 +90,7 @@ def _(**kwargs): # type: ignore[no-untyped-def]
def test_unimplemented(backend: str | None) -> None:
with pytest.raises(NotImplementedError, match="No backend implementation for"):
if backend:
_ = my_plot_unimplemented(backend=backend)
_ = my_plot_unimplemented(backend=backend) # type: ignore
else:
_ = my_plot_unimplemented()

Expand Down Expand Up @@ -138,20 +138,20 @@ def test_missing_backend_module(
ImportError,
match=f"An import error occured when dispatching plot utility to backend '{backend}'.",
):
_ = my_plot(backend=backend)
_ = my_plot(backend=backend) # type: ignore


@pytest.mark.parametrize(
"backend,rtype",
(("vispy", SceneCanvas), ("matplotlib", MplFigure), ("plotly", Figure)),
)
def test_return_type(backend: str, rtype: type) -> None:
ret = my_plot(backend=backend)
ret = my_plot(backend=backend) # type: ignore
assert isinstance(ret, rtype), f"{ret!r} is not of type {rtype}"


def test_process_vispy_kwargs() -> None:
kwargs = {"color": "red"}
kwargs: dict[str, Any] = {"color": "red"}
canvas, view = process_vispy_kwargs(kwargs)
assert view == view_from_canvas(canvas)

Expand Down Expand Up @@ -180,7 +180,7 @@ def test_process_vispy_kwargs() -> None:


def test_process_matplotlib_kwargs() -> None:
kwargs = {"color": "green"}
kwargs: dict[str, Any] = {"color": "green"}
fig, ax = process_matplotlib_kwargs(kwargs)

kwargs["figure"] = fig
Expand Down Expand Up @@ -210,7 +210,7 @@ def test_process_matplotlib_kwargs() -> None:


def test_process_plotly_kwargs() -> None:
kwargs = {"color": "blue"}
kwargs: dict[str, Any] = {"color": "blue"}
fig = process_plotly_kwargs(kwargs)

kwargs["figure"] = fig
Expand Down
4 changes: 3 additions & 1 deletion tests/rt/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

import chex
import jax.numpy as jnp
import pytest
Expand All @@ -7,7 +9,7 @@
from differt.utils import sorted_array2


def uint_array(array_like: Array) -> Array:
def uint_array(array_like: Any) -> Array:
return jnp.array(array_like, dtype=jnp.uint32)


Expand Down
3 changes: 2 additions & 1 deletion tests/test_differt_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import differt
from differt import _core


def test_same_version() -> None:
assert differt.__version__ == differt._core.__version__
assert differt.__version__ == _core.__version__

0 comments on commit a4fd1c5

Please sign in to comment.