Skip to content

Commit

Permalink
feat: dispatch neg_p (#185)
Browse files Browse the repository at this point in the history
* feat: dispatch neg_p
* ci: exclude sybil 7.1.0
See simplistix/sybil#131

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Sep 16, 2024
1 parent 0a37d4e commit 964e879
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 163 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
"pytest-arraydiff",
"pytest-cov >=3",
"pytest-env",
"sybil",
"sybil != 7.1.0",
]

[project.urls]
Expand Down
7 changes: 3 additions & 4 deletions src/coordinax/_coordinax/base/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
__all__ = ["AbstractAcceleration"]

from abc import abstractmethod
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar

import jax
from quax import register

import quaxed.array_api as xp
import quaxed.numpy as jnp
from dataclassish import field_items
from quaxed import lax as qlax
from unxt import Quantity
Expand Down Expand Up @@ -107,7 +106,7 @@ def __neg__(self) -> "Self":
Quantity['angular acceleration'](Array(-1., dtype=float32), unit='mas / yr2')
"""
return replace(self, **{k: -v for k, v in field_items(self)})
return jax.tree.map(jnp.negative, self)

# ===============================================================
# Convenience methods
Expand Down Expand Up @@ -191,7 +190,7 @@ def _mul_acc_time(lhs: AbstractAcceleration, rhs: Quantity["time"]) -> AbstractV
"""
# TODO: better access to corresponding fields
return lhs.integral_cls.constructor(
{k.replace("2", ""): xp.multiply(v, rhs) for k, v in field_items(lhs)}
{k.replace("2", ""): jnp.multiply(v, rhs) for k, v in field_items(lhs)}
)


Expand Down
63 changes: 33 additions & 30 deletions src/coordinax/_coordinax/base/base_pos.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from dataclasses import replace
from functools import partial
from inspect import isabstract
from typing import TYPE_CHECKING, Any, TypeVar
from typing import Any, TypeVar
from typing_extensions import override

import equinox as eqx
import jax
Expand All @@ -22,11 +23,9 @@
from .base import AbstractVector
from .mixins import AvalMixin
from coordinax._coordinax import typing as ct
from coordinax._coordinax.funcs import represent_as
from coordinax._coordinax.utils import classproperty

if TYPE_CHECKING:
from typing_extensions import Self

PosT = TypeVar("PosT", bound="AbstractPosition")

# TODO: figure out public API for this
Expand Down Expand Up @@ -87,36 +86,14 @@ def differential_cls(cls) -> type["AbstractVelocity"]:
raise NotImplementedError

# ===============================================================
# Array

# -----------------------------------------------------
# Unary operations

def __neg__(self) -> "Self":
"""Negate the vector.
The default implementation is to go through Cartesian coordinates.
Examples
--------
>>> import coordinax as cx
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> -vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
)
>>> (-vec).x
Quantity['length'](Array(-1., dtype=float32), unit='m')
"""
cart = self.represent_as(self._cartesian_cls)
return (-cart).represent_as(type(self))
__neg__ = jnp.negative

# ===============================================================
# Convenience methods

@override
def represent_as(self, target: type[PosT], /, *args: Any, **kwargs: Any) -> PosT:
"""Represent the vector as another type.
Expand Down Expand Up @@ -149,8 +126,6 @@ def represent_as(self, target: type[PosT], /, *args: Any, **kwargs: Any) -> PosT
Distance(Array(3.7416575, dtype=float32), unit='m')
"""
from coordinax import represent_as # pylint: disable=import-outside-toplevel

return represent_as(self, target, *args, **kwargs)

@partial(jax.jit, inline=True)
Expand Down Expand Up @@ -403,6 +378,34 @@ def _mul_pos_pos(lhs: AbstractPosition, rhs: AbstractPosition, /) -> Quantity:
# ------------------------------------------------


@register(jax.lax.neg_p) # type: ignore[misc]
def _neg_pos(obj: AbstractPosition, /) -> AbstractPosition:
"""Negate the vector.
The default implementation is to go through Cartesian coordinates.
Examples
--------
>>> import coordinax as cx
>>> vec = cx.CartesianPosition3D.constructor([1, 2, 3], "m")
>>> -vec
CartesianPosition3D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
y=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
z=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m"))
)
>>> (-vec).x
Quantity['length'](Array(-1., dtype=float32), unit='m')
"""
cart = represent_as(obj, obj._cartesian_cls) # noqa: SLF001
negcart = jnp.negative(cart)
return represent_as(negcart, type(obj))


# ------------------------------------------------


@register(jax.lax.reshape_p) # type: ignore[misc]
def _reshape_pos(
operand: AbstractPosition, *, new_sizes: tuple[int, ...], **kwargs: Any
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_coordinax/base/base_vel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
__all__ = ["AbstractVelocity"]

from abc import abstractmethod
from dataclasses import replace
from functools import partial
from typing import TYPE_CHECKING, Any, TypeVar

import jax
from quax import register

import quaxed.numpy as jnp
from dataclassish import field_items
from unxt import Quantity

Expand Down Expand Up @@ -123,7 +123,7 @@ def __neg__(self) -> "Self":
Quantity['angular frequency'](Array(-1., dtype=float32), unit='mas / yr')
"""
return replace(self, **{k: -v for k, v in field_items(self)})
return jax.tree.map(jnp.negative, self)

# ===============================================================
# Convenience methods
Expand Down
67 changes: 32 additions & 35 deletions src/coordinax/_coordinax/d1/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jaxtyping import ArrayLike
from quax import register

import quaxed.array_api as xp
import quaxed.numpy as jnp
from quaxed import lax as qlax
from unxt import Quantity

Expand Down Expand Up @@ -63,24 +63,6 @@ class CartesianPosition1D(AbstractPosition1D):
def differential_cls(cls) -> type["CartesianVelocity1D"]:
return CartesianVelocity1D

# -----------------------------------------------------
# Unary operations

def __neg__(self) -> "Self":
"""Negate the vector.
Examples
--------
>>> import coordinax as cx
>>> q = cx.CartesianPosition1D.constructor([1], "kpc")
>>> -q
CartesianPosition1D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("kpc"))
)
"""
return replace(self, x=-self.x)


# -------------------------------------------------------------------
# Method dispatches
Expand All @@ -92,13 +74,13 @@ def _add_qq(lhs: CartesianPosition1D, rhs: AbstractPosition, /) -> CartesianPosi
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> import coordinax as cx
>>> q = cx.CartesianPosition1D.constructor([1], "kpc")
>>> r = cx.RadialPosition.constructor([1], "kpc")
>>> qpr = xp.add(q, r)
>>> qpr = jnp.add(q, r)
>>> qpr
CartesianPosition1D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("kpc"))
Expand All @@ -120,12 +102,12 @@ def _mul_ac1(lhs: ArrayLike, rhs: CartesianPosition1D, /) -> CartesianPosition1D
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> v = cx.CartesianPosition1D(x=Quantity(1, "m"))
>>> xp.multiply(2, v).x
>>> jnp.multiply(2, v).x
Quantity['length'](Array(2., dtype=float32), unit='m')
>>> (2 * v).x
Expand All @@ -141,6 +123,21 @@ def _mul_ac1(lhs: ArrayLike, rhs: CartesianPosition1D, /) -> CartesianPosition1D
return replace(rhs, x=lhs * rhs.x)


@register(jax.lax.neg_p) # type: ignore[misc]
def _neg_p_cart1d_pos(obj: CartesianPosition1D, /) -> CartesianPosition1D:
"""Negate the `coordinax.CartesianPosition1D`.
Examples
--------
>>> import coordinax as cx
>>> q = cx.CartesianPosition1D.constructor([1], "km")
>>> (-q).x
Quantity['length'](Array(-1., dtype=float32), unit='km')
"""
return jax.tree.map(qlax.neg, obj)


@register(jax.lax.sub_p) # type: ignore[misc]
def _sub_q1d_pos(
self: CartesianPosition1D, other: AbstractPosition, /
Expand All @@ -149,14 +146,14 @@ def _sub_q1d_pos(
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> q = cx.CartesianPosition1D.constructor(Quantity([1], "kpc"))
>>> r = cx.RadialPosition.constructor(Quantity([1], "kpc"))
>>> qmr = xp.subtract(q, r)
>>> qmr = jnp.subtract(q, r)
>>> qmr
CartesianPosition1D(
x=Quantity[PhysicalType('length')](value=f32[], unit=Unit("kpc"))
Expand Down Expand Up @@ -205,7 +202,7 @@ def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableSpeed:
Quantity['speed'](Array(1, dtype=int32), unit='km / s')
"""
return xp.abs(self.d_x)
return jnp.abs(self.d_x)


# -------------------------------------------------------------------
Expand All @@ -220,12 +217,12 @@ def _add_pp(
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> v = cx.CartesianVelocity1D.constructor([1], "km/s")
>>> vec = xp.add(v, v)
>>> vec = jnp.add(v, v)
>>> vec
CartesianVelocity1D(
d_x=Quantity[...]( value=i32[], unit=Unit("km / s") )
Expand All @@ -246,12 +243,12 @@ def _mul_vcart(lhs: ArrayLike, rhs: CartesianVelocity1D, /) -> CartesianVelocity
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> v = cx.CartesianVelocity1D(d_x=Quantity(1, "m/s"))
>>> vec = xp.multiply(2, v)
>>> vec = jnp.multiply(2, v)
>>> vec
CartesianVelocity1D(
d_x=Quantity[...]( value=i32[], unit=Unit("m / s") )
Expand Down Expand Up @@ -304,7 +301,7 @@ def norm(self, _: AbstractPosition1D | None = None, /) -> ct.BatchableAcc:
Quantity['acceleration'](Array(1, dtype=int32), unit='km / s2')
"""
return xp.abs(self.d2_x)
return jnp.abs(self.d2_x)


@register(jax.lax.add_p) # type: ignore[misc]
Expand All @@ -315,12 +312,12 @@ def _add_aa(
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> v = cx.CartesianAcceleration1D.constructor([1], "km/s2")
>>> vec = xp.add(v, v)
>>> vec = jnp.add(v, v)
>>> vec
CartesianAcceleration1D(
d2_x=Quantity[...](value=i32[], unit=Unit("km / s2"))
Expand All @@ -341,12 +338,12 @@ def _mul_aq(lhs: ArrayLike, rhs: CartesianAcceleration1D, /) -> CartesianAcceler
Examples
--------
>>> import quaxed.array_api as xp
>>> import quaxed.numpy as jnp
>>> from unxt import Quantity
>>> import coordinax as cx
>>> v = cx.CartesianAcceleration1D(d2_x=Quantity(1, "m/s2"))
>>> vec = xp.multiply(2, v)
>>> vec = jnp.multiply(2, v)
>>> vec
CartesianAcceleration1D(
d2_x=Quantity[...](value=i32[], unit=Unit("m / s2"))
Expand Down
Loading

0 comments on commit 964e879

Please sign in to comment.