Skip to content

Commit

Permalink
➕ dep-add(quax-blocks): replace quaxed.experimental with quax-blocks …
Browse files Browse the repository at this point in the history
…v0.1 (#395)
  • Loading branch information
nstarman authored Feb 7, 2025
1 parent a2d278b commit 66e4118
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 27 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ repos:
additional_dependencies:
- plum-dispatch>=2.5.6
- quaxed>=0.8.1
- quax-blocks>=0.1

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
Expand Down
3 changes: 1 addition & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,7 @@ Arithmetic will raise an error if the units are incompatible:
>>> z = u.Quantity(5.0, "second")
>>> try: x + z
... except Exception as e: print(e)
...
's' (time) and 'm' (length) are not convertible
unsupported operand type(s) for +: 'Quantity[PhysicalType('length')]' and 'Quantity[PhysicalType('time')]'
```

### Converting Units
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"optional-dependencies>=0.3.2",
"plum-dispatch>=2.5.6",
"quax>=0.0.5",
"quax-blocks>=0.1",
"quaxed>=0.8.1",
"xmmutablemap>=0.1",
"zeroth>=1.0.0",
Expand Down
24 changes: 12 additions & 12 deletions src/unxt/_src/quantity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import equinox as eqx
import jax
import jax.core
import quax_blocks
from astropy.units import UnitConversionError
from jax._src.numpy.array_methods import _IndexUpdateHelper, _IndexUpdateRef
from jaxtyping import Array, ArrayLike, Bool, ScalarLike, Shaped
Expand All @@ -20,7 +21,6 @@

import quaxed.numpy as jnp
from dataclassish import replace
from quaxed.experimental import arrayish

from .api import is_unit_convertible, uconvert, ustrip
from .mixins import AstropyQuantityCompatMixin, IPythonReprMixin, NumPyCompatMixin
Expand All @@ -39,15 +39,15 @@ class AbstractQuantity(
NumPyCompatMixin,
IPythonReprMixin,
ArrayValue,
arrayish.NumpyBinaryOpsMixin[Any, "AbstractQuantity"],
arrayish.NumpyComparisonMixin[Any, Bool[Array, "*shape"]], # TODO: shape hint
arrayish.NumpyUnaryMixin["AbstractQuantity"],
arrayish.NumpyRoundMixin["AbstractQuantity"],
arrayish.NumpyTruncMixin["AbstractQuantity"],
arrayish.NumpyFloorMixin["AbstractQuantity"],
arrayish.NumpyCeilMixin["AbstractQuantity"],
arrayish.LaxLenMixin,
arrayish.LaxLengthHintMixin,
quax_blocks.NumpyBinaryOpsMixin[Any, "AbstractQuantity"],
quax_blocks.NumpyComparisonMixin[Any, Bool[Array, "*shape"]], # TODO: shape hint
quax_blocks.NumpyUnaryMixin["AbstractQuantity"],
quax_blocks.NumpyRoundMixin["AbstractQuantity"],
quax_blocks.NumpyTruncMixin["AbstractQuantity"],
quax_blocks.NumpyFloorMixin["AbstractQuantity"],
quax_blocks.NumpyCeilMixin["AbstractQuantity"],
quax_blocks.LaxLenMixin,
quax_blocks.LaxLengthHintMixin,
):
"""Represents an array, with each axis bound to a name.
Expand Down Expand Up @@ -315,7 +315,7 @@ def __rmod__(self, other: Any) -> Any:
return self % other

# required to override mixin methods
__eq__ = arrayish.NumpyEqMixin.__eq__
__eq__ = quax_blocks.NumpyEqMixin.__eq__

# ---------------------------------------------------------------
# methods
Expand Down Expand Up @@ -898,7 +898,7 @@ def __repr__(self) -> str:
return super().__repr__().replace("_IndexUpdateRef", "_QuantityIndexUpdateRef")

@override
def get( # type: ignore[override]
def get(
self,
*,
indices_are_sorted: bool = False,
Expand Down
12 changes: 6 additions & 6 deletions src/unxt/_src/quantity/register_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def add_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity:
... x + y
... except Exception as e:
... print(e)
'km' (length) and '' (dimensionless) are not convertible
unsupported operand type(s) for +: 'jaxlib.xla_extension.ArrayImpl' and 'BareQuantity'
>>> y = BareQuantity(100.0, "")
>>> jnp.add(x, y)
Expand All @@ -215,7 +215,7 @@ def add_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity:
... x + q2
... except Exception as e:
... print(e)
'km' (length) and '' (dimensionless) are not convertible
unsupported operand type(s) for +: 'jaxlib.xla_extension.ArrayImpl' and 'Quantity[PhysicalType('length')]'
>>> q2 = Quantity(100.0, "")
>>> jnp.add(x, q2)
Expand All @@ -229,7 +229,7 @@ def add_p_vaq(x: ArrayLike, y: AbstractQuantity) -> AbstractQuantity:
>>> jnp.add(x, q2 / q3)
Quantity['dimensionless'](Array(501., dtype=float32, weak_type=True), unit='')
"""
""" # noqa: E501
y = uconvert(one, y)
return replace(y, value=qlax.add(x, ustrip(y)))

Expand Down Expand Up @@ -259,7 +259,7 @@ def add_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity:
... q1 + y
... except Exception as e:
... print(e)
'km' (length) and '' (dimensionless) are not convertible
unsupported operand type(s) for +: 'BareQuantity' and 'jaxlib.xla_extension.ArrayImpl'
>>> q1 = BareQuantity(100.0, "")
>>> jnp.add(q1, y)
Expand Down Expand Up @@ -288,7 +288,7 @@ def add_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity:
... q1 + y
... except Exception as e:
... print(e)
'km' (length) and '' (dimensionless) are not convertible
unsupported operand type(s) for +: 'Quantity[PhysicalType('length')]' and 'jaxlib.xla_extension.ArrayImpl'
>>> q1 = Quantity(100.0, "")
>>> jnp.add(q1, y)
Expand All @@ -302,7 +302,7 @@ def add_p_aqv(x: AbstractQuantity, y: ArrayLike) -> AbstractQuantity:
>>> jnp.add(q2 / q3, y)
Quantity['dimensionless'](Array(501., dtype=float32, weak_type=True), unit='')
"""
""" # noqa: E501
x = uconvert(one, x)
return replace(x, value=qlax.add(ustrip(x), y))

Expand Down
7 changes: 3 additions & 4 deletions tests/unit/test_quantity.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,9 +335,8 @@ def test_eq():
assert np.array_equal(q == u.Quantity(2, "m"), [False, True, False])

# Test with incompatible units
# TODO: better equinox exception matching
with pytest.raises(Exception): # noqa: B017, PT011
_ = q == u.Quantity(0, "s")
_ = jnp.equal(q, u.Quantity(0, "s"))

# Test special case w/out units
assert u.Quantity(0, "m") == 0
Expand All @@ -361,9 +360,9 @@ def test_ne():
# Test with incompatible units
# TODO: better equinox exception matching
with pytest.raises(Exception): # noqa: B017, PT011
_ = q != u.Quantity(0, "s")
_ = jnp.not_equal(q, u.Quantity(0, "s"))
with pytest.raises(Exception): # noqa: B017, PT011
_ = q != u.Quantity(4, "s")
_ = jnp.not_equal(q, u.Quantity(4, "s"))

# Test special case w/out units
assert u.Quantity(1, "m") != 0
Expand Down
26 changes: 23 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 66e4118

Please sign in to comment.