diff --git a/pyproject.toml b/pyproject.toml index 90c6674..2657b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ "pytest-arraydiff", "pytest-cov >=3", "pytest-env", - "sybil", + "sybil != 7.1.0", ] [project.urls] diff --git a/src/coordinax/_coordinax/base/base_acc.py b/src/coordinax/_coordinax/base/base_acc.py index babb792..215d380 100644 --- a/src/coordinax/_coordinax/base/base_acc.py +++ b/src/coordinax/_coordinax/base/base_acc.py @@ -3,14 +3,13 @@ __all__ = ["AbstractAcceleration"] from abc import abstractmethod -from dataclasses import replace from functools import partial from typing import TYPE_CHECKING, Any, TypeVar import jax from quax import register -import quaxed.array_api as xp +import quaxed.numpy as jnp from dataclassish import field_items from quaxed import lax as qlax from unxt import Quantity @@ -107,7 +106,7 @@ def __neg__(self) -> "Self": Quantity['angular acceleration'](Array(-1., dtype=float32), unit='mas / yr2') """ - return replace(self, **{k: -v for k, v in field_items(self)}) + return jax.tree.map(jnp.negative, self) # =============================================================== # Convenience methods @@ -191,7 +190,7 @@ def _mul_acc_time(lhs: AbstractAcceleration, rhs: Quantity["time"]) -> AbstractV """ # TODO: better access to corresponding fields return lhs.integral_cls.constructor( - {k.replace("2", ""): xp.multiply(v, rhs) for k, v in field_items(lhs)} + {k.replace("2", ""): jnp.multiply(v, rhs) for k, v in field_items(lhs)} ) diff --git a/src/coordinax/_coordinax/base/base_pos.py b/src/coordinax/_coordinax/base/base_pos.py index 28c8657..2da52cc 100644 --- a/src/coordinax/_coordinax/base/base_pos.py +++ b/src/coordinax/_coordinax/base/base_pos.py @@ -6,7 +6,8 @@ from dataclasses import replace from functools import partial from inspect import isabstract -from typing import TYPE_CHECKING, Any, TypeVar +from typing import Any, TypeVar +from typing_extensions import override import equinox as eqx import jax @@ -22,11 +23,9 @@ from .base import AbstractVector from .mixins import AvalMixin from coordinax._coordinax import typing as ct +from coordinax._coordinax.funcs import represent_as from coordinax._coordinax.utils import classproperty -if TYPE_CHECKING: - from typing_extensions import Self - PosT = TypeVar("PosT", bound="AbstractPosition") # TODO: figure out public API for this @@ -87,36 +86,14 @@ def differential_cls(cls) -> type["AbstractVelocity"]: raise NotImplementedError # =============================================================== - # Array - - # ----------------------------------------------------- # Unary operations - def __neg__(self) -> "Self": - """Negate the vector. - - The default implementation is to go through Cartesian coordinates. - - Examples - -------- - >>> import coordinax as cx - >>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m") - >>> -vec - CartesianPosition3D( - x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), - y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), - z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")) - ) - >>> (-vec).x - Quantity['length'](Array(-1., dtype=float32), unit='m') - - """ - cart = self.represent_as(self._cartesian_cls) - return (-cart).represent_as(type(self)) + __neg__ = jnp.negative # =============================================================== # Convenience methods + @override def represent_as(self, target: type[PosT], /, *args: Any, **kwargs: Any) -> PosT: """Represent the vector as another type. @@ -149,8 +126,6 @@ def represent_as(self, target: type[PosT], /, *args: Any, **kwargs: Any) -> PosT Distance(Array(3.7416575, dtype=float32), unit='m') """ - from coordinax import represent_as # pylint: disable=import-outside-toplevel - return represent_as(self, target, *args, **kwargs) @partial(jax.jit, inline=True) @@ -403,6 +378,34 @@ def _mul_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> Quantity: # ------------------------------------------------ +@register(jax.lax.neg_p) # type: ignore[misc] +def _neg_pos(obj: AbstractPosition, /) -> AbstractPosition: + """Negate the vector. + + The default implementation is to go through Cartesian coordinates. + + Examples + -------- + >>> import coordinax as cx + >>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m") + >>> -vec + CartesianPosition3D( + x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), + y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")), + z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")) + ) + >>> (-vec).x + Quantity['length'](Array(-1., dtype=float32), unit='m') + + """ + cart = represent_as(obj, obj._cartesian_cls) # noqa: SLF001 + negcart = jnp.negative(cart) + return represent_as(negcart, type(obj)) + + +# ------------------------------------------------ + + @register(jax.lax.reshape_p) # type: ignore[misc] def _reshape_pos( operand: AbstractPosition, *, new_sizes: tuple[int, ...], **kwargs: Any diff --git a/src/coordinax/_coordinax/base/base_vel.py b/src/coordinax/_coordinax/base/base_vel.py index 7e9cbc5..b7f2341 100644 --- a/src/coordinax/_coordinax/base/base_vel.py +++ b/src/coordinax/_coordinax/base/base_vel.py @@ -3,13 +3,13 @@ __all__ = ["AbstractVelocity"] from abc import abstractmethod -from dataclasses import replace from functools import partial from typing import TYPE_CHECKING, Any, TypeVar import jax from quax import register +import quaxed.numpy as jnp from dataclassish import field_items from unxt import Quantity @@ -123,7 +123,7 @@ def __neg__(self) -> "Self": Quantity['angular frequency'](Array(-1., dtype=float32), unit='mas / yr') """ - return replace(self, **{k: -v for k, v in field_items(self)}) + return jax.tree.map(jnp.negative, self) # =============================================================== # Convenience methods diff --git a/src/coordinax/_coordinax/d1/cartesian.py b/src/coordinax/_coordinax/d1/cartesian.py index 2ae2458..5c54032 100644 --- a/src/coordinax/_coordinax/d1/cartesian.py +++ b/src/coordinax/_coordinax/d1/cartesian.py @@ -15,7 +15,7 @@ from jaxtyping import ArrayLike from quax import register -import quaxed.array_api as xp +import quaxed.numpy as jnp from quaxed import lax as qlax from unxt import Quantity @@ -63,24 +63,6 @@ class CartesianPosition1D(AbstractPosition1D): def differential_cls(cls) -> type["CartesianVelocity1D"]: return CartesianVelocity1D - # ----------------------------------------------------- - # Unary operations - - def __neg__(self) -> "Self": - """Negate the vector. - - Examples - -------- - >>> import coordinax as cx - >>> q = cx.CartesianPosition1D.constructor([1], "kpc") - >>> -q - CartesianPosition1D( - x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("kpc")) - ) - - """ - return replace(self, x=-self.x) - # ------------------------------------------------------------------- # Method dispatches @@ -92,13 +74,13 @@ def _add_qq(lhs: CartesianPosition1D, rhs: AbstractPosition, /) -> CartesianPosi Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> import coordinax as cx >>> q = cx.CartesianPosition1D.constructor([1], "kpc") >>> r = cx.RadialPosition.constructor([1], "kpc") - >>> qpr = xp.add(q, r) + >>> qpr = jnp.add(q, r) >>> qpr CartesianPosition1D( x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("kpc")) @@ -120,12 +102,12 @@ def _mul_ac1(lhs: ArrayLike, rhs: CartesianPosition1D, /) -> CartesianPosition1D Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianPosition1D(x=Quantity(1, "m")) - >>> xp.multiply(2, v).x + >>> jnp.multiply(2, v).x Quantity['length'](Array(2., dtype=float32), unit='m') >>> (2 * v).x @@ -141,6 +123,21 @@ def _mul_ac1(lhs: ArrayLike, rhs: CartesianPosition1D, /) -> CartesianPosition1D return replace(rhs, x=lhs * rhs.x) +@register(jax.lax.neg_p) # type: ignore[misc] +def _neg_p_cart1d_pos(obj: CartesianPosition1D, /) -> CartesianPosition1D: + """Negate the `coordinax.CartesianPosition1D`. + + Examples + -------- + >>> import coordinax as cx + >>> q = cx.CartesianPosition1D.constructor([1], "km") + >>> (-q).x + Quantity['length'](Array(-1., dtype=float32), unit='km') + + """ + return jax.tree.map(qlax.neg, obj) + + @register(jax.lax.sub_p) # type: ignore[misc] def _sub_q1d_pos( self: CartesianPosition1D, other: AbstractPosition, / @@ -149,14 +146,14 @@ def _sub_q1d_pos( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> q = cx.CartesianPosition1D.constructor(Quantity([1], "kpc")) >>> r = cx.RadialPosition.constructor(Quantity([1], "kpc")) - >>> qmr = xp.subtract(q, r) + >>> qmr = jnp.subtract(q, r) >>> qmr CartesianPosition1D( x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("kpc")) @@ -205,7 +202,7 @@ def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed: Quantity['speed'](Array(1, dtype=int32), unit='km / s') """ - return xp.abs(self.d_x) + return jnp.abs(self.d_x) # ------------------------------------------------------------------- @@ -220,12 +217,12 @@ def _add_pp( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianVelocity1D.constructor([1], "km/s") - >>> vec = xp.add(v, v) + >>> vec = jnp.add(v, v) >>> vec CartesianVelocity1D( d_x=Quantity[...]( value=i32[], unit=Unit("km / s") ) @@ -246,12 +243,12 @@ def _mul_vcart(lhs: ArrayLike, rhs: CartesianVelocity1D, /) -> CartesianVelocity Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianVelocity1D(d_x=Quantity(1, "m/s")) - >>> vec = xp.multiply(2, v) + >>> vec = jnp.multiply(2, v) >>> vec CartesianVelocity1D( d_x=Quantity[...]( value=i32[], unit=Unit("m / s") ) @@ -304,7 +301,7 @@ def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc: Quantity['acceleration'](Array(1, dtype=int32), unit='km / s2') """ - return xp.abs(self.d2_x) + return jnp.abs(self.d2_x) @register(jax.lax.add_p) # type: ignore[misc] @@ -315,12 +312,12 @@ def _add_aa( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianAcceleration1D.constructor([1], "km/s2") - >>> vec = xp.add(v, v) + >>> vec = jnp.add(v, v) >>> vec CartesianAcceleration1D( d2_x=Quantity[...](value=i32[], unit=Unit("km / s2")) @@ -341,12 +338,12 @@ def _mul_aq(lhs: ArrayLike, rhs: CartesianAcceleration1D, /) -> CartesianAcceler Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianAcceleration1D(d2_x=Quantity(1, "m/s2")) - >>> vec = xp.multiply(2, v) + >>> vec = jnp.multiply(2, v) >>> vec CartesianAcceleration1D( d2_x=Quantity[...](value=i32[], unit=Unit("m / s2")) diff --git a/src/coordinax/_coordinax/d2/cartesian.py b/src/coordinax/_coordinax/d2/cartesian.py index beffa8d..30256a0 100644 --- a/src/coordinax/_coordinax/d2/cartesian.py +++ b/src/coordinax/_coordinax/d2/cartesian.py @@ -15,7 +15,7 @@ from jaxtyping import ArrayLike, Shaped from quax import register -import quaxed.array_api as xp +import quaxed.numpy as jnp from quaxed import lax as qlax from unxt import AbstractQuantity, Quantity @@ -45,24 +45,6 @@ class CartesianPosition2D(AbstractPosition2D): def differential_cls(cls) -> type["CartesianVelocity2D"]: return CartesianVelocity2D - # ----------------------------------------------------- - # Unary operations - - def __neg__(self) -> "Self": - """Negate the vector. - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - - >>> q = cx.CartesianPosition2D.constructor([1, 2], "kpc") - >>> (-q).x - Quantity['length'](Array(-1., dtype=float32), unit='kpc') - - """ - return replace(self, x=-self.x, y=-self.y) - # ----------------------------------------------------- @@ -101,7 +83,7 @@ def _add_cart2d_pos( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx @@ -110,7 +92,7 @@ def _add_cart2d_pos( >>> (cart + polr).x Quantity['length'](Array(0.9999999, dtype=float32), unit='kpc') - >>> xp.add(cart, polr).x + >>> jnp.add(cart, polr).x Quantity['length'](Array(0.9999999, dtype=float32), unit='kpc') """ @@ -124,12 +106,12 @@ def _mul_v_cart2d(lhs: ArrayLike, rhs: CartesianPosition2D, /) -> CartesianPosit Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianPosition2D.constructor(Quantity([3, 4], "m")) - >>> xp.multiply(5, v).x + >>> jnp.multiply(5, v).x Quantity['length'](Array(15., dtype=float32), unit='m') """ @@ -142,6 +124,21 @@ def _mul_v_cart2d(lhs: ArrayLike, rhs: CartesianPosition2D, /) -> CartesianPosit return replace(rhs, x=lhs * rhs.x, y=lhs * rhs.y) +@register(jax.lax.neg_p) # type: ignore[misc] +def _neg_p_cart2d_pos(obj: CartesianPosition2D, /) -> CartesianPosition2D: + """Negate the `coordinax.CartesianPosition2D`. + + Examples + -------- + >>> import coordinax as cx + >>> q = cx.CartesianPosition2D.constructor([1, 2], "km") + >>> (-q).x + Quantity['length'](Array(-1., dtype=float32), unit='km') + + """ + return jax.tree.map(qlax.neg, obj) + + @register(jax.lax.sub_p) # type: ignore[misc] def _sub_cart2d_pos2d( lhs: CartesianPosition2D, rhs: AbstractPosition, / @@ -228,7 +225,7 @@ def _add_pp( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx @@ -236,7 +233,7 @@ def _add_pp( >>> (v + v).d_x Quantity['speed'](Array(2., dtype=float32), unit='km / s') - >>> xp.add(v, v).d_x + >>> jnp.add(v, v).d_x Quantity['speed'](Array(2., dtype=float32), unit='km / s') """ @@ -249,7 +246,7 @@ def _mul_vp(lhs: ArrayLike, rhts: CartesianVelocity2D, /) -> CartesianVelocity2D Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx @@ -257,7 +254,7 @@ def _mul_vp(lhs: ArrayLike, rhts: CartesianVelocity2D, /) -> CartesianVelocity2D >>> (5 * v).d_x Quantity['speed'](Array(15., dtype=float32), unit='m / s') - >>> xp.multiply(5, v).d_x + >>> jnp.multiply(5, v).d_x Quantity['speed'](Array(15., dtype=float32), unit='m / s') """ @@ -307,7 +304,7 @@ def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc: Quantity['acceleration'](Array(5., dtype=float32), unit='km / s2') """ - return xp.sqrt(self.d2_x**2 + self.d2_y**2) + return jnp.sqrt(self.d2_x**2 + self.d2_y**2) # ----------------------------------------------------- @@ -316,7 +313,7 @@ def norm(self, _: AbstractVelocity2D | None = None, /) -> ct.BatchableAcc: @CartesianAcceleration2D.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001 def constructor( cls: type[CartesianAcceleration2D], - obj: Shaped[AbstractQuantity, "*batch 2"], + obj: AbstractQuantity, /, ) -> CartesianAcceleration2D: """Construct a 2D Cartesian velocity. @@ -349,7 +346,7 @@ def _add_aa( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx @@ -357,7 +354,7 @@ def _add_aa( >>> (v + v).d2_x Quantity['acceleration'](Array(6., dtype=float32), unit='km / s2') - >>> xp.add(v, v).d2_x + >>> jnp.add(v, v).d2_x Quantity['acceleration'](Array(6., dtype=float32), unit='km / s2') """ @@ -372,12 +369,12 @@ def _mul_va( Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianAcceleration2D.constructor(Quantity([3, 4], "m/s2")) - >>> xp.multiply(5, v).d2_x + >>> jnp.multiply(5, v).d2_x Quantity['acceleration'](Array(15., dtype=float32), unit='m / s2') >>> (5 * v).d2_x @@ -391,24 +388,3 @@ def _mul_va( # Scale the components return replace(rhts, d2_x=lhs * rhts.d2_x, d2_y=lhs * rhts.d2_y) - - -@register(jax.lax.sub_p) # type: ignore[misc] -def _sub_cart2d_pos2d( - self: CartesianPosition2D, other: AbstractPosition, / -) -> CartesianPosition2D: - """Subtract two vectors. - - Examples - -------- - >>> from unxt import Quantity - >>> from coordinax import CartesianPosition2D, PolarPosition - >>> cart = CartesianPosition2D.constructor(Quantity([1, 2], "kpc")) - >>> polr = PolarPosition(r=Quantity(3, "kpc"), phi=Quantity(90, "deg")) - - >>> (cart - polr).x - Quantity['length'](Array(1.0000001, dtype=float32), unit='kpc') - - """ - cart = other.represent_as(CartesianPosition2D) - return jax.tree.map(qlax.sub, self, cart) diff --git a/src/coordinax/_coordinax/d3/cartesian.py b/src/coordinax/_coordinax/d3/cartesian.py index 9303225..a1682ae 100644 --- a/src/coordinax/_coordinax/d3/cartesian.py +++ b/src/coordinax/_coordinax/d3/cartesian.py @@ -9,14 +9,15 @@ from dataclasses import fields, replace from functools import partial from typing import final +from typing_extensions import override import equinox as eqx import jax from jaxtyping import ArrayLike, Shaped from quax import register -import quaxed.array_api as xp import quaxed.lax as qlax +import quaxed.numpy as jnp from dataclassish import field_items from unxt import AbstractQuantity, Quantity @@ -26,6 +27,9 @@ from coordinax._coordinax.base.mixins import AvalMixin from coordinax._coordinax.utils import classproperty +##################################################################### +# Position + @final class CartesianPosition3D(AbstractPosition3D): @@ -46,27 +50,13 @@ class CartesianPosition3D(AbstractPosition3D): ) r"""Z coordinate :math:`z \in (-\infty,+\infty)`.""" + @override @classproperty @classmethod def differential_cls(cls) -> type["CartesianVelocity3D"]: + """Return the differential of the class.""" return CartesianVelocity3D - # ----------------------------------------------------- - # Unary operations - - def __neg__(self) -> "Self": - """Negate the vector. - - Examples - -------- - >>> import coordinax as cx - >>> q = cx.CartesianPosition3D.constructor([1, 2, 3], "kpc") - >>> (-q).x - Quantity['length'](Array(-1., dtype=float32), unit='kpc') - - """ - return replace(self, x=-self.x, y=-self.y, z=-self.z) - # ----------------------------------------------------- @@ -124,6 +114,21 @@ def _add_cart3d_pos( ) +@register(jax.lax.neg_p) # type: ignore[misc] +def _neg_p_cart3d_pos(obj: CartesianPosition3D, /) -> CartesianPosition3D: + """Negate the `coordinax.CartesianPosition3D`. + + Examples + -------- + >>> import coordinax as cx + >>> q = cx.CartesianPosition3D.constructor([1, 2, 3], "kpc") + >>> (-q).x + Quantity['length'](Array(-1., dtype=float32), unit='kpc') + + """ + return jax.tree.map(qlax.neg, obj) + + @register(jax.lax.sub_p) # type: ignore[misc] def _sub_cart3d_pos( lhs: CartesianPosition3D, rhs: AbstractPosition, / @@ -146,6 +151,7 @@ def _sub_cart3d_pos( ##################################################################### +# Velocity @final @@ -190,7 +196,7 @@ def norm(self, _: AbstractPosition3D | None = None, /) -> ct.BatchableSpeed: Quantity['speed'](Array(3.7416575, dtype=float32), unit='km / s') """ - return xp.sqrt(self.d_x**2 + self.d_y**2 + self.d_z**2) + return jnp.sqrt(self.d_x**2 + self.d_y**2 + self.d_z**2) # ----------------------------------------------------- @@ -264,6 +270,7 @@ def _sub_v3_v3( ##################################################################### +# Acceleration @final @@ -306,7 +313,7 @@ def norm(self, _: AbstractVelocity3D | None = None, /) -> ct.BatchableAcc: Quantity['acceleration'](Array(3.7416575, dtype=float32), unit='km / s2') """ - return xp.sqrt(self.d2_x**2 + self.d2_y**2 + self.d2_z**2) + return jnp.sqrt(self.d2_x**2 + self.d2_y**2 + self.d2_z**2) # ----------------------------------------------------- @@ -354,12 +361,12 @@ def _mul_ac3(lhs: ArrayLike, rhs: CartesianPosition3D, /) -> CartesianPosition3D Examples -------- - >>> import quaxed.array_api as xp + >>> import quaxed.numpy as jnp >>> from unxt import Quantity >>> import coordinax as cx >>> v = cx.CartesianPosition3D.constructor([1, 2, 3], "kpc") - >>> xp.multiply(2, v).x + >>> jnp.multiply(2, v).x Quantity['length'](Array(2., dtype=float32), unit='kpc') """ diff --git a/src/coordinax/_coordinax/dn/cartesian.py b/src/coordinax/_coordinax/dn/cartesian.py index 3373d02..260bd0e 100644 --- a/src/coordinax/_coordinax/dn/cartesian.py +++ b/src/coordinax/_coordinax/dn/cartesian.py @@ -13,6 +13,7 @@ from plum import conversion_method from quax import register +import quaxed.lax as qlax import quaxed.numpy as jnp from unxt import Quantity @@ -95,22 +96,7 @@ def differential_cls(cls) -> type["CartesianVelocityND"]: # type: ignore[overri # ----------------------------------------------------- # Unary operations - def __neg__(self) -> "Self": - """Negate the vector. - - Examples - -------- - >>> from unxt import Quantity - >>> import coordinax as cx - - A 3D vector: - - >>> vec = cx.CartesianPositionND(Quantity([1, 2, 3], "kpc")) - >>> (-vec).q - Quantity['length'](Array([-1., -2., -3.], dtype=float32), unit='kpc') - - """ - return replace(self, q=-self.q) + __neg__ = jnp.negative # ----------------------------------------------------- @@ -251,6 +237,25 @@ def _mul_vcnd(lhs: ArrayLike, rhs: CartesianPositionND, /) -> CartesianPositionN return replace(rhs, q=lhs * rhs.q) +@register(jax.lax.neg_p) # type: ignore[misc] +def _neg_p_cartnd_pos(obj: CartesianPositionND, /) -> CartesianPositionND: + """Negate the `coordinax.CartesianPositionND`. + + Examples + -------- + >>> from unxt import Quantity + >>> import coordinax as cx + + A 3D vector: + + >>> vec = cx.CartesianPositionND(Quantity([1, 2, 3], "kpc")) + >>> (-vec).q + Quantity['length'](Array([-1., -2., -3.], dtype=float32), unit='kpc') + + """ + return jax.tree.map(qlax.neg, obj) + + @register(jax.lax.sub_p) # type: ignore[misc] def _sub_cnd_pos( lhs: CartesianPositionND, rhs: AbstractPosition, / diff --git a/src/coordinax/_coordinax/space.py b/src/coordinax/_coordinax/space.py index aef21e3..8aa4a20 100644 --- a/src/coordinax/_coordinax/space.py +++ b/src/coordinax/_coordinax/space.py @@ -451,6 +451,7 @@ def asdict( """ return dict_factory(self._data) + @override @classproperty @classmethod def components(cls) -> tuple[str, ...]: @@ -462,6 +463,7 @@ def units(self) -> MappingProxyType[str, Unit]: """Get the units of the vector's components.""" raise NotImplementedError # TODO: implement this + @override @property def dtypes(self) -> MappingProxyType[str, MappingProxyType[str, jnp.dtype]]: """Get the dtypes of the vector's components. @@ -483,6 +485,7 @@ def dtypes(self) -> MappingProxyType[str, MappingProxyType[str, jnp.dtype]]: """ # noqa: E501 return MappingProxyType({k: v.dtypes for k, v in self.items()}) + @override @property def devices(self) -> MappingProxyType[str, MappingProxyType[str, Device]]: """Get the devices of the vector's components.