Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve default dtype selection #254

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
5 changes: 3 additions & 2 deletions cheetah/accelerator/aperture.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -31,8 +31,9 @@ def __init__(
is_active: bool = True,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([], [x_max, y_max], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
33 changes: 22 additions & 11 deletions cheetah/accelerator/cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.track_methods import base_rmatrix
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -36,8 +40,11 @@ def __init__(
frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [voltage, phase, frequency], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -248,13 +255,12 @@ def _track_beam(self, incoming: Beam) -> Beam:

def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
"""Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0."""
device = self.length.device
dtype = self.length.dtype
factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype}

phi = torch.deg2rad(self.phase)
delta_energy = self.voltage * torch.cos(phi)
# Comment from Ocelot: Pure pi-standing-wave case
eta = torch.tensor(1.0, device=device, dtype=dtype)
eta = torch.tensor(1.0, **factory_kwargs)
Ei = energy / electron_mass_eV
Ef = (energy + delta_energy) / electron_mass_eV
Ep = (Ef - Ei) / self.length # Derivative of the energy
Expand Down Expand Up @@ -288,12 +294,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
)
)

r56 = torch.tensor(0.0)
beta0 = torch.tensor(1.0)
beta1 = torch.tensor(1.0)
r56 = torch.tensor(0.0, **factory_kwargs)
beta0 = torch.tensor(1.0, **factory_kwargs)
beta1 = torch.tensor(1.0, **factory_kwargs)

k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light)
r55_cor = torch.tensor(0.0)
k = (
2
* torch.pi
* self.frequency
/ torch.tensor(constants.speed_of_light, **factory_kwargs)
)
r55_cor = torch.tensor(0.0, **factory_kwargs)
if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if?
beta0 = torch.sqrt(1 - 1 / Ei**2)
beta1 = torch.sqrt(1 - 1 / Ef**2)
Expand All @@ -320,7 +331,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor:
r11, r12, r21, r22, r55_cor, r56, r65, r66
)

R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1))
R = torch.eye(7, **factory_kwargs).repeat((*r11.shape, 1, 1))
R[..., 0, 0] = r11
R[..., 0, 1] = r12
R[..., 1, 0] = r21
Expand Down
5 changes: 3 additions & 2 deletions cheetah/accelerator/custom_transfer_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam
from cheetah.utils import UniqueNameGenerator
from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -23,8 +23,9 @@ def __init__(
length: Optional[torch.Tensor] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([transfer_map], [length], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
36 changes: 32 additions & 4 deletions cheetah/accelerator/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.track_methods import base_rmatrix, rotation_matrix
from cheetah.utils import UniqueNameGenerator, bmadx
from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand Down Expand Up @@ -62,8 +62,24 @@ def __init__(
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
):
device, dtype = verify_device_and_dtype(
[length],
[
angle,
k1,
e1,
e2,
tilt,
gap,
gap_exit,
fringe_integral,
fringe_integral_exit,
],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -203,7 +219,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:

# Begin Bmad-X tracking
x, px, y, py = bmadx.offset_particle_set(
torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py
torch.zeros_like(self.tilt),
torch.zeros_like(self.tilt),
self.tilt,
x,
px,
y,
py,
)

if self.fringe_at == "entrance" or self.fringe_at == "both":
Expand All @@ -215,7 +237,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam:
px, py = self._bmadx_fringe_linear("exit", x, px, y, py)

x, px, y, py = bmadx.offset_particle_unset(
torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py
torch.zeros_like(self.tilt),
torch.zeros_like(self.tilt),
self.tilt,
x,
px,
y,
py,
)
# End of Bmad-X tracking

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/drift.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
9 changes: 7 additions & 2 deletions cheetah/accelerator/horizontal_corrector.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,11 @@
from torch import nn

from cheetah.accelerator.element import Element
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand All @@ -29,8 +33,9 @@ def __init__(
angle: Optional[Union[torch.Tensor, nn.Parameter]] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype([length], [angle], device, dtype)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
7 changes: 5 additions & 2 deletions cheetah/accelerator/quadrupole.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.track_methods import base_rmatrix, misalignment_matrix
from cheetah.utils import UniqueNameGenerator, bmadx
from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand Down Expand Up @@ -42,8 +42,11 @@ def __init__(
tracking_method: Literal["cheetah", "bmadx"] = "cheetah",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [k1, misalignment, tilt], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/rbend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(
gap: Optional[Union[torch.Tensor, nn.Parameter]] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
):
super().__init__(
length=length,
Expand Down
15 changes: 12 additions & 3 deletions cheetah/accelerator/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParameterBeam, ParticleBeam
from cheetah.utils import UniqueNameGenerator, kde_histogram_2d
from cheetah.utils import UniqueNameGenerator, kde_histogram_2d, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand Down Expand Up @@ -53,8 +53,15 @@ def __init__(
is_active: bool = False,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[], # No required tensor arguments
# Excludes resolution and binning, since those are integer valued, not float
[pixel_size, misalignment, kde_bandwidth],
device,
dtype,
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -203,7 +210,9 @@ def reading(self) -> torch.Tensor:
read_beam = self.get_read_beam()
if read_beam is Beam.empty or read_beam is None:
image = torch.zeros(
(int(self.effective_resolution[1]), int(self.effective_resolution[0]))
(int(self.effective_resolution[1]), int(self.effective_resolution[0])),
device=self.misalignment.device,
dtype=self.misalignment.dtype,
)
elif isinstance(read_beam, ParameterBeam):
if torch.numel(read_beam._mu[..., 0]) > 1:
Expand Down
11 changes: 9 additions & 2 deletions cheetah/accelerator/solenoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@

from cheetah.accelerator.element import Element
from cheetah.track_methods import misalignment_matrix
from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors
from cheetah.utils import (
UniqueNameGenerator,
compute_relativistic_factors,
verify_device_and_dtype,
)

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand Down Expand Up @@ -36,8 +40,11 @@ def __init__(
misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [k, misalignment], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down
15 changes: 13 additions & 2 deletions cheetah/accelerator/space_charge_kick.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import verify_device_and_dtype


class SpaceChargeKick(Element):
Expand Down Expand Up @@ -56,8 +57,14 @@ def __init__(
grid_extend_tau: Union[torch.Tensor, nn.Parameter] = 3,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[effect_length],
[], # TODO: Add grid_extend_{x,y,tau}, needs torch.Tensor default
device,
dtype,
)
self.factory_kwargs = {"device": device, "dtype": dtype}

super().__init__(name=name)
Expand Down Expand Up @@ -589,7 +596,11 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam:
],
dim=-1,
)
cell_size = 2 * grid_dimensions / torch.tensor(self.grid_shape)
cell_size = (
2
* grid_dimensions
/ torch.tensor(self.grid_shape, **self.factory_kwargs)
)
dt = flattened_length_effect / (
speed_of_light * flattened_incoming.relativistic_beta
)
Expand Down
9 changes: 6 additions & 3 deletions cheetah/accelerator/transverse_deflecting_cavity.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from cheetah.accelerator.element import Element
from cheetah.particles import Beam, ParticleBeam
from cheetah.utils import UniqueNameGenerator, bmadx
from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype

generate_unique_name = UniqueNameGenerator(prefix="unnamed_element")

Expand Down Expand Up @@ -44,8 +44,11 @@ def __init__(
tracking_method: Literal["bmadx"] = "bmadx",
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
device, dtype = verify_device_and_dtype(
[length], [voltage, phase, frequency, misalignment, tilt], device, dtype
)
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)

Expand Down Expand Up @@ -87,7 +90,7 @@ def __init__(
(
torch.as_tensor(tilt, **factory_kwargs)
if tilt is not None
else torch.zeros_like(self.length)
else torch.tensor(0.0, **factory_kwargs)
),
)
self.num_steps = num_steps
Expand Down
2 changes: 1 addition & 1 deletion cheetah/accelerator/undulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
is_active: bool = False,
name: Optional[str] = None,
device=None,
dtype=torch.float32,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(name=name)
Expand Down
Loading