Skip to content

Commit

Permalink
refactor: consolidate constructors (#172)
Browse files Browse the repository at this point in the history
* refactor: misc
* refactor: consolidate dispatches

Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman authored Aug 29, 2024
1 parent bf4acf2 commit f891863
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 40 deletions.
1 change: 0 additions & 1 deletion src/coordinax/_coordinax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,6 @@ def __str__(self) -> str:
# Register additional constructors


# TODO: move to the class in py3.11+
@AbstractVector.constructor._f.dispatch # type: ignore[attr-defined, misc] # noqa: SLF001
def constructor(cls: type[AbstractVector], obj: AbstractVector, /) -> AbstractVector:
"""Construct a vector from another vector.
Expand Down
56 changes: 17 additions & 39 deletions src/coordinax/_interop/coordinax_interop_astropy/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
__all__: list[str] = []

from collections.abc import Mapping
from dataclasses import fields

import astropy.coordinates as apyc
import astropy.units as u
import equinox as eqx
from jaxtyping import Shaped
from plum import convert

from unxt import Quantity

import coordinax as cx

Expand Down Expand Up @@ -524,36 +524,21 @@ def constructor(cls: type[cx.AbstractVector], obj: u.Quantity, /) -> cx.Abstract
>>> vec.x
Quantity['length'](Array([1., 4.], dtype=float32), unit='m')
"""
_ = eqx.error_if(
obj,
obj.shape[-1] != len(fields(cls)),
f"Cannot construct {cls} from array with shape {obj.shape}.",
>>> vec = cx.CartesianVelocity3D.constructor(Quantity([1, 2, 3], "m/s"))
>>> vec
CartesianVelocity3D(
d_x=Quantity[...]( value=f32[], unit=Unit("m / s") ),
d_y=Quantity[...]( value=f32[], unit=Unit("m / s") ),
d_z=Quantity[...]( value=f32[], unit=Unit("m / s") )
)
return cls(**{f.name: obj[..., i] for i, f in enumerate(fields(cls))})


@cx.FourVector.constructor._f.dispatch # noqa: SLF001
def constructor(
cls: type[cx.FourVector], obj: Shaped[u.Quantity, "*batch 4"], /
) -> cx.FourVector:
"""Construct a vector from a Quantity array.
The array is expected to have the components as the last dimension.
Parameters
----------
cls : type[FourVector]
The class.
obj : Quantity[Any, (*#batch, 4), "..."]
The array of components.
The 4 components are the (c x) time, x, y, z.
Examples
--------
>>> import jax.numpy as jnp
>>> from astropy.units import Quantity
>>> import coordinax as cx
>>> vec = cx.CartesianAcceleration3D.constructor(Quantity([1, 2, 3], "m/s2"))
>>> vec
CartesianAcceleration3D(
d2_x=Quantity[...](value=f32[], unit=Unit("m / s2")),
d2_y=Quantity[...](value=f32[], unit=Unit("m / s2")),
d2_z=Quantity[...](value=f32[], unit=Unit("m / s2"))
)
>>> xs = Quantity([0, 1, 2, 3], "meter") # [ct, x, y, z]
>>> vec = cx.FourVector.constructor(xs)
Expand All @@ -570,15 +555,8 @@ def constructor(
t=Quantity[PhysicalType('time')](value=f32[2], unit=Unit("m s / km")),
q=CartesianPosition3D( ... )
)
>>> vec.x
Quantity['length'](Array([1., 4.], dtype=float32), unit='m')
"""
_ = eqx.error_if(
obj,
obj.shape[-1] != 4,
f"Cannot construct {cls} from array with shape {obj.shape}.",
)
c = cls.__dataclass_fields__["c"].default
return cls(t=obj[..., 0] / c, q=obj[..., 1:])
return cls.constructor(convert(obj, Quantity))

0 comments on commit f891863

Please sign in to comment.