Skip to content

Commit

Permalink
feat: convert to unchecked quantity (#194)
Browse files Browse the repository at this point in the history
* refactor: move base pos compat
* feat: convert to UncheckedQuantity
* docs: fix

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Sep 23, 2024
1 parent e0a70d5 commit 7b1a499
Show file tree
Hide file tree
Showing 9 changed files with 374 additions and 213 deletions.
5 changes: 4 additions & 1 deletion src/coordinax/_src/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@
from .base_vel import AbstractVelocity

# isort: split
from . import register_primitives # noqa: F401
from . import (
compat, # noqa: F401
register_primitives, # noqa: F401
)
4 changes: 2 additions & 2 deletions src/coordinax/_src/base/base_acc.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _mul_acc_time(lhs: AbstractAcceleration, rhs: Quantity["time"]) -> AbstractV
>>> d2r = cx.RadialAcceleration(Quantity(1, "m/s2"))
>>> vec = lax.mul(d2r, Quantity(2, "s"))
>>> vec
RadialVelocity( d_r=Quantity[...]( value=weak_i32[], unit=Unit("m / s") ) )
RadialVelocity( d_r=Quantity[...]( value=...i32[], unit=Unit("m / s") ) )
>>> vec.d_r
Quantity['speed'](Array(2, dtype=int32, ...), unit='m / s')
Expand All @@ -209,7 +209,7 @@ def _mul_time_acc(lhs: Quantity["time"], rhs: AbstractAcceleration) -> AbstractV
>>> d2r = cx.RadialAcceleration(Quantity(1, "m/s2"))
>>> vec = lax.mul(Quantity(2, "s"), d2r)
>>> vec
RadialVelocity( d_r=Quantity[...]( value=weak_i32[], unit=Unit("m / s") ) )
RadialVelocity( d_r=Quantity[...]( value=...i32[], unit=Unit("m / s") ) )
>>> vec.d_r
Quantity['speed'](Array(2, dtype=int32, ...), unit='m / s')
Expand Down
184 changes: 184 additions & 0 deletions src/coordinax/_src/base/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""Intra-ecosystem Compatibility."""

__all__: list[str] = []


from jaxtyping import Shaped
from plum import conversion_method, convert

import quaxed.numpy as xp
from dataclassish import field_values
from unxt import AbstractQuantity, Distance, Quantity, UncheckedQuantity

from .base_pos import AbstractPosition
from coordinax._src.utils import full_shaped

#####################################################################
# Convert to Quantity


@conversion_method(type_from=AbstractPosition, type_to=AbstractQuantity) # type: ignore[misc]
def convert_pos_to_absquantity(obj: AbstractPosition, /) -> AbstractQuantity:
"""`coordinax.AbstractPosition` -> `unxt.AbstractQuantity`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
cart = full_shaped(obj.represent_as(obj._cartesian_cls)) # noqa: SLF001
return xp.stack(tuple(field_values(cart)), axis=-1)


@conversion_method(type_from=AbstractPosition, type_to=Quantity) # type: ignore[misc]
def convert_pos_to_q(obj: AbstractPosition, /) -> Quantity["length"]:
"""`coordinax.AbstractPosition` -> `unxt.Quantity`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, AbstractQuantity)
Quantity['length'](Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
return convert(convert(obj, AbstractQuantity), Quantity)


@conversion_method(type_from=AbstractPosition, type_to=UncheckedQuantity) # type: ignore[misc]
def convert_pos_to_uncheckedq(
obj: AbstractPosition, /
) -> Shaped[UncheckedQuantity, "*batch 1"]:
"""`coordinax.AbstractPosition` -> `unxt.UncheckedQuantity`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity, UncheckedQuantity
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, UncheckedQuantity)
UncheckedQuantity(Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
return convert(convert(obj, AbstractQuantity), UncheckedQuantity)


@conversion_method(type_from=AbstractPosition, type_to=Distance) # type: ignore[misc]
def convert_pos_to_distance(obj: AbstractPosition, /) -> Shaped[Distance, "*batch 1"]:
"""`coordinax.AbstractPosition` -> `unxt.Distance`.
Examples
--------
>>> import coordinax as cx
>>> from unxt import AbstractQuantity, Quantity, Distance
>>> pos = cx.CartesianPosition1D.constructor([1.0], "km")
>>> convert(pos, Distance)
Distance(Array([1.], dtype=float32), unit='km')
>>> pos = cx.RadialPosition.constructor([1.0], "km")
>>> convert(pos, Distance)
Distance(Array([1.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition2D.constructor([1.0, 2.0], "km")
>>> convert(pos, Distance)
Distance(Array([1., 2.], dtype=float32), unit='km')
>>> pos = cx.PolarPosition(Quantity(1.0, "km"), Quantity(0, "deg"))
>>> convert(pos, Distance)
Distance(Array([1., 0.], dtype=float32), unit='km')
>>> pos = cx.CartesianPosition3D.constructor([1.0, 2.0, 3.0], "km")
>>> convert(pos, Distance)
Distance(Array([1., 2., 3.], dtype=float32), unit='km')
>>> pos = cx.SphericalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "deg"))
>>> convert(pos, Distance)
Distance(Array([0., 0., 1.], dtype=float32), unit='km')
>>> pos = cx.CylindricalPosition(Quantity(1.0, "km"), Quantity(0, "deg"), Quantity(0, "km"))
>>> convert(pos, Distance)
Distance(Array([1., 0., 0.], dtype=float32), unit='km')
""" # noqa: E501
return convert(convert(obj, AbstractQuantity), Distance)
Loading

0 comments on commit 7b1a499

Please sign in to comment.