"""Module to describe the contrast transfer function (CTF) and the related apertures."""
from __future__ import annotations
import copy
from abc import abstractmethod
from collections import defaultdict
from functools import reduce
from typing import TYPE_CHECKING, Any, Mapping, Optional, SupportsFloat
import numpy as np
from abtem.core.axes import AxisMetadata, OrdinalAxis, ParameterAxis
from abtem.core.backend import cp, get_array_module
from abtem.core.complex import complex_exponential
from abtem.core.energy import (
Accelerator,
HasAcceleratorMixin,
energy2wavelength,
reciprocal_space_sampling_to_angular_sampling,
)
from abtem.core.fft import fft_crop
from abtem.core.grid import Grid, HasGrid2DMixin, polar_spatial_frequencies
from abtem.core.utils import expand_dims_to_broadcast, get_dtype
from abtem.distributions import (
BaseDistribution,
_unpack_distributions,
validate_distribution,
)
from abtem.measurements import ReciprocalSpaceLineProfiles
from abtem.transform import ReciprocalSpaceMultiplication
if TYPE_CHECKING:
from abtem.measurements import DiffractionPatterns, Images
from abtem.visualize import Visualization
from abtem.waves import BaseWaves, Waves
[docs]
class BaseTransferFunction(
ReciprocalSpaceMultiplication, HasAcceleratorMixin, HasGrid2DMixin
):
"""Base class for transfer functions."""
[docs]
def __init__(
self,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
distributions: tuple[str, ...] = (),
):
self._accelerator = Accelerator(energy=energy)
self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling)
super().__init__(distributions=distributions)
@abstractmethod
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
pass
@property
def angular_sampling(self) -> tuple[float, float]:
"""The sampling in scattering angles of the transfer function [mrad]."""
return reciprocal_space_sampling_to_angular_sampling(
self.reciprocal_space_sampling, self._valid_energy
)
def _angular_grid(self, device: str) -> tuple[np.ndarray, np.ndarray]:
xp = get_array_module(device)
alpha, phi = polar_spatial_frequencies(
self._valid_gpts, self._valid_sampling, xp=xp
)
alpha *= self.wavelength
return alpha, phi
def _evaluate_kernel(self, waves: Optional[BaseWaves] = None) -> np.ndarray:
"""
Evaluate the array to be multiplied with the waves in reciprocal space.
Parameters
----------
waves : BaseWaves, optional
If given, the array will be evaluated to match the provided waves.
Returns
-------
kernel : np.ndarray or dask.array.Array
"""
if waves is None:
device = "cpu"
else:
self.accelerator.match(waves)
self.grid.match(waves)
device = waves.device
self.grid.check_is_defined()
self.accelerator.check_is_defined()
alpha, phi = self._angular_grid(device)
return self._evaluate_from_angular_grid(alpha, phi)
[docs]
def to_diffraction_patterns(
self,
max_angle: Optional[float] = None,
gpts: Optional[int | tuple[int, int]] = None,
) -> DiffractionPatterns:
"""Converts the transfer function instance to DiffractionPatterns.
Parameters
----------
max_angle : float, optional
The maximum diffraction angle in radians. If not provided, the maximum angle
will be determined based on the `self._max_semiangle_cutoff` attribute of
the instance. If neither `max_angle` nor `self._max_semiangle_cutoff` is
available, a `RuntimeError` will be raised.
gpts : int | tuple[int, int], optional
The number of grid points in reciprocal space for performing Fourier
Transform. If not provided, a default value of 128 will be used.
Returns
-------
abtem.measurements.DiffractionPatterns
The diffraction patterns obtained from the conversion.
"""
from abtem.measurements import DiffractionPatterns
if self.sampling is None or max_angle is not None:
if max_angle is None and hasattr(self, "_max_semiangle_cutoff"):
max_angle = self._max_semiangle_cutoff
elif max_angle is None:
raise RuntimeError()
sampling = 1 / (max_angle * 1e-3) / 2 * self.wavelength
else:
sampling = self.sampling
if self.gpts is None and gpts is None:
gpts = 128
else:
gpts = self.gpts
ctf = self.copy()
ctf.gpts = gpts
ctf.sampling = sampling
array = ctf._evaluate_kernel()
xp = get_array_module(array)
diffraction_patterns = DiffractionPatterns(
xp.fft.fftshift(array, axes=(-2, -1)),
sampling=ctf.reciprocal_space_sampling,
ensemble_axes_metadata=ctf.ensemble_axes_metadata,
fftshift=False,
metadata={"energy": self.energy},
)
return diffraction_patterns
def show(self, max_angle: Optional[float] = None, **kwargs: Any) -> Visualization:
return self.to_diffraction_patterns(max_angle=max_angle).show(**kwargs)
[docs]
class BaseAperture(BaseTransferFunction):
"""Base class for apertures. Documented in the subclasses."""
[docs]
def __init__(
self,
semiangle_cutoff: float | BaseDistribution = np.inf,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
distributions: tuple[str, ...] = (),
):
self._semiangle_cutoff = semiangle_cutoff
super().__init__(
energy=energy,
extent=extent,
gpts=gpts,
sampling=sampling,
distributions=distributions,
)
@property
def metadata(self) -> dict:
metadata = {}
if not isinstance(self.semiangle_cutoff, BaseDistribution):
metadata["semiangle_cutoff"] = self.semiangle_cutoff
return metadata
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return []
@property
def _max_semiangle_cutoff(self) -> float:
if isinstance(self.semiangle_cutoff, BaseDistribution):
return max(self.semiangle_cutoff.values)
else:
return self.semiangle_cutoff
@property
def nyquist_sampling(self) -> float:
"""Nyquist sampling corresponding to the semiangle cutoff of the
aperture [Å]."""
return 1 / (4 * self._max_semiangle_cutoff / self.wavelength * 1e-3)
@property
def semiangle_cutoff(self) -> float | BaseDistribution:
"""Semiangle cutoff of the aperture [mrad]."""
return self._semiangle_cutoff
@semiangle_cutoff.setter
def semiangle_cutoff(self, semiangle_cutoff: float | BaseDistribution) -> None:
self._semiangle_cutoff = semiangle_cutoff
def _cropped_aperture(self) -> BaseAperture:
if self._max_semiangle_cutoff == np.inf:
return self
gpts = (
int(2 * np.ceil(self._max_semiangle_cutoff / self.angular_sampling[0] + 1)),
int(2 * np.ceil(self._max_semiangle_cutoff / self.angular_sampling[1] + 1)),
)
cropped_aperture = self.copy()
cropped_aperture.gpts = gpts
return cropped_aperture
def _evaluate_from_cropped(self, waves: Waves) -> np.ndarray:
cropped = self._cropped_aperture()
array = cropped._evaluate_kernel(waves)
array = fft_crop(array, waves._valid_gpts)
return array
[docs]
def soft_aperture(
alpha: np.ndarray,
phi: np.ndarray,
semiangle_cutoff: float | np.ndarray,
angular_sampling: tuple[float, float],
) -> np.ndarray:
"""
Calculates an array with a disk of ones and a soft edge.
Parameters
----------
alpha : 2D array
Array of radial angles [mrad].
phi : 2D array
Array of azimuthal angles [rad].
semiangle_cutoff : float or 1D array
Semiangle cutoff(s) of the aperture(s). If given as an array, a 3D array is
returned where the first dimension represents a different aperture for each
item in the array of semiangle cutoffs.
angular_sampling : tuple of float
Reciprocal-space sampling in units of scattering angles [mrad].
Returns
-------
soft_aperture_array : 2D or 3D np.ndarray
"""
xp = get_array_module(alpha)
semiangle_cutoff_array = xp.asarray(semiangle_cutoff, dtype=get_dtype(complex=False))
base_ndims = len(alpha.shape)
semiangle_cutoff_array, alpha = expand_dims_to_broadcast(
semiangle_cutoff_array, alpha
)
semiangle_cutoff, phi = expand_dims_to_broadcast(
semiangle_cutoff_array, phi, match_dims=((-2, -1), (-2, -1))
)
angular_sampling = xp.asarray(angular_sampling, dtype=get_dtype(complex=False)) * 1e-3
denominator = xp.sqrt(
(xp.cos(phi) * angular_sampling[0]) ** 2
+ (xp.sin(phi) * angular_sampling[1]) ** 2
)
ndims = len(alpha.shape)
zeros = (slice(None),) * (ndims - base_ndims) + (0,) * base_ndims
denominator[zeros] = 1.0
array = xp.clip(
(semiangle_cutoff - alpha) / denominator + 0.5, a_min=0.0, a_max=1.0
)
array[zeros] = 1.0
return array
[docs]
def hard_aperture(
alpha: np.ndarray, semiangle_cutoff: float | BaseDistribution
) -> np.ndarray:
"""
Calculates an array with a disk of ones and a soft edge.
Parameters
----------
alpha : 2D array
Array of radial angles [mrad].
semiangle_cutoff : float or 1D array
Semiangle cutoff(s) of the aperture(s). If given as an array, a 3D array is
returned where the first dimension represents a different aperture for each
item in the array of semiangle cutoffs.
Returns
-------
hard_aperture_array : 2D or 3D np.ndarray
"""
xp = get_array_module(alpha)
return xp.array(alpha <= semiangle_cutoff).astype(get_dtype(complex=False))
[docs]
class Aperture(BaseAperture):
"""
A circular aperture cutting off the wave function at a specified angle, employed in
both STEM and HRTEM. The abrupt cutoff may be softened by tapering it.
Parameters
----------
semiangle_cutoff : float or BaseDistribution
The cutoff semiangle of the aperture [mrad]. Alternatively, a distribution of
angles may be provided.
soft : bool, optional
If True, the edge of the aperture is softened (default is True).
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
"""
[docs]
def __init__(
self,
semiangle_cutoff: float | BaseDistribution,
soft: bool = True,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
):
validated_semiangle_cutoff = validate_distribution(semiangle_cutoff)
self._soft = soft
super().__init__(
distributions=("semiangle_cutoff",),
energy=energy,
semiangle_cutoff=validated_semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
ensemble_axes_metadata: list[AxisMetadata] = []
if isinstance(self.semiangle_cutoff, BaseDistribution):
ensemble_axes_metadata = [
ParameterAxis(
label="semiangle_cutoff",
values=tuple(self.semiangle_cutoff),
units="mrad",
tex_label="$\\alpha_{cut}$",
_ensemble_mean=self.semiangle_cutoff.ensemble_mean,
)
]
return ensemble_axes_metadata
@property
def soft(self) -> bool:
"""True if the aperture has a soft edge."""
return self._soft
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
if self.semiangle_cutoff == xp.inf:
return xp.ones_like(alpha)
semiangle_cutoff = xp.asarray(self.semiangle_cutoff) * 1e-3
if (
self.soft
and self.grid.check_is_defined(False)
and not np.isscalar(alpha)
and not np.isscalar(phi)
):
aperture = soft_aperture(
alpha, phi, semiangle_cutoff, self.angular_sampling
)
return aperture
else:
return hard_aperture(alpha, semiangle_cutoff)
[docs]
class Bullseye(BaseAperture):
"""
Bullseye aperture.
Parameters
----------
num_spokes : int
Number of spokes.
spoke_width : float
Width of spokes [deg].
num_rings : int
Number of rings.
ring_width : float
Width of rings [mrad].
semiangle_cutoff : float
The cutoff semiangle of the aperture [mrad].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
"""
[docs]
def __init__(
self,
num_spokes: int,
spoke_width: float,
num_rings: int,
ring_width: float,
semiangle_cutoff: float,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
):
self._spoke_num = num_spokes
self._spoke_width = spoke_width
self._num_rings = num_rings
self._ring_width = ring_width
super().__init__(
energy=energy,
semiangle_cutoff=semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def soft(self) -> bool:
"""True if the aperture has a soft edge."""
return False
@property
def num_spokes(self) -> int:
"""Number of spokes."""
return self._spoke_num
@property
def spoke_width(self) -> float:
"""Width of spokes [deg]."""
return self._spoke_width
@property
def num_rings(self) -> int:
"""Number of rings."""
return self._num_rings
@property
def ring_width(self) -> float:
"""Width of rings [mrad]."""
return self._ring_width
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
alpha = xp.array(alpha)
semiangle_cutoff = self.semiangle_cutoff
assert isinstance(semiangle_cutoff, SupportsFloat)
semiangle_cutoff = semiangle_cutoff / 1e3
array = alpha < semiangle_cutoff
# add crossbars
array = array * (
((phi + np.pi * self.spoke_width / (180 * 2)) * self.num_spokes)
% (2 * np.pi)
> (np.pi * self.spoke_width / 180 * self.num_spokes)
)
# add ring bars
end_edges = np.linspace(
semiangle_cutoff / self.num_rings, semiangle_cutoff, self.num_rings
)
start_edges = end_edges - self.ring_width / 1e3
for start_edge, end_edge in zip(start_edges, end_edges):
array[(alpha > start_edge) * (alpha < end_edge)] = 0.0
return array
[docs]
class Vortex(BaseAperture):
"""
Vortex-beam aperture.
Parameters
----------
quantum_number : int
Quantum number of vortex beam.
semiangle_cutoff : float
The cutoff semiangle of the aperture [mrad].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
"""
[docs]
def __init__(
self,
quantum_number: int,
semiangle_cutoff: float,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
soft: bool = False,
):
self._quantum_number = quantum_number
self._soft = soft
super().__init__(
energy=energy,
semiangle_cutoff=semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def soft(self) -> bool:
"""True if the aperture has a soft edge."""
return self._soft
@property
def quantum_number(self) -> int:
"""Quantum number of vortex beam."""
return self._quantum_number
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
alpha = xp.array(alpha)
semiangle_cutoff = self.semiangle_cutoff
assert isinstance(semiangle_cutoff, SupportsFloat)
semiangle_cutoff = semiangle_cutoff / 1e3
if self._soft:
array = soft_aperture(alpha, phi, semiangle_cutoff, self.angular_sampling)
else:
array = alpha < semiangle_cutoff
array = array * np.exp(1j * phi * self.quantum_number)
return array
[docs]
class AnnularAperture(BaseAperture):
"""
Annular aperture.
Parameters
----------
inner_cutoff : float
The cutoff semiangle of inner radius of the aperture [mrad].
semiangle_cutoff : float
The cutoff semiangle of the aperture [mrad].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
"""
[docs]
def __init__(
self,
inner_cutoff: float,
semiangle_cutoff: float,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
):
self._inner_cutoff = inner_cutoff
super().__init__(
energy=energy,
semiangle_cutoff=semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def inner_cutoff(self) -> float:
"""The cutoff semiangle of inner radius of the aperture."""
return self._inner_cutoff
@property
def soft(self) -> bool:
"""True if the aperture has a soft edge."""
return False
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
alpha = xp.array(alpha)
semiangle_cutoff = self.semiangle_cutoff
inner_cutoff = self._inner_cutoff
assert isinstance(semiangle_cutoff, SupportsFloat)
semiangle_cutoff = semiangle_cutoff / 1e3
inner_cutoff = inner_cutoff / 1e3
return xp.asarray((alpha > inner_cutoff) * (alpha < semiangle_cutoff), dtype=get_dtype(complex=False))
[docs]
class Zernike(BaseAperture):
"""
Zernike aperture.
Parameters
----------
center_hole_cutoff : float
Cutoff semiangle of aperture hole [mrad].
phase_shift: float
Phase shift of Zernike film [rad]
semiangle_cutoff : float
The cutoff semiangle of the aperture [mrad].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
"""
[docs]
def __init__(
self,
center_hole_cutoff: float,
phase_shift: float,
semiangle_cutoff: float,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
):
self._center_hole_cutoff = center_hole_cutoff
self._phase_shift = phase_shift
super().__init__(
energy=energy,
semiangle_cutoff=semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def center_hole_cutoff(self) -> float:
"""Cutoff semiangle of aperture hole."""
return self._center_hole_cutoff
@property
def soft(self) -> bool:
"""True if the aperture has a soft edge."""
return False
@property
def phase_shift(self) -> float:
"""Phase shift of Zernike film."""
return self._phase_shift
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
alpha = xp.array(alpha)
semiangle_cutoff = self.semiangle_cutoff
assert isinstance(semiangle_cutoff, SupportsFloat)
semiangle_cutoff = semiangle_cutoff / 1e3
center_hole_cutoff = self.center_hole_cutoff / 1e3
phase_shift = self.phase_shift
amplitude = xp.asarray(alpha < semiangle_cutoff, dtype=get_dtype(complex=False))
phase_array = xp.asarray(
xp.logical_and(alpha > center_hole_cutoff, alpha < semiangle_cutoff),
dtype=get_dtype(complex=False),
)
phase = xp.exp(1.0j * phase_shift * phase_array)
array = amplitude * phase
return array
[docs]
class RadialPhasePlate(BaseAperture):
[docs]
def __init__(
self,
num_flips: int,
semiangle_cutoff: float,
phase_shift: float = np.pi,
power_law: float = 2.0,
shift_central_semiangle: float = 0.0,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
):
self._num_flips = num_flips
self._shift_central_semiangle = shift_central_semiangle
self._power_law = power_law
self._phase_shift = phase_shift
super().__init__(
distributions=(),
semiangle_cutoff=semiangle_cutoff,
energy=energy,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def soft(self) -> bool:
"""True if the aperture has a soft edge."""
return False
@property
def num_flips(self) -> int:
"""Number of phase flips."""
return self._num_flips
@property
def phase_shift(self) -> float:
"""Phase shift of the phase plate."""
return self._phase_shift
@property
def power_law(self) -> float:
"""Power law of the phase plate."""
return self._power_law
@property
def shift_central_semiangle(self) -> float:
"""Shift central semiangle of the phase plate."""
return self._shift_central_semiangle
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
semiangle_cutoff = self.semiangle_cutoff
assert isinstance(semiangle_cutoff, SupportsFloat)
max_alpha = semiangle_cutoff / 1e3
min_alpha = self.shift_central_semiangle / 1e3
alpha_normed = xp.zeros_like(alpha)
alpha_normed[alpha > max_alpha] = max_alpha
alpha_normed[alpha < max_alpha] = alpha[alpha < max_alpha]
alpha_normed[alpha_normed < min_alpha] = min_alpha
alpha_normed = (alpha_normed - alpha_normed.min()) / np.ptp(alpha_normed)
phase_shift = xp.sin(alpha_normed * np.pi * (self.num_flips + 1)) < 0.0
phase_shift = xp.exp(1.0j * self.phase_shift * phase_shift)
return phase_shift
[docs]
class TemporalEnvelope(BaseTransferFunction):
"""
Envelope function for simulating partial temporal coherence in the quasi-coherent
approximation.
Parameters
----------
focal_spread: float or 1D array or BaseDistribution
The standard deviation of the focal spread due to chromatic aberration and lens
current instability [Å].
Alternatively, a distribution of values may be provided.
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
"""
[docs]
def __init__(
self,
focal_spread: float | BaseDistribution,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
):
self._accelerator = Accelerator(energy=energy)
self._focal_spread = validate_distribution(focal_spread)
super().__init__(
distributions=("focal_spread",),
energy=energy,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def focal_spread(self) -> float | BaseDistribution:
"""The standard deviation of the focal spread [Å]."""
return self._focal_spread
@focal_spread.setter
def focal_spread(self, value: float | BaseDistribution) -> None:
self._focal_spread = validate_distribution(value)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return self._get_axes_metadata_from_distributions(
focal_spread={"units": "mrad"}
)
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
unpacked, _ = _unpack_distributions(self.focal_spread, shape=alpha.shape, xp=xp)
(focal_spread,) = unpacked
alpha = xp.array(alpha)
alpha = xp.expand_dims(alpha, axis=tuple(range(0, self._num_ensemble_axes)))
array = xp.exp(
-((0.5 * xp.pi / self.wavelength * focal_spread * alpha**2) ** 2)
).astype(get_dtype(complex=False))
return array
[docs]
def symbol_to_tex_symbol(symbol: str) -> str:
tex_symbol = symbol.replace("C", "C_{").replace("phi", "\\phi_{") + "}"
return f"${tex_symbol}$"
polar_aliases = {
"defocus": "C10",
"Cs": "C30",
"C5": "C50",
"astigmatism": "C12",
"astigmatism_angle": "phi12",
"astigmatism3": "C32",
"astigmatism3_angle": "phi32",
"astigmatism5": "C52",
"astigmatism5_angle": "phi52",
"coma": "C21",
"coma_angle": "phi21",
"coma4": "C41",
"coma4_angle": "phi41",
"trefoil": "C23",
"trefoil_angle": "phi23",
"trefoil4": "C43",
"trefoil4_angle": "phi43",
"quadrafoil": "C34",
"quadrafoil_angle": "phi34",
"quadrafoil5": "C54",
"quadrafoil5_angle": "phi54",
"pentafoil": "C45",
"pentafoil_angle": "phi45",
"hexafoil": "C56",
"hexafoil_angle": "phi56",
}
polar_symbols = {value: key for key, value in polar_aliases.items()}
class _HasAberrations(HasAcceleratorMixin):
C10: float | BaseDistribution
C12: float | BaseDistribution
phi12: float | BaseDistribution
C21: float | BaseDistribution
phi21: float | BaseDistribution
C23: float | BaseDistribution
phi23: float | BaseDistribution
C30: float | BaseDistribution
C32: float | BaseDistribution
phi32: float | BaseDistribution
C34: float | BaseDistribution
phi34: float | BaseDistribution
C41: float | BaseDistribution
phi41: float | BaseDistribution
C43: float | BaseDistribution
phi43: float | BaseDistribution
C45: float | BaseDistribution
phi45: float | BaseDistribution
C50: float | BaseDistribution
C52: float | BaseDistribution
phi52: float | BaseDistribution
C54: float | BaseDistribution
phi54: float | BaseDistribution
C56: float | BaseDistribution
phi56: float | BaseDistribution
Cs: float | BaseDistribution
C5: float | BaseDistribution
astigmatism: float | BaseDistribution
astigmatism_angle: float | BaseDistribution
astigmatism3: float | BaseDistribution
astigmatism3_angle: float | BaseDistribution
astigmatism5: float | BaseDistribution
astigmatism5_angle: float | BaseDistribution
coma: float | BaseDistribution
coma_angle: float | BaseDistribution
coma4: float | BaseDistribution
coma4_angle: float | BaseDistribution
trefoil: float | BaseDistribution
trefoil_angle: float | BaseDistribution
trefoil4: float | BaseDistribution
trefoil4_angle: float | BaseDistribution
quadrafoil: float | BaseDistribution
quadrafoil_angle: float | BaseDistribution
quadrafoil5: float | BaseDistribution
quadrafoil5_angle: float | BaseDistribution
pentafoil: float | BaseDistribution
pentafoil_angle: float | BaseDistribution
hexafoil: float | BaseDistribution
hexafoil_angle: float | BaseDistribution
def __init__(self, *args, **kwargs):
self._aberration_coefficients = {symbol: 0.0 for symbol in polar_symbols.keys()}
super().__init__(*args, **kwargs)
def __getattr__(self, name: str) -> float | BaseDistribution:
name = polar_aliases.get(name, name)
if name not in polar_symbols:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{name}'"
)
return self._aberration_coefficients.get(name, 0.0)
def __setattr__(self, name: str, value: float | BaseDistribution) -> None:
if name == "defocus":
super().__setattr__(name, value)
return
name = polar_aliases.get(name, name)
if name in polar_symbols:
self._aberration_coefficients[name] = validate_distribution(value)
else:
super().__setattr__(name, value)
@property
def defocus(self) -> float | BaseDistribution:
"""Defocus equivalent to negative C10."""
return -self.C10
@defocus.setter
def defocus(self, value: float | BaseDistribution) -> None:
self.C10 = -value
def _nonzero_coefficients(self, symbols: tuple[str, ...]) -> bool:
for symbol in symbols:
if not np.isscalar(self._aberration_coefficients[symbol]):
return True
if not self._aberration_coefficients[symbol] == 0.0:
return True
return False
@property
def _phase_aberrations_ensemble_axes_metadata(self) -> list[AxisMetadata]:
axes_metadata: list[AxisMetadata] = []
for parameter_name, value in self._aberration_coefficients.items():
if isinstance(value, BaseDistribution):
axes_metadata += [
ParameterAxis(
label=parameter_name,
values=tuple(value.values),
units="Å",
_ensemble_mean=value.ensemble_mean,
tex_label=symbol_to_tex_symbol(parameter_name),
)
]
return axes_metadata
@property
def aberration_coefficients(self) -> Mapping[str, float | BaseDistribution]:
"""The aberration coefficients as a dictionary."""
return copy.deepcopy(self._aberration_coefficients)
@property
def _has_aberrations(self) -> bool:
if np.all(
[np.all(value == 0.0) for value in self._aberration_coefficients.values()]
):
return False
else:
return True
def set_aberrations(
self, aberration_coefficients: Mapping[str, str | float | BaseDistribution]
) -> None:
"""
Set the phase of the phase aberration.
Parameters
----------
aberration_coefficients : dict
Mapping from aberration symbols to their corresponding values.
"""
for symbol, value in aberration_coefficients.items():
if symbol in ("defocus", "C10"):
if isinstance(value, str) and value.lower() == "scherzer":
if self.energy is None:
raise RuntimeError(
"energy undefined, Scherzer defocus cannot be evaluated"
)
C30 = self._aberration_coefficients["C30"]
assert isinstance(C30, SupportsFloat)
value = scherzer_defocus(float(C30), self._valid_energy)
if isinstance(value, str):
raise ValueError("string values only allowed for defocus")
value = validate_distribution(value)
setattr(self, symbol, value)
# if not isinstance(value, str):
# value = validate_distribution(value)
# if symbol in polar_symbols:
# self._aberration_coefficients[symbol] = value
# elif symbol in polar_aliases:
# self._aberration_coefficients[self._aliases()[symbol]] = value
# else:
# raise ValueError("{} not a recognized parameter".format(symbol))
# for symbol, value in aberration_coefficients.items():
# if symbol in ("defocus", "C10"):
# if isinstance(value, str) and value.lower() == "scherzer":
# if self._valid_energy is None:
# raise RuntimeError(
# "energy undefined, Scherzer defocus cannot be evaluated"
# )
# value = scherzer_defocus(
# self._aberration_coefficients["C30"], self._valid_energy
# )
# elif isinstance(value, str):
# raise ValueError(
# f"String values for defocus must be 'Scherzer', got {value}"
# )
# value = validate_distribution(value)
# if symbol == "defocus":
# value = -value
# self._aberration_coefficients["C10"] = value
[docs]
class SpatialEnvelope(BaseTransferFunction, _HasAberrations):
"""
Envelope function for simulating partial spatial coherence in the quasi-coherent
approximation.
Parameters
----------
angular_spread: float or 1D array or BaseDistribution
The standard deviation of the angular deviations due to source size [mrad].
Alternatively, a distribution of standard deviations may be provided.
aberration_coefficients: dict, optional
Mapping from aberration symbols to their corresponding values. All aberration
magnitudes should be given in [Å] and angles should be given in [radian].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
kwargs : dict, optional
Optionally provide the aberration coefficients as keyword arguments.
"""
[docs]
def __init__(
self,
angular_spread: float | BaseDistribution,
aberration_coefficients: Optional[
Mapping[str, str | float | BaseDistribution]
] = None,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
**kwargs: str | float | BaseDistribution,
):
distributions = tuple(polar_symbols.keys()) + ("angular_spread",)
super().__init__(
distributions=distributions,
energy=energy,
extent=extent,
gpts=gpts,
sampling=sampling,
)
self._angular_spread = validate_distribution(angular_spread)
aberration_coefficients = (
{} if aberration_coefficients is None else aberration_coefficients
)
aberration_coefficients = {**aberration_coefficients, **kwargs}
self.set_aberrations(aberration_coefficients)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
ensemble_axes_metadata = [
*self._phase_aberrations_ensemble_axes_metadata,
*self._get_axes_metadata_from_distributions(
angular_spread={"units": "mrad"}
),
]
return ensemble_axes_metadata
@property
def angular_spread(self) -> float | BaseDistribution:
"""The standard deviation of the angular deviations due to source size
[mrad]."""
return self._angular_spread
@angular_spread.setter
def angular_spread(self, value: float | BaseDistribution) -> None:
self._angular_spread = value
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
args = tuple(self.aberration_coefficients.values()) + (self.angular_spread,)
unpacked, _ = _unpack_distributions(*args, shape=alpha.shape, xp=xp)
angular_spread = unpacked[-1] / 1e3
parameters = dict(zip(polar_symbols, unpacked[:-1]))
alpha = xp.array(alpha)
alpha = xp.expand_dims(alpha, axis=tuple(range(0, self._num_ensemble_axes)))
xp = get_array_module(alpha)
dchi_dk = (
2
* xp.pi
/ self.wavelength
* (
(
parameters["C12"] * xp.cos(2.0 * (phi - parameters["phi12"]))
+ parameters["C10"]
)
* alpha
+ (
parameters["C23"] * xp.cos(3.0 * (phi - parameters["phi23"]))
+ parameters["C21"] * xp.cos(1.0 * (phi - parameters["phi21"]))
)
* alpha**2
+ (
parameters["C34"] * xp.cos(4.0 * (phi - parameters["phi34"]))
+ parameters["C32"] * xp.cos(2.0 * (phi - parameters["phi32"]))
+ parameters["C30"]
)
* alpha**3
+ (
parameters["C45"] * xp.cos(5.0 * (phi - parameters["phi45"]))
+ parameters["C43"] * xp.cos(3.0 * (phi - parameters["phi43"]))
+ parameters["C41"] * xp.cos(1.0 * (phi - parameters["phi41"]))
)
* alpha**4
+ (
parameters["C56"] * xp.cos(6.0 * (phi - parameters["phi56"]))
+ parameters["C54"] * xp.cos(4.0 * (phi - parameters["phi54"]))
+ parameters["C52"] * xp.cos(2.0 * (phi - parameters["phi52"]))
+ parameters["C50"]
)
* alpha**5
)
)
dchi_dphi = (
-2
* xp.pi
/ self.wavelength
* (
1
/ 2.0
* (2.0 * parameters["C12"] * xp.sin(2.0 * (phi - parameters["phi12"])))
* alpha
+ 1
/ 3.0
* (
3.0 * parameters["C23"] * xp.sin(3.0 * (phi - parameters["phi23"]))
+ 1.0
* parameters["C21"]
* xp.sin(1.0 * (phi - parameters["phi21"]))
)
* alpha**2
+ 1
/ 4.0
* (
4.0 * parameters["C34"] * xp.sin(4.0 * (phi - parameters["phi34"]))
+ 2.0
* parameters["C32"]
* xp.sin(2.0 * (phi - parameters["phi32"]))
)
* alpha**3
+ 1
/ 5.0
* (
5.0 * parameters["C45"] * xp.sin(5.0 * (phi - parameters["phi45"]))
+ 3.0
* parameters["C43"]
* xp.sin(3.0 * (phi - parameters["phi43"]))
+ 1.0
* parameters["C41"]
* xp.sin(1.0 * (phi - parameters["phi41"]))
)
* alpha**4
+ (1 / 6.0)
* (
6.0 * parameters["C56"] * xp.sin(6.0 * (phi - parameters["phi56"]))
+ 4.0
* parameters["C54"]
* xp.sin(4.0 * (phi - parameters["phi54"]))
+ 2.0
* parameters["C52"]
* xp.sin(2.0 * (phi - parameters["phi52"]))
)
* alpha**5
)
)
array = xp.exp(
-xp.sign(angular_spread)
* (angular_spread / 2) ** 2
* (dchi_dk**2 + dchi_dphi**2)
)
return array
[docs]
class Aberrations(BaseTransferFunction, _HasAberrations):
"""
Phase aberrations.
Parameters
----------
aberration_coefficients: dict, optional
Mapping from aberration symbols to their corresponding values. All aberration
magnitudes should be given in [Å] and angles should be given in [radian].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [Å]. If 'gpts' is also given, will be
ignored.
kwargs : dict, optional
Optionally provide the aberration coefficients as keyword arguments.
"""
[docs]
def __init__(
self,
aberration_coefficients: Optional[
Mapping[str, str | float | BaseDistribution]
] = None,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
**kwargs: Any,
):
super().__init__(
distributions=tuple(polar_symbols.keys()),
energy=energy,
extent=extent,
gpts=gpts,
sampling=sampling,
)
aberration_coefficients = (
{} if aberration_coefficients is None else aberration_coefficients
)
aberration_coefficients = {**aberration_coefficients, **kwargs}
self.set_aberrations(aberration_coefficients)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return self._phase_aberrations_ensemble_axes_metadata
@property
def defocus(self) -> float | BaseDistribution:
"""The defocus [Å]."""
return -self._aberration_coefficients["C10"]
@defocus.setter
def defocus(self, value: float | BaseDistribution) -> None:
self.C10 = -value
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = get_array_module(alpha)
if not self._has_aberrations:
return xp.ones(
self.ensemble_shape + alpha.shape, dtype=get_dtype(complex=True)
)
parameter_values, weights = _unpack_distributions(
*tuple(self.aberration_coefficients.values()), shape=alpha.shape, xp=xp
)
parameters = {
symbol: value for symbol, value in zip(polar_symbols, parameter_values)
}
axis = tuple(range(0, len(self.ensemble_shape)))
alpha = xp.expand_dims(alpha, axis=axis)
phi = xp.expand_dims(phi, axis=axis).astype(get_dtype(complex=False))
array = xp.zeros(alpha.shape, dtype=get_dtype(complex=False))
if self._nonzero_coefficients(("C10", "C12", "phi12")):
array = array + (
1
/ 2
* alpha**2
* (
parameters["C10"]
+ parameters["C12"] * xp.cos(2 * (phi - parameters["phi12"]))
)
)
if self._nonzero_coefficients(("C21", "phi21", "C23", "phi23")):
array = array + (
1
/ 3
* alpha**3
* (
parameters["C21"] * xp.cos(phi - parameters["phi21"])
+ parameters["C23"] * xp.cos(3 * (phi - parameters["phi23"]))
)
)
if self._nonzero_coefficients(("C30", "C32", "phi32", "C34", "phi34")):
array = array + (
1
/ 4
* alpha**4
* (
parameters["C30"]
+ parameters["C32"] * xp.cos(2 * (phi - parameters["phi32"]))
+ parameters["C34"] * xp.cos(4 * (phi - parameters["phi34"]))
)
)
if self._nonzero_coefficients(("C41", "phi41", "C43", "phi43", "C45", "phi45")):
array = array + (
1
/ 5
* alpha**5
* (
parameters["C41"] * xp.cos((phi - parameters["phi41"]))
+ parameters["C43"] * xp.cos(3 * (phi - parameters["phi43"]))
+ parameters["C45"] * xp.cos(5 * (phi - parameters["phi45"]))
)
)
if self._nonzero_coefficients(
("C50", "C52", "phi52", "C54", "phi54", "C56", "phi56")
):
array = array + (
1
/ 6
* alpha**6
* (
parameters["C50"]
+ parameters["C52"] * xp.cos(2 * (phi - parameters["phi52"]))
+ parameters["C54"] * xp.cos(4 * (phi - parameters["phi54"]))
+ parameters["C56"] * xp.cos(6 * (phi - parameters["phi56"]))
)
)
dtype = get_dtype(complex=False)
array *= xp.array(2 * xp.pi / self.wavelength, dtype=dtype)
array = complex_exponential(-array)
if cp is not None:
weights = cp.asnumpy(weights)
if weights is not None:
array = xp.asarray(weights, dtype=dtype) * array
return array
[docs]
class CTF(_HasAberrations, BaseAperture):
"""
The contrast transfer function (CTF) describes the aberrations of the objective lens
in HRTEM and specifies how the condenser system shapes the probe in STEM.
abTEM implements phase aberrations up to 5th order using polar coefficients.
See Eq. 2.22 in the reference [1]_.
Cartesian coefficients can be converted to polar using the utility function
`abtem.transfer.cartesian2polar`.
Partial coherence is included as envelopes in the quasi-coherent approximation.
See Chapter 3.2 in reference [1]_.
Parameters
----------
semiangle_cutoff: float, optional
The semiangle cutoff describes the sharp reciprocal-space cutoff due to the
objective aperture [mrad] (default is no cutoff).
soft : bool, optional
If True, the edge of the aperture is softened (default is True).
focal_spread: float, optional
The standard deviation of the focal spread due to chromatic aberration and lens
current instability [Å] (default is 0).
angular_spread: float, optional
The standard deviation of the angular deviations due to source size [Å]
(default is 0).
aberration_coefficients: dict, optional
Mapping from aberration symbols to their corresponding values. All aberration
magnitudes should be given in [Å] and angles should be given in [radian].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single
float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be
ignored.
flip_phase : bool, optional
Changes the sign of all negative parts of the CTF to positive
(following doi:10.1016/j.ultramic.2008.03.004) (default is False).
wiener_snr : float, optional
Applies a Wiener filter to the CTF(following doi:10.1016/j.ultramic.2008.03.004)
with a given SNR value. If no value is given, the default value of 0.0 means
that no filter is applied.
kwargs : dict, optional
Optionally provide the aberration coefficients as keyword arguments.
References
----------
.. [1] Kirkland, E. J. (2010). Advanced Computing in Electron Microscopy (2nd ed.).
Springer.
"""
[docs]
def __init__(
self,
semiangle_cutoff: float | BaseDistribution = np.inf,
soft: bool = True,
focal_spread: float | BaseDistribution = 0.0,
angular_spread: float | BaseDistribution = 0.0,
aberration_coefficients: Optional[
Mapping[str, float | BaseDistribution]
] = None,
energy: Optional[float] = None,
extent: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
sampling: Optional[float | tuple[float, float]] = None,
flip_phase: bool = False,
wiener_snr: float = 0.0,
**kwargs: Any,
):
distributions = (
*tuple(polar_symbols.keys()),
"angular_spread",
"focal_spread",
"semiangle_cutoff",
)
semiangle_cutoff = validate_distribution(semiangle_cutoff)
super().__init__(
distributions=distributions,
energy=energy,
semiangle_cutoff=semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
aberration_coefficients = (
{} if aberration_coefficients is None else aberration_coefficients
)
aberration_coefficients = {**aberration_coefficients, **kwargs}
self.set_aberrations(aberration_coefficients)
self._angular_spread = validate_distribution(angular_spread)
self._focal_spread = validate_distribution(focal_spread)
self._soft = soft
self._flip_phase = flip_phase
self._wiener_snr = wiener_snr
@property
def scherzer_defocus(self) -> float:
"""The Scherzer defocus [Å]."""
if self.Cs == 0.0:
raise ValueError("Cs must be defined to calculate Scherzer defocus")
Cs = self.Cs
assert isinstance(Cs, SupportsFloat)
return scherzer_defocus(Cs, self._valid_energy)
@property
def crossover_angle(self) -> float:
"""The first zero-crossing of the phase at Scherzer defocus [mrad]."""
return 1e3 * energy2wavelength(self._valid_energy) / self.point_resolution
@property
def point_resolution(self) -> float:
"""The Scherzer point resolution [Å]."""
Cs = self.Cs
assert isinstance(Cs, SupportsFloat)
return point_resolution(Cs, self._valid_energy)
@property
def _aberrations(self) -> Aberrations:
return Aberrations(
aberration_coefficients=self.aberration_coefficients,
energy=self.energy,
extent=self.extent,
gpts=self.gpts,
)
@property
def _aperture(self) -> Aperture:
return Aperture(
semiangle_cutoff=self.semiangle_cutoff,
soft=self._soft,
energy=self.energy,
extent=self.extent,
gpts=self.gpts,
)
@property
def _spatial_envelope(self) -> SpatialEnvelope:
return SpatialEnvelope(
aberration_coefficients=self.aberration_coefficients,
angular_spread=self.angular_spread,
energy=self.energy,
)
@property
def _temporal_envelope(self) -> TemporalEnvelope:
return TemporalEnvelope(
focal_spread=self.focal_spread,
energy=self.energy,
)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return (
self._spatial_envelope.ensemble_axes_metadata
+ self._temporal_envelope.ensemble_axes_metadata
+ self._aperture.ensemble_axes_metadata
)
@property
def soft(self) -> float:
"""True if the aperture has a soft edge."""
return self._soft
@property
def semiangle_cutoff(self) -> float | BaseDistribution:
"""The semiangle cutoff [mrad]."""
return self._semiangle_cutoff
@semiangle_cutoff.setter
def semiangle_cutoff(self, value: float) -> None:
self._semiangle_cutoff = value
@property
def focal_spread(self) -> float | BaseDistribution:
"""The standard deviation of the focal spread [Å]."""
return self._focal_spread
@focal_spread.setter
def focal_spread(self, value: float) -> None:
self._focal_spread = value
@property
def angular_spread(self) -> float | BaseDistribution:
"""The standard deviation of the angular deviations due to source size
[mrad]."""
return self._angular_spread
@angular_spread.setter
def angular_spread(self, value: float) -> None:
self._angular_spread = value
@property
def flip_phase(self) -> bool:
"""If true the signs of all negative parts of the CTF are changed to
positive."""
return self._flip_phase
@flip_phase.setter
def flip_phase(self, value: bool) -> None:
self._flip_phase = value
@property
def wiener_snr(self) -> float:
"""If true a Wiener filter is applied to the CTF."""
return self._wiener_snr
@wiener_snr.setter
def wiener_snr(self, value: float) -> None:
self._wiener_snr = value
def _evaluate_to_match(
self,
component: Aberrations | Aperture | SpatialEnvelope | TemporalEnvelope,
alpha: np.ndarray,
phi: np.ndarray,
) -> np.ndarray:
expanded_axes: tuple[int, ...] = ()
for i, axis_metadata in enumerate(self.ensemble_axes_metadata):
expand = all([a != axis_metadata for a in component.ensemble_axes_metadata])
if expand:
expanded_axes += (i,)
array = component._evaluate_from_angular_grid(alpha, phi)
return np.expand_dims(array, expanded_axes)
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray, keep_all: bool = False
) -> np.ndarray:
match_dims = tuple(range(-len(alpha.shape), 0))
array = self._aberrations._evaluate_from_angular_grid(alpha, phi)
if self._spatial_envelope.angular_spread != 0.0:
new_aberrations_dims = tuple(range(len(self._aberrations.ensemble_shape)))
old_match_dims = new_aberrations_dims + match_dims
added_dims = int(hasattr(self._spatial_envelope.angular_spread, "values"))
new_match_dims = (
tuple(range(len(self._spatial_envelope.ensemble_shape) - added_dims))
+ match_dims
)
new_array = self._spatial_envelope._evaluate_from_angular_grid(alpha, phi)
array, new_array = expand_dims_to_broadcast(
array, new_array, match_dims=(old_match_dims, new_match_dims)
)
array = array * new_array
if self._temporal_envelope.focal_spread != 0.0:
new_array = self._temporal_envelope._evaluate_from_angular_grid(alpha, phi)
array, new_array = expand_dims_to_broadcast(
array, new_array, match_dims=(match_dims, match_dims)
)
array = array * new_array
if self._aperture.semiangle_cutoff != np.inf:
new_array = self._aperture._evaluate_from_angular_grid(alpha, phi)
array, new_array = expand_dims_to_broadcast(
array, new_array, match_dims=(match_dims, match_dims)
)
array = array * new_array
if self._wiener_snr != 0.0:
return (
(1 + 1 / self._wiener_snr)
* array**2
/ (array**2 + 1 / self._wiener_snr)
)
elif self._flip_phase:
return array.real - 1j * np.abs(array.imag)
else:
return array
def to_point_spread_functions(
self, gpts: int | tuple[int, int], extent: float | tuple[float, float]
) -> Images:
from abtem.waves import Probe
return (
Probe(gpts=gpts, extent=extent, energy=self.energy, aperture=self)
.build()
.to_images()
)
[docs]
def profiles(
self,
gpts: int = 1000,
max_angle: Optional[float] = None,
phi: float | np.ndarray = 0.0,
) -> ReciprocalSpaceLineProfiles:
"""
Calculate radial line profiles for each included component (phase aberrations,
aperture, temporal and spatial envelopes) of the contrast transfer function.
Parameters
----------
gpts: int
Number of grid points along the line profiles.
max_angle : float
The maximum scattering angle included in the radial line profiles [mrad].
The default is 1.5 times the semiangle cutoff or 50 mrad if no semiangle
cutoff is set.
phi : float
The azimuthal angle of the radial line profiles [rad]. Default is 0.
Returns
-------
ctf_profiles : ReciprocalSpaceLineProfiles
Ensemble of reciprocal space line profiles. The first ensemble dimension
represents the different
"""
if max_angle is None:
if self.semiangle_cutoff == np.inf:
max_angle = 50.0
else:
max_angle = self._max_semiangle_cutoff * 1.6
self.accelerator.check_is_defined()
sampling = max_angle / (gpts - 1) / (self.wavelength * 1e3)
alpha = np.linspace(0, max_angle * 1e-3, gpts).astype(get_dtype(complex=False))
phi = np.array(phi)
components = dict()
components["ctf"] = self._evaluate_to_match(self._aberrations, alpha, phi).imag
if self._spatial_envelope.angular_spread != 0.0:
components["spatial envelope"] = self._evaluate_to_match(
self._spatial_envelope, alpha, phi
)
if self._temporal_envelope.focal_spread != 0.0:
components["temporal envelope"] = self._evaluate_to_match(
self._temporal_envelope, alpha, phi
)
if self._aperture.semiangle_cutoff != np.inf:
components["aperture"] = self._evaluate_to_match(self._aperture, alpha, phi)
components["ctf"] = reduce(lambda x, y: x * y, tuple(components.values()))
ensemble_axes_metadata: list[AxisMetadata] = self.ensemble_axes_metadata
if len(components) > 1:
profiles = np.stack(
np.broadcast_arrays(*list(components.values())),
axis=-2,
)
component_metadata: list[AxisMetadata] = [
OrdinalAxis(
label="",
values=tuple(components.keys()),
)
]
ensemble_axes_metadata = ensemble_axes_metadata + component_metadata
else:
profiles = components["ctf"]
metadata = {"energy": self.energy}
profiles = ReciprocalSpaceLineProfiles(
profiles,
sampling=sampling,
metadata=metadata,
ensemble_axes_metadata=ensemble_axes_metadata,
)
return profiles
[docs]
def nyquist_sampling(semiangle_cutoff: float, energy: float) -> float:
"""
Calculate the Nyquist sampling.
Parameters
----------
semiangle_cutoff: float
Semiangle cutoff [mrad].
energy: float
Electron energy [eV].
"""
wavelength = energy2wavelength(energy)
return 1 / (4 * semiangle_cutoff / wavelength * 1e-3)
[docs]
def scherzer_defocus(Cs: float, energy: float) -> float:
"""
Calculate the Scherzer defocus.
Parameters
----------
Cs: float
Spherical aberration [Å].
energy: float
Electron energy [eV].
"""
return np.sign(Cs) * np.sqrt(3 / 2 * np.abs(Cs) * energy2wavelength(energy))
[docs]
def point_resolution(Cs: float, energy: float) -> float:
"""
Calculate the Scherzer point resolution.
Parameters
----------
Cs: float
Spherical aberration [Å].
energy: float
Electron energy [eV].
"""
return (energy2wavelength(energy) ** 3 * np.abs(Cs) / 6) ** (1 / 4)
[docs]
def polar2cartesian(polar: dict) -> dict:
"""
Convert between polar and Cartesian aberration coefficients.
Parameters
----------
polar : dict
Mapping from polar aberration symbols to their corresponding values.
Returns
-------
cartesian : dict
Mapping from Cartesian aberration symbols to their corresponding values.
"""
polar = defaultdict(lambda: 0, polar)
cartesian = dict()
cartesian["C10"] = polar["C10"]
cartesian["C12a"] = -polar["C12"] * np.cos(2 * polar["phi12"])
cartesian["C12b"] = polar["C12"] * np.sin(2 * polar["phi12"])
cartesian["C21a"] = polar["C21"] * np.sin(polar["phi21"])
cartesian["C21b"] = polar["C21"] * np.cos(polar["phi21"])
cartesian["C23a"] = -polar["C23"] * np.sin(3 * polar["phi23"])
cartesian["C23b"] = polar["C23"] * np.cos(3 * polar["phi23"])
cartesian["C30"] = polar["C30"]
cartesian["C32a"] = -polar["C32"] * np.cos(2 * polar["phi32"])
cartesian["C32b"] = polar["C32"] * np.cos(np.pi / 2 - 2 * polar["phi32"])
cartesian["C34a"] = polar["C34"] * np.cos(-4 * polar["phi34"])
k = np.sqrt(3 + np.sqrt(8.0))
cartesian["C34b"] = (
1
/ 4.0
* (1 + k**2) ** 2
/ (k**3 - k)
* polar["C34"]
* np.cos(4 * np.arctan(1 / k) - 4 * polar["phi34"])
)
return cartesian
[docs]
def cartesian2polar(cartesian: dict) -> dict:
"""
Convert between Cartesian and polar aberration coefficients.
Parameters
----------
cartesian : dict
Mapping from Cartesian aberration symbols to their corresponding values.
Returns
-------
polar : dict
Mapping from polar aberration symbols to their corresponding values.
"""
cartesian = defaultdict(lambda: 0, cartesian)
polar = dict()
polar["C10"] = cartesian["C10"]
polar["C12"] = -np.sqrt(cartesian["C12a"] ** 2 + cartesian["C12b"] ** 2)
polar["phi12"] = -np.arctan2(cartesian["C12b"], cartesian["C12a"]) / 2.0
polar["C21"] = np.sqrt(cartesian["C21a"] ** 2 + cartesian["C21b"] ** 2)
polar["phi21"] = np.arctan2(cartesian["C21a"], cartesian["C21b"])
polar["C23"] = np.sqrt(cartesian["C23a"] ** 2 + cartesian["C23b"] ** 2)
polar["phi23"] = -np.arctan2(cartesian["C23a"], cartesian["C23b"]) / 3.0
polar["C30"] = cartesian["C30"]
polar["C32"] = -np.sqrt(cartesian["C32a"] ** 2 + cartesian["C32b"] ** 2)
polar["phi32"] = -np.arctan2(cartesian["C32b"], cartesian["C32a"]) / 2.0
polar["C34"] = np.sqrt(cartesian["C34a"] ** 2 + cartesian["C34b"] ** 2)
polar["phi34"] = np.arctan2(cartesian["C34b"], cartesian["C34a"]) / 4
return polar