diff --git a/cheetah/accelerator/aperture.py b/cheetah/accelerator/aperture.py index cdf5d256..7214e49b 100644 --- a/cheetah/accelerator/aperture.py +++ b/cheetah/accelerator/aperture.py @@ -7,7 +7,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator +from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -31,8 +31,9 @@ def __init__( is_active: bool = True, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([], [x_max, y_max], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/cavity.py b/cheetah/accelerator/cavity.py index c7a89e05..e213a298 100644 --- a/cheetah/accelerator/cavity.py +++ b/cheetah/accelerator/cavity.py @@ -10,7 +10,11 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParameterBeam, ParticleBeam from cheetah.track_methods import base_rmatrix -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -36,8 +40,11 @@ def __init__( frequency: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length], [voltage, phase, frequency], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -248,13 +255,12 @@ def _track_beam(self, incoming: Beam) -> Beam: def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: """Produces an R-matrix for a cavity when it is on, i.e. voltage > 0.0.""" - device = self.length.device - dtype = self.length.dtype + factory_kwargs = {"device": self.length.device, "dtype": self.length.dtype} phi = torch.deg2rad(self.phase) delta_energy = self.voltage * torch.cos(phi) # Comment from Ocelot: Pure pi-standing-wave case - eta = torch.tensor(1.0, device=device, dtype=dtype) + eta = torch.tensor(1.0, **factory_kwargs) Ei = energy / electron_mass_eV Ef = (energy + delta_energy) / electron_mass_eV Ep = (Ef - Ei) / self.length # Derivative of the energy @@ -288,12 +294,17 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: ) ) - r56 = torch.tensor(0.0) - beta0 = torch.tensor(1.0) - beta1 = torch.tensor(1.0) + r56 = torch.tensor(0.0, **factory_kwargs) + beta0 = torch.tensor(1.0, **factory_kwargs) + beta1 = torch.tensor(1.0, **factory_kwargs) - k = 2 * torch.pi * self.frequency / torch.tensor(constants.speed_of_light) - r55_cor = torch.tensor(0.0) + k = ( + 2 + * torch.pi + * self.frequency + / torch.tensor(constants.speed_of_light, **factory_kwargs) + ) + r55_cor = torch.tensor(0.0, **factory_kwargs) if torch.any((self.voltage != 0) & (energy != 0)): # TODO: Do we need this if? beta0 = torch.sqrt(1 - 1 / Ei**2) beta1 = torch.sqrt(1 - 1 / Ef**2) @@ -320,7 +331,7 @@ def _cavity_rmatrix(self, energy: torch.Tensor) -> torch.Tensor: r11, r12, r21, r22, r55_cor, r56, r65, r66 ) - R = torch.eye(7, device=device, dtype=dtype).repeat((*r11.shape, 1, 1)) + R = torch.eye(7, **factory_kwargs).repeat((*r11.shape, 1, 1)) R[..., 0, 0] = r11 R[..., 0, 1] = r12 R[..., 1, 0] = r21 diff --git a/cheetah/accelerator/custom_transfer_map.py b/cheetah/accelerator/custom_transfer_map.py index 5baf87d2..abaeae3e 100644 --- a/cheetah/accelerator/custom_transfer_map.py +++ b/cheetah/accelerator/custom_transfer_map.py @@ -7,7 +7,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam -from cheetah.utils import UniqueNameGenerator +from cheetah.utils import UniqueNameGenerator, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -23,8 +23,9 @@ def __init__( length: Optional[torch.Tensor] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([transfer_map], [length], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/dipole.py b/cheetah/accelerator/dipole.py index 5e919fbf..bd803626 100644 --- a/cheetah/accelerator/dipole.py +++ b/cheetah/accelerator/dipole.py @@ -10,7 +10,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam from cheetah.track_methods import base_rmatrix, rotation_matrix -from cheetah.utils import UniqueNameGenerator, bmadx +from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -62,8 +62,24 @@ def __init__( tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ): + device, dtype = verify_device_and_dtype( + [length], + [ + angle, + k1, + e1, + e2, + tilt, + gap, + gap_exit, + fringe_integral, + fringe_integral_exit, + ], + device, + dtype, + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -203,7 +219,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: # Begin Bmad-X tracking x, px, y, py = bmadx.offset_particle_set( - torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py + torch.zeros_like(self.tilt), + torch.zeros_like(self.tilt), + self.tilt, + x, + px, + y, + py, ) if self.fringe_at == "entrance" or self.fringe_at == "both": @@ -215,7 +237,13 @@ def _track_bmadx(self, incoming: ParticleBeam) -> ParticleBeam: px, py = self._bmadx_fringe_linear("exit", x, px, y, py) x, px, y, py = bmadx.offset_particle_unset( - torch.tensor(0.0), torch.tensor(0.0), self.tilt, x, px, y, py + torch.zeros_like(self.tilt), + torch.zeros_like(self.tilt), + self.tilt, + x, + px, + y, + py, ) # End of Bmad-X tracking diff --git a/cheetah/accelerator/drift.py b/cheetah/accelerator/drift.py index eb5fb187..d9d62d3f 100644 --- a/cheetah/accelerator/drift.py +++ b/cheetah/accelerator/drift.py @@ -32,7 +32,7 @@ def __init__( tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/horizontal_corrector.py b/cheetah/accelerator/horizontal_corrector.py index a00c2cbe..a19aaa15 100644 --- a/cheetah/accelerator/horizontal_corrector.py +++ b/cheetah/accelerator/horizontal_corrector.py @@ -7,7 +7,11 @@ from torch import nn from cheetah.accelerator.element import Element -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -29,8 +33,9 @@ def __init__( angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([length], [angle], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/quadrupole.py b/cheetah/accelerator/quadrupole.py index db6a559d..b34a9e21 100644 --- a/cheetah/accelerator/quadrupole.py +++ b/cheetah/accelerator/quadrupole.py @@ -10,7 +10,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam from cheetah.track_methods import base_rmatrix, misalignment_matrix -from cheetah.utils import UniqueNameGenerator, bmadx +from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -42,8 +42,11 @@ def __init__( tracking_method: Literal["cheetah", "bmadx"] = "cheetah", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length], [k1, misalignment, tilt], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/rbend.py b/cheetah/accelerator/rbend.py index b50ef08e..67f25f1d 100644 --- a/cheetah/accelerator/rbend.py +++ b/cheetah/accelerator/rbend.py @@ -39,7 +39,7 @@ def __init__( gap: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ): super().__init__( length=length, diff --git a/cheetah/accelerator/screen.py b/cheetah/accelerator/screen.py index d7b153d6..359a6d88 100644 --- a/cheetah/accelerator/screen.py +++ b/cheetah/accelerator/screen.py @@ -9,7 +9,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParameterBeam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, kde_histogram_2d +from cheetah.utils import UniqueNameGenerator, kde_histogram_2d, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -53,8 +53,15 @@ def __init__( is_active: bool = False, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [], # No required tensor arguments + # Excludes resolution and binning, since those are integer valued, not float + [pixel_size, misalignment, kde_bandwidth], + device, + dtype, + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -203,7 +210,9 @@ def reading(self) -> torch.Tensor: read_beam = self.get_read_beam() if read_beam is Beam.empty or read_beam is None: image = torch.zeros( - (int(self.effective_resolution[1]), int(self.effective_resolution[0])) + (int(self.effective_resolution[1]), int(self.effective_resolution[0])), + device=self.misalignment.device, + dtype=self.misalignment.dtype, ) elif isinstance(read_beam, ParameterBeam): if torch.numel(read_beam._mu[..., 0]) > 1: diff --git a/cheetah/accelerator/solenoid.py b/cheetah/accelerator/solenoid.py index f8faf24c..ed7bc6e7 100644 --- a/cheetah/accelerator/solenoid.py +++ b/cheetah/accelerator/solenoid.py @@ -8,7 +8,11 @@ from cheetah.accelerator.element import Element from cheetah.track_methods import misalignment_matrix -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -36,8 +40,11 @@ def __init__( misalignment: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length], [k, misalignment], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/space_charge_kick.py b/cheetah/accelerator/space_charge_kick.py index 7bb7d7f3..8fb52e7c 100644 --- a/cheetah/accelerator/space_charge_kick.py +++ b/cheetah/accelerator/space_charge_kick.py @@ -7,6 +7,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam +from cheetah.utils import verify_device_and_dtype class SpaceChargeKick(Element): @@ -56,8 +57,14 @@ def __init__( grid_extend_tau: Union[torch.Tensor, nn.Parameter] = 3, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [effect_length], + [], # TODO: Add grid_extend_{x,y,tau}, needs torch.Tensor default + device, + dtype, + ) self.factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -589,7 +596,11 @@ def track(self, incoming: ParticleBeam) -> ParticleBeam: ], dim=-1, ) - cell_size = 2 * grid_dimensions / torch.tensor(self.grid_shape) + cell_size = ( + 2 + * grid_dimensions + / torch.tensor(self.grid_shape, **self.factory_kwargs) + ) dt = flattened_length_effect / ( speed_of_light * flattened_incoming.relativistic_beta ) diff --git a/cheetah/accelerator/transverse_deflecting_cavity.py b/cheetah/accelerator/transverse_deflecting_cavity.py index fd4ed2af..de5ad9e1 100644 --- a/cheetah/accelerator/transverse_deflecting_cavity.py +++ b/cheetah/accelerator/transverse_deflecting_cavity.py @@ -8,7 +8,7 @@ from cheetah.accelerator.element import Element from cheetah.particles import Beam, ParticleBeam -from cheetah.utils import UniqueNameGenerator, bmadx +from cheetah.utils import UniqueNameGenerator, bmadx, verify_device_and_dtype generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -44,8 +44,11 @@ def __init__( tracking_method: Literal["bmadx"] = "bmadx", name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [length], [voltage, phase, frequency, misalignment, tilt], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) @@ -87,7 +90,7 @@ def __init__( ( torch.as_tensor(tilt, **factory_kwargs) if tilt is not None - else torch.zeros_like(self.length) + else torch.tensor(0.0, **factory_kwargs) ), ) self.num_steps = num_steps diff --git a/cheetah/accelerator/undulator.py b/cheetah/accelerator/undulator.py index e7304870..ebf13b37 100644 --- a/cheetah/accelerator/undulator.py +++ b/cheetah/accelerator/undulator.py @@ -32,7 +32,7 @@ def __init__( is_active: bool = False, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/accelerator/vertical_corrector.py b/cheetah/accelerator/vertical_corrector.py index be5ba4e4..4010c4f4 100644 --- a/cheetah/accelerator/vertical_corrector.py +++ b/cheetah/accelerator/vertical_corrector.py @@ -8,7 +8,11 @@ from torch import nn from cheetah.accelerator.element import Element -from cheetah.utils import UniqueNameGenerator, compute_relativistic_factors +from cheetah.utils import ( + UniqueNameGenerator, + compute_relativistic_factors, + verify_device_and_dtype, +) generate_unique_name = UniqueNameGenerator(prefix="unnamed_element") @@ -32,8 +36,9 @@ def __init__( angle: Optional[Union[torch.Tensor, nn.Parameter]] = None, name: Optional[str] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype([length], [angle], device, dtype) factory_kwargs = {"device": device, "dtype": dtype} super().__init__(name=name) diff --git a/cheetah/particles/beam.py b/cheetah/particles/beam.py index 427e0329..629443ae 100644 --- a/cheetah/particles/beam.py +++ b/cheetah/particles/beam.py @@ -52,6 +52,8 @@ def from_parameters( cor_tau: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, + device=None, + dtype=None, ) -> "Beam": """ Create beam that with given beam parameters. @@ -75,6 +77,8 @@ def from_parameters( :param cor_tau: Correlation between tau and p. :param energy: Reference energy of the beam in eV. :param total_charge: Total charge of the beam in C. + :param device: Device to create the beam on. + :param dtype: Data type of the beam. """ raise NotImplementedError @@ -93,7 +97,7 @@ def from_twiss( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "Beam": """ Create a beam from twiss parameters. @@ -115,14 +119,14 @@ def from_twiss( raise NotImplementedError @classmethod - def from_ocelot(cls, parray) -> "Beam": + def from_ocelot(cls, parray, device=None, dtype=None) -> "Beam": """ Convert an Ocelot ParticleArray `parray` to a Cheetah Beam. """ raise NotImplementedError @classmethod - def from_astra(cls, path: str, **kwargs) -> "Beam": + def from_astra(cls, path: str, device=None, dtype=None) -> "Beam": """Load an Astra particle distribution as a Cheetah Beam.""" raise NotImplementedError @@ -140,6 +144,8 @@ def transformed_to( sigma_p: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, + device=None, + dtype=None, ) -> "Beam": """ Create version of this beam that is transformed to new beam parameters. @@ -160,7 +166,12 @@ def transformed_to( dimensionless. :param energy: Reference energy of the beam in eV. :param total_charge: Total charge of the beam in C. + :param device: Device to create the beam on. + :param dtype: Data type of the beam. """ + device = device if device is not None else self.mu_x.device + dtype = dtype if dtype is not None else self.mu_x.dtype + # Figure out vector dimensions of the original beam and check that passed # arguments have the same vector dimensions. shape = self.mu_x.shape @@ -213,6 +224,8 @@ def transformed_to( sigma_p=sigma_p, energy=energy, total_charge=total_charge, + device=device, + dtype=dtype, ) @property diff --git a/cheetah/particles/parameter_beam.py b/cheetah/particles/parameter_beam.py index dcdbb8c0..d17022dc 100644 --- a/cheetah/particles/parameter_beam.py +++ b/cheetah/particles/parameter_beam.py @@ -5,6 +5,12 @@ from cheetah.particles.beam import Beam from cheetah.particles.particle_beam import ParticleBeam +from cheetah.utils import ( + extract_argument_device, + extract_argument_dtype, + extract_argument_shape, + verify_device_and_dtype, +) class ParameterBeam(Beam): @@ -26,8 +32,11 @@ def __init__( energy: torch.Tensor, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: + device, dtype = verify_device_and_dtype( + [mu, cov, energy], [total_charge], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} super().__init__() @@ -62,24 +71,70 @@ def from_parameters( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParameterBeam": + # Figure out if arguments were passed + not_nones = [ + argument + for argument in [ + mu_x, + mu_px, + mu_y, + mu_py, + sigma_x, + sigma_px, + sigma_y, + sigma_py, + sigma_tau, + sigma_p, + cor_x, + cor_y, + cor_tau, + energy, + total_charge, + ] + if argument is not None + ] + + # Extract device and dtype from given arguments + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + factory_kwargs = {"device": device, "dtype": dtype} + # Set default values without function call in function signature - mu_x = mu_x if mu_x is not None else torch.tensor(0.0) - mu_px = mu_px if mu_px is not None else torch.tensor(0.0) - mu_y = mu_y if mu_y is not None else torch.tensor(0.0) - mu_py = mu_py if mu_py is not None else torch.tensor(0.0) - sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) - cor_x = cor_x if cor_x is not None else torch.tensor(0.0) - cor_y = cor_y if cor_y is not None else torch.tensor(0.0) - cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) - energy = energy if energy is not None else torch.tensor(1e8) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0, **factory_kwargs) + sigma_x = ( + sigma_x if sigma_x is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_px = ( + sigma_px if sigma_px is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_y = ( + sigma_y if sigma_y is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_py = ( + sigma_py if sigma_py is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0, **factory_kwargs) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0, **factory_kwargs) + cor_tau = ( + cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) mu_x, mu_px, mu_y, mu_py = torch.broadcast_tensors(mu_x, mu_px, mu_y, mu_py) mu = torch.stack( @@ -116,7 +171,7 @@ def from_parameters( cor_tau, sigma_p, ) - cov = torch.zeros(*sigma_x.shape, 7, 7) + cov = torch.zeros(*sigma_x.shape, 7, 7, **factory_kwargs) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -154,9 +209,9 @@ def from_twiss( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParameterBeam": - # Figure out if arguments were passed, figure out their shape + # Figure out if arguments were passed not_nones = [ argument for argument in [ @@ -174,29 +229,56 @@ def from_twiss( ] if argument is not None ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." + + # Extract shape, device and dtype from given arguments + shape = extract_argument_shape(not_nones) + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature - beta_x = beta_x if beta_x is not None else torch.full(shape, 1.0) - alpha_x = alpha_x if alpha_x is not None else torch.full(shape, 0.0) + beta_x = ( + beta_x if beta_x is not None else torch.full(shape, 1.0, **factory_kwargs) + ) + alpha_x = ( + alpha_x if alpha_x is not None else torch.full(shape, 0.0, **factory_kwargs) + ) emittance_x = ( - emittance_x if emittance_x is not None else torch.full(shape, 7.1971891e-13) + emittance_x + if emittance_x is not None + else torch.full(shape, 7.1971891e-13, **factory_kwargs) + ) + beta_y = ( + beta_y if beta_y is not None else torch.full(shape, 1.0, **factory_kwargs) + ) + alpha_y = ( + alpha_y if alpha_y is not None else torch.full(shape, 0.0, **factory_kwargs) ) - beta_y = beta_y if beta_y is not None else torch.full(shape, 1.0) - alpha_y = alpha_y if alpha_y is not None else torch.full(shape, 0.0) emittance_y = ( - emittance_y if emittance_y is not None else torch.full(shape, 7.1971891e-13) + emittance_y + if emittance_y is not None + else torch.full(shape, 7.1971891e-13, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau + if sigma_tau is not None + else torch.full(shape, 1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p + if sigma_p is not None + else torch.full(shape, 1e-6, **factory_kwargs) + ) + cor_tau = ( + cor_tau if cor_tau is not None else torch.full(shape, 0.0, **factory_kwargs) + ) + energy = ( + energy if energy is not None else torch.full(shape, 1e8, **factory_kwargs) ) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) + total_charge + if total_charge is not None + else torch.full(shape, 0.0, **factory_kwargs) ) assert torch.all( @@ -287,7 +369,7 @@ def transformed_to( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParameterBeam": """ Create version of this beam that is transformed to new beam parameters. diff --git a/cheetah/particles/particle_beam.py b/cheetah/particles/particle_beam.py index 5b83bab1..dc980cad 100644 --- a/cheetah/particles/particle_beam.py +++ b/cheetah/particles/particle_beam.py @@ -6,7 +6,13 @@ from torch.distributions import MultivariateNormal from cheetah.particles.beam import Beam -from cheetah.utils import elementwise_linspace +from cheetah.utils import ( + elementwise_linspace, + extract_argument_device, + extract_argument_dtype, + extract_argument_shape, + verify_device_and_dtype, +) speed_of_light = torch.tensor(constants.speed_of_light) # In m/s electron_mass = torch.tensor(constants.electron_mass) # In kg @@ -24,6 +30,7 @@ class ParticleBeam(Beam): :param total_charge: Total charge of the beam in C. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. + :param dtype: Data type of the generated particles. """ def __init__( @@ -32,10 +39,13 @@ def __init__( energy: torch.Tensor, particle_charges: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> None: - super().__init__() + device, dtype = verify_device_and_dtype( + [particles, energy], [particle_charges], device, dtype + ) factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() assert ( particles.shape[-2] > 0 and particles.shape[-1] == 7 @@ -55,7 +65,7 @@ def __init__( @classmethod def from_parameters( cls, - num_particles: Optional[torch.Tensor] = None, + num_particles: Optional[int] = None, mu_x: Optional[torch.Tensor] = None, mu_y: Optional[torch.Tensor] = None, mu_px: Optional[torch.Tensor] = None, @@ -96,32 +106,77 @@ def from_parameters( :param cor_y: Correlation between y and py. :param cor_tau: Correlation between s and p. :param energy: Energy of the beam in eV. - :total_charge: Total charge of the beam in C. + :param total_charge: Total charge of the beam in C. :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ + # Figure out if arguments were passed + not_nones = [ + argument + for argument in [ + mu_x, + mu_px, + mu_y, + mu_py, + sigma_x, + sigma_px, + sigma_y, + sigma_py, + sigma_tau, + sigma_p, + cor_x, + cor_y, + cor_tau, + energy, + total_charge, + ] + if argument is not None + ] + + # Extract device and dtype from given arguments + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature num_particles = ( num_particles if num_particles is not None else torch.tensor(100_000) ) - mu_x = mu_x if mu_x is not None else torch.tensor(0.0) - mu_px = mu_px if mu_px is not None else torch.tensor(0.0) - mu_y = mu_y if mu_y is not None else torch.tensor(0.0) - mu_py = mu_py if mu_py is not None else torch.tensor(0.0) - sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) - cor_x = cor_x if cor_x is not None else torch.tensor(0.0) - cor_y = cor_y if cor_y is not None else torch.tensor(0.0) - cor_tau = cor_tau if cor_tau is not None else torch.tensor(0.0) - energy = energy if energy is not None else torch.tensor(1e8) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0, **factory_kwargs) + sigma_x = ( + sigma_x if sigma_x is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_px = ( + sigma_px if sigma_px is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_y = ( + sigma_y if sigma_y is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_py = ( + sigma_py if sigma_py is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + cor_x = cor_x if cor_x is not None else torch.tensor(0.0, **factory_kwargs) + cor_y = cor_y if cor_y is not None else torch.tensor(0.0, **factory_kwargs) + cor_tau = ( + cor_tau if cor_tau is not None else torch.tensor(0.0, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) particle_charges = ( - torch.ones((*total_charge.shape, num_particles)) + torch.ones((*total_charge.shape, num_particles), **factory_kwargs) * total_charge.unsqueeze(-1) / num_particles ) @@ -153,7 +208,7 @@ def from_parameters( cor_tau, sigma_p, ) - cov = torch.zeros(*sigma_x.shape, 6, 6) + cov = torch.zeros(*sigma_x.shape, 6, 6, **factory_kwargs) cov[..., 0, 0] = sigma_x**2 cov[..., 0, 1] = cor_x cov[..., 1, 0] = cor_x @@ -167,7 +222,7 @@ def from_parameters( cov[..., 5, 4] = cor_tau cov[..., 5, 5] = sigma_p**2 - particles = torch.ones((*mean.shape[:-1], num_particles, 7)) + particles = torch.ones((*mean.shape[:-1], num_particles, 7), **factory_kwargs) distributions = [ MultivariateNormal(sample_mean, covariance_matrix=sample_cov) for sample_mean, sample_cov in zip(mean.view(-1, 6), cov.view(-1, 6, 6)) @@ -188,7 +243,7 @@ def from_parameters( @classmethod def from_twiss( cls, - num_particles: Optional[torch.Tensor] = None, + num_particles: Optional[int] = None, beta_x: Optional[torch.Tensor] = None, alpha_x: Optional[torch.Tensor] = None, emittance_x: Optional[torch.Tensor] = None, @@ -201,9 +256,9 @@ def from_twiss( cor_tau: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": - # Figure out if arguments were passed, figure out their shape + # Figure out if arguments were passed not_nones = [ argument for argument in [ @@ -221,28 +276,57 @@ def from_twiss( ] if argument is not None ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([1]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." + + # Extract shape, device and dtype from given arguments + shape = extract_argument_shape(not_nones) + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature - num_particles = ( - num_particles if num_particles is not None else torch.tensor(1_000_000) + num_particles = num_particles if num_particles is not None else 1_000_000 + beta_x = ( + beta_x if beta_x is not None else torch.full(shape, 0.0, **factory_kwargs) + ) + alpha_x = ( + alpha_x if alpha_x is not None else torch.full(shape, 0.0, **factory_kwargs) + ) + emittance_x = ( + emittance_x + if emittance_x is not None + else torch.full(shape, 0.0, **factory_kwargs) + ) + beta_y = ( + beta_y if beta_y is not None else torch.full(shape, 0.0, **factory_kwargs) + ) + alpha_y = ( + alpha_y if alpha_y is not None else torch.full(shape, 0.0, **factory_kwargs) + ) + emittance_y = ( + emittance_y + if emittance_y is not None + else torch.full(shape, 0.0, **factory_kwargs) + ) + energy = ( + energy if energy is not None else torch.full(shape, 1e8, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau + if sigma_tau is not None + else torch.full(shape, 1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p + if sigma_p is not None + else torch.full(shape, 1e-6, **factory_kwargs) + ) + cor_tau = ( + cor_tau if cor_tau is not None else torch.full(shape, 0.0, **factory_kwargs) ) - beta_x = beta_x if beta_x is not None else torch.full(shape, 0.0) - alpha_x = alpha_x if alpha_x is not None else torch.full(shape, 0.0) - emittance_x = emittance_x if emittance_x is not None else torch.full(shape, 0.0) - beta_y = beta_y if beta_y is not None else torch.full(shape, 0.0) - alpha_y = alpha_y if alpha_y is not None else torch.full(shape, 0.0) - emittance_y = emittance_y if emittance_y is not None else torch.full(shape, 0.0) - energy = energy if energy is not None else torch.full(shape, 1e8) - sigma_tau = sigma_tau if sigma_tau is not None else torch.full(shape, 1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.full(shape, 1e-6) - cor_tau = cor_tau if cor_tau is not None else torch.full(shape, 0.0) total_charge = ( - total_charge if total_charge is not None else torch.full(shape, 0.0) + total_charge + if total_charge is not None + else torch.full(shape, 0.0, **factory_kwargs) ) sigma_x = torch.sqrt(beta_x * emittance_x) @@ -254,10 +338,10 @@ def from_twiss( return cls.from_parameters( num_particles=num_particles, - mu_x=torch.full(shape, 0.0), - mu_px=torch.full(shape, 0.0), - mu_y=torch.full(shape, 0.0), - mu_py=torch.full(shape, 0.0), + mu_x=torch.full(shape, 0.0, **factory_kwargs), + mu_px=torch.full(shape, 0.0, **factory_kwargs), + mu_y=torch.full(shape, 0.0, **factory_kwargs), + mu_py=torch.full(shape, 0.0, **factory_kwargs), sigma_x=sigma_x, sigma_px=sigma_px, sigma_y=sigma_y, @@ -276,7 +360,7 @@ def from_twiss( @classmethod def uniform_3d_ellipsoid( cls, - num_particles: Optional[torch.Tensor] = None, + num_particles: Optional[int] = None, radius_x: Optional[torch.Tensor] = None, radius_y: Optional[torch.Tensor] = None, radius_tau: Optional[torch.Tensor] = None, @@ -286,7 +370,7 @@ def uniform_3d_ellipsoid( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ): """ Generate a particle beam with spatially uniformly distributed particles inside @@ -316,7 +400,7 @@ def uniform_3d_ellipsoid( :return: ParticleBeam with uniformly distributed particles inside an ellipsoid. """ - # Figure out if arguments were passed, figure out their shape + # Figure out if arguments were passed not_nones = [ argument for argument in [ @@ -331,11 +415,12 @@ def uniform_3d_ellipsoid( ] if argument is not None ] - shape = not_nones[0].shape if len(not_nones) > 0 else torch.Size([]) - if len(not_nones) > 1: - assert all( - argument.shape == shape for argument in not_nones - ), "Arguments must have the same shape." + + # Extract shape, device and dtype from given arguments + shape = extract_argument_shape(not_nones) + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + factory_kwargs = {"device": device, "dtype": dtype} # Expand to vectorised version for beam creation vector_shape = shape if len(shape) > 0 else torch.Size([1]) @@ -349,31 +434,37 @@ def uniform_3d_ellipsoid( radius_x = ( radius_x.expand(vector_shape) if radius_x is not None - else torch.full(vector_shape, 1e-3) + else torch.full(vector_shape, 1e-3, **factory_kwargs) ) radius_y = ( radius_y.expand(vector_shape) if radius_y is not None - else torch.full(vector_shape, 1e-3) + else torch.full(vector_shape, 1e-3, **factory_kwargs) ) radius_tau = ( radius_tau.expand(vector_shape) if radius_tau is not None - else torch.full(vector_shape, 1e-3) + else torch.full(vector_shape, 1e-3, **factory_kwargs) ) # Generate x, y and ss within the ellipsoid - flattened_x = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) - flattened_y = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) - flattened_tau = torch.empty(*vector_shape, num_particles).flatten(end_dim=-2) + flattened_x = torch.empty( + *vector_shape, num_particles, **factory_kwargs + ).flatten(end_dim=-2) + flattened_y = torch.empty( + *vector_shape, num_particles, **factory_kwargs + ).flatten(end_dim=-2) + flattened_tau = torch.empty( + *vector_shape, num_particles, **factory_kwargs + ).flatten(end_dim=-2) for i, (r_x, r_y, r_tau) in enumerate( zip(radius_x.flatten(), radius_y.flatten(), radius_tau.flatten()) ): num_successful = 0 while num_successful < num_particles: - x = (torch.rand(num_particles) - 0.5) * 2 * r_x - y = (torch.rand(num_particles) - 0.5) * 2 * r_y - tau = (torch.rand(num_particles) - 0.5) * 2 * r_tau + x = (torch.rand(num_particles, **factory_kwargs) - 0.5) * 2 * r_x + y = (torch.rand(num_particles, **factory_kwargs) - 0.5) * 2 * r_y + tau = (torch.rand(num_particles, **factory_kwargs) - 0.5) * 2 * r_tau is_in_ellipsoid = x**2 / r_x**2 + y**2 / r_y**2 + tau**2 / r_tau**2 < 1 num_to_add = min(num_particles - num_successful, is_in_ellipsoid.sum()) @@ -393,8 +484,8 @@ def uniform_3d_ellipsoid( # Generate an uncorrelated Gaussian beam beam = cls.from_parameters( num_particles=num_particles, - mu_px=torch.full(shape, 0.0), - mu_py=torch.full(shape, 0.0), + mu_px=torch.full(shape, 0.0, **factory_kwargs), + mu_py=torch.full(shape, 0.0, **factory_kwargs), sigma_px=sigma_px, sigma_py=sigma_py, sigma_p=sigma_p, @@ -414,7 +505,7 @@ def uniform_3d_ellipsoid( @classmethod def make_linspaced( cls, - num_particles: Optional[torch.Tensor] = None, + num_particles: Optional[int] = None, mu_x: Optional[torch.Tensor] = None, mu_y: Optional[torch.Tensor] = None, mu_px: Optional[torch.Tensor] = None, @@ -428,7 +519,7 @@ def make_linspaced( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": """ Generate Cheetah Beam of *n* linspaced particles. @@ -451,23 +542,63 @@ def make_linspaced( :param device: Device to move the beam's particle array to. If set to `"auto"` a CUDA GPU is selected if available. The CPU is used otherwise. """ + # Figure out if arguments were passed + not_nones = [ + argument + for argument in [ + mu_x, + mu_px, + mu_y, + mu_py, + sigma_x, + sigma_px, + sigma_y, + sigma_py, + sigma_tau, + sigma_p, + energy, + total_charge, + ] + if argument is not None + ] + + # Extract device and dtype from given arguments + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + factory_kwargs = {"device": device, "dtype": dtype} # Set default values without function call in function signature num_particles = num_particles if num_particles is not None else torch.tensor(10) - mu_x = mu_x if mu_x is not None else torch.tensor(0.0) - mu_px = mu_px if mu_px is not None else torch.tensor(0.0) - mu_y = mu_y if mu_y is not None else torch.tensor(0.0) - mu_py = mu_py if mu_py is not None else torch.tensor(0.0) - sigma_x = sigma_x if sigma_x is not None else torch.tensor(175e-9) - sigma_px = sigma_px if sigma_px is not None else torch.tensor(2e-7) - sigma_y = sigma_y if sigma_y is not None else torch.tensor(175e-9) - sigma_py = sigma_py if sigma_py is not None else torch.tensor(2e-7) - sigma_tau = sigma_tau if sigma_tau is not None else torch.tensor(1e-6) - sigma_p = sigma_p if sigma_p is not None else torch.tensor(1e-6) - energy = energy if energy is not None else torch.tensor(1e8) - total_charge = total_charge if total_charge is not None else torch.tensor(0.0) + mu_x = mu_x if mu_x is not None else torch.tensor(0.0, **factory_kwargs) + mu_px = mu_px if mu_px is not None else torch.tensor(0.0, **factory_kwargs) + mu_y = mu_y if mu_y is not None else torch.tensor(0.0, **factory_kwargs) + mu_py = mu_py if mu_py is not None else torch.tensor(0.0, **factory_kwargs) + sigma_x = ( + sigma_x if sigma_x is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_px = ( + sigma_px if sigma_px is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_y = ( + sigma_y if sigma_y is not None else torch.tensor(175e-9, **factory_kwargs) + ) + sigma_py = ( + sigma_py if sigma_py is not None else torch.tensor(2e-7, **factory_kwargs) + ) + sigma_tau = ( + sigma_tau if sigma_tau is not None else torch.tensor(1e-6, **factory_kwargs) + ) + sigma_p = ( + sigma_p if sigma_p is not None else torch.tensor(1e-6, **factory_kwargs) + ) + energy = energy if energy is not None else torch.tensor(1e8, **factory_kwargs) + total_charge = ( + total_charge + if total_charge is not None + else torch.tensor(0.0, **factory_kwargs) + ) particle_charges = ( - torch.ones((*total_charge.shape, num_particles)) + torch.ones((*total_charge.shape, num_particles), **factory_kwargs) * total_charge.unsqueeze(-1) / num_particles ) @@ -484,7 +615,7 @@ def make_linspaced( sigma_tau.shape, sigma_p.shape, ) - particles = torch.ones((*vector_shape, num_particles, 7)) + particles = torch.ones((*vector_shape, num_particles, 7), **factory_kwargs) particles[..., 0] = elementwise_linspace( mu_x - sigma_x, mu_x + sigma_x, num_particles @@ -559,7 +690,7 @@ def transformed_to( energy: Optional[torch.Tensor] = None, total_charge: Optional[torch.Tensor] = None, device=None, - dtype=torch.float32, + dtype=None, ) -> "ParticleBeam": """ Create version of this beam that is transformed to new beam parameters. diff --git a/cheetah/utils/__init__.py b/cheetah/utils/__init__.py index a29d9ae9..e5c574a0 100644 --- a/cheetah/utils/__init__.py +++ b/cheetah/utils/__init__.py @@ -1,4 +1,10 @@ from . import bmadx # noqa: F401 +from .argument_verification import ( # noqa: F401 + extract_argument_device, + extract_argument_dtype, + extract_argument_shape, + verify_device_and_dtype, +) from .device import is_mps_available_and_functional # noqa: F401 from .elementwise_linspace import elementwise_linspace # noqa: F401 from .kde import kde_histogram_1d, kde_histogram_2d # noqa: F401 diff --git a/cheetah/utils/argument_verification.py b/cheetah/utils/argument_verification.py new file mode 100644 index 00000000..7c67b7a1 --- /dev/null +++ b/cheetah/utils/argument_verification.py @@ -0,0 +1,56 @@ +from typing import Optional + +import torch + + +def extract_argument_device(arguments: list[torch.Tensor]) -> torch.device: + """ + Determines whether all arguments are on the same device and returns the default + pytorch device if no argumente are passed. + """ + if len(arguments) > 1: + assert all( + argument.device == arguments[0].device for argument in arguments + ), "Arguments must be on the same device." + + return arguments[0].device if len(arguments) > 0 else torch.get_default_device() + + +def extract_argument_dtype(arguments: list[torch.Tensor]) -> torch.dtype: + """ + Determines whether all arguments have the same dtype and returns the default + pytorch dtype if no argumente are passed. + """ + if len(arguments) > 1: + assert all( + argument.dtype == arguments[0].dtype for argument in arguments + ), "Arguments must have the same dtype." + + return arguments[0].dtype if len(arguments) > 0 else torch.get_default_dtype() + + +def extract_argument_shape(arguments: list[torch.Tensor]) -> torch.Size: + """Determines whether all arguments have the same shape.""" + if len(arguments) > 1: + assert all( + argument.shape == arguments[0].shape for argument in arguments + ), "Arguments must have the same shape." + + return arguments[0].shape if len(arguments) > 0 else torch.Size([1]) + + +def verify_device_and_dtype( + required: list[torch.Tensor], + optionals: list[Optional[torch.Tensor]], + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.device, torch.dtype]: + """ + Verifies that all required & given optional arguments have the same device and + dtype if no defaults are provided. + """ + not_nones = required + [argument for argument in optionals if argument is not None] + + device = device if device is not None else extract_argument_device(not_nones) + dtype = dtype if dtype is not None else extract_argument_dtype(not_nones) + return (device, dtype) diff --git a/tests/test_elegant_conversion.py b/tests/test_elegant_conversion.py index 5e4ac94b..01fdfbaf 100644 --- a/tests/test_elegant_conversion.py +++ b/tests/test_elegant_conversion.py @@ -17,14 +17,14 @@ def test_fodo(): cheetah.Quadrupole( name="q1", length=torch.tensor(0.1), k1=torch.tensor(1.5) ), - cheetah.Drift(name="d1", length=torch.tensor(1)), + cheetah.Drift(name="d1", length=torch.tensor(1.0)), cheetah.Marker(name="m1"), cheetah.Dipole(name="s1", length=torch.tensor(0.3), e1=torch.tensor(0.25)), - cheetah.Drift(name="d1", length=torch.tensor(1)), + cheetah.Drift(name="d1", length=torch.tensor(1.0)), cheetah.Quadrupole( - name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3) + name="q2", length=torch.tensor(0.2), k1=torch.tensor(-3.0) ), - cheetah.Drift(name="d2", length=torch.tensor(2)), + cheetah.Drift(name="d2", length=torch.tensor(2.0)), ], name="fodo", ) diff --git a/tests/test_screen.py b/tests/test_screen.py index 0a1a43c3..6981d3d1 100644 --- a/tests/test_screen.py +++ b/tests/test_screen.py @@ -51,7 +51,7 @@ def test_screen_kde_bandwidth(kde_bandwidth): is_active=True, method="kde", name="my_screen", - kde_bandwidth=kde_bandwidth, + kde_bandwidth=torch.tensor(kde_bandwidth), ), ], ) diff --git a/tests/test_space_charge_kick.py b/tests/test_space_charge_kick.py index 8d91cb32..309dc07b 100644 --- a/tests/test_space_charge_kick.py +++ b/tests/test_space_charge_kick.py @@ -269,7 +269,10 @@ def test_space_charge_with_ares_astra_beam(): `IndexError: index -38 is out of bounds for dimension 3 with size 32`. """ segment = cheetah.Segment( - [cheetah.Drift(length=1.0), cheetah.SpaceChargeKick(effect_length=1.0)] + [ + cheetah.Drift(length=torch.tensor(1.0)), + cheetah.SpaceChargeKick(effect_length=torch.tensor(1.0)), + ] ) beam = cheetah.ParticleBeam.from_astra("tests/resources/ACHIP_EA1_2021.1351.001")