"""Module for simulating beam tilt."""
from __future__ import annotations
from typing import TYPE_CHECKING
import dask.array as da
import numpy as np
from abtem.core.axes import AxisMetadata, TiltAxis, AxisAlignedTiltAxis
from abtem.core.backend import get_array_module
from abtem.transform import CompositeArrayObjectTransform, ArrayObjectTransform
from abtem.distributions import (
BaseDistribution,
MultidimensionalDistribution,
EnsembleFromDistributions,
from_values,
validate_distribution,
)
if TYPE_CHECKING:
from abtem.waves import Waves
def _validate_tilt(
tilt: BaseDistribution | tuple[float, float] | np.ndarray
) -> BeamTilt | CompositeArrayObjectTransform:
"""Validate that the given tilt is correctly defined."""
if isinstance(tilt, MultidimensionalDistribution):
raise NotImplementedError
if isinstance(tilt, BaseDistribution):
return BeamTilt(tilt)
elif isinstance(tilt, (tuple, list)):
assert len(tilt) == 2
transforms = []
for tilt_component, direction in zip(tilt, ("x", "y")):
transforms.append(
AxisAlignedBeamTilt(tilt=tilt_component, direction=direction)
)
tilt = CompositeArrayObjectTransform(transforms)
elif isinstance(tilt, np.ndarray):
return BeamTilt(tilt)
return tilt
def _get_tilt_axes(waves):
return tuple(
i
for i, axis in enumerate(waves.ensemble_axes_metadata)
if hasattr(axis, "tilt")
)
[docs]
def precession_tilts(
precession_angle: float,
num_samples: int,
min_azimuth: float = 0.0,
max_azimuth: float = 2 * np.pi,
endpoint: bool = False,
):
"""
Tilts for electron precession at a given precession angle.
Parameters
----------
precession_angle : float
Precession angle [mrad].
num_samples : int
Number of tilt samples.
min_azimuth : float, optional
Minmimum azimuthal angle [rad]. Default is 0.
max_azimuth : float, optional
Maximum azimuthal angle [rad]. Default is $2 \\pi$.
endpoint
If True, end is `max_azimuth`. Otherwise, it is not included. Default is False.
Returns
-------
array_of_tilts : 2D array
Array of xy-tilt angles [rad].
"""
azimuthal_angles = np.linspace(
min_azimuth, max_azimuth, num=num_samples, endpoint=endpoint
)
tilt_x = precession_angle * np.cos(azimuthal_angles)
tilt_y = precession_angle * np.sin(azimuthal_angles)
return np.array([tilt_x, tilt_y], dtype=float).T
[docs]
class BaseBeamTilt(EnsembleFromDistributions, ArrayObjectTransform):
[docs]
def apply(self, waves: Waves, in_place: bool = False) -> Waves:
"""
Apply tilt(s) to (an ensemble of) wave function(s).
Parameters
----------
waves : Waves
The waves to transform.
in_place: bool, optional
If True, the array representing the waves may be modified in-place.
Returns
-------
waves_with_tilt : Waves
"""
xp = get_array_module(waves.device)
array = waves.array[(None,) * len(self.ensemble_shape)]
if waves.is_lazy:
array = da.tile(array, self.ensemble_shape + (1,) * len(waves.shape))
else:
array = xp.tile(array, self.ensemble_shape + (1,) * len(waves.shape))
kwargs = waves._copy_kwargs(exclude=("array",))
kwargs["array"] = array
kwargs["metadata"] = {**kwargs["metadata"], **self.metadata}
kwargs["ensemble_axes_metadata"] = (
self.ensemble_axes_metadata + kwargs["ensemble_axes_metadata"]
)
return waves.__class__(**kwargs)
[docs]
class BeamTilt(BaseBeamTilt):
"""
Class describing beam tilt.
Parameters
----------
tilt : tuple of float
Tilt along the `x` and `y` axes [mrad] with an optional spread of values.
"""
[docs]
def __init__(self, tilt: tuple[float, float] | BaseDistribution | np.ndarray):
if isinstance(tilt, np.ndarray):
tilt = from_values(tilt)
self._tilt = tilt
super().__init__(distributions=("tilt",))
@property
def tilt(self) -> tuple[float, float] | BaseDistribution:
"""Beam tilt angle [mrad]."""
return self._tilt
@property
def metadata(self):
"""Metadata describing the tilt."""
if isinstance(self.tilt, BaseDistribution):
return {"base_tilt_x": 0.0, "base_tilt_y": 0.0}
else:
return {"base_tilt_x": self.tilt[0], "base_tilt_y": self.tilt[1]}
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
"""Metadata describing (an ensemble of) tilted wave function(s)."""
if isinstance(self.tilt, BaseDistribution):
return [
TiltAxis(
label=f"tilt",
values=tuple(tuple(value) for value in self.tilt.values),
units="mrad",
_ensemble_mean=self.tilt.ensemble_mean,
)
]
[docs]
class AxisAlignedBeamTilt(BaseBeamTilt):
"""
Class describing tilt(s) aligned with an axis.
Parameters
----------
tilt : array of BeamTilt
Tilt along the given direction with an optional spread of values.
direction : str
Cartesian axis, should be either 'x' or 'y'.
"""
[docs]
def __init__(self, tilt: float | BaseDistribution = 0.0, direction: str = "x"):
if isinstance(tilt, (np.ndarray, list, tuple)):
tilt = validate_distribution(tilt)
if not isinstance(tilt, BaseDistribution):
tilt = float(tilt)
self._tilt = tilt
self._direction = direction
super().__init__(distributions=("tilt",))
@property
def direction(self) -> str:
"""Tilt direction."""
return self._direction
@property
def tilt(self) -> float | BaseDistribution:
"""Beam tilt [mrad]."""
return self._tilt
@property
def metadata(self):
if isinstance(self.tilt, BaseDistribution):
return {f"base_tilt_{self._direction}": 0.0}
else:
return {f"base_tilt_{self._direction}": self._tilt}
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
if isinstance(self.tilt, BaseDistribution):
return [
AxisAlignedTiltAxis(
label=f"tilt_{self._direction}",
values=tuple(self.tilt.values),
direction=self._direction,
units="mrad",
_ensemble_mean=self.tilt.ensemble_mean,
)
]
else:
return []