Source code for abtem.waves

"""Module for describing wave functions of the incoming electron beam and the exit wave."""

from __future__ import annotations

import itertools
from abc import abstractmethod
from copy import copy
from functools import partial
from numbers import Number
from typing import Sequence

import dask.array as da
import numpy as np
from ase import Atoms

import abtem
from abtem.array import ArrayObject, ComputableList, _expand_dims, _validate_lazy
from abtem.core.axes import (
    RealSpaceAxis,
    ReciprocalSpaceAxis,
    AxisMetadata,
    AxesMetadataList,
    OrdinalAxis,
)
from abtem.core.backend import (
    get_array_module,
    validate_device,
    device_name_from_array_module,
)
from abtem.core.chunks import validate_chunks
from abtem.core.complex import abs2
from abtem.core.energy import Accelerator
from abtem.core.energy import HasAcceleratorMixin
from abtem.core.ensemble import (
    EmptyEnsemble,
    Ensemble,
    _wrap_with_array,
    unpack_blockwise_args,
)
from abtem.core.fft import fft2, ifft2, fft_crop, fft_interpolate
from abtem.core.grid import Grid, validate_gpts, polar_spatial_frequencies
from abtem.core.grid import HasGridMixin
from abtem.core.utils import (
    safe_floor_int,
    CopyMixin,
    EqualityMixin,
    tuple_range,
    get_dtype,
)
from abtem.detectors import (
    BaseDetector,
    FlexibleAnnularDetector,
)
from abtem.distributions import BaseDistribution
from abtem.inelastic.core_loss import (
    BaseTransitionPotential,
    _validate_transition_potentials,
)
from abtem.measurements import (
    DiffractionPatterns,
    Images,
    BaseMeasurements,
    RealSpaceLineProfiles,
)
from abtem.multislice import (
    MultisliceTransform,
    transition_potential_multislice_and_detect,
)
from abtem.potentials.iam import BasePotential, _validate_potential
from abtem.scan import BaseScan, GridScan, _validate_scan, CustomScan
from abtem.slicing import SliceIndexedAtoms
from abtem.tilt import _validate_tilt
from abtem.transfer import Aberrations, CTF, Aperture, BaseAperture
from abtem.transform import (
    ArrayObjectTransform,
)


def _ensure_parity(n, even, v=1):
    assert (v == 1) or (v == -1)
    assert isinstance(even, bool)

    if n % 2 == 0 and not even:
        return n + v
    elif not n % 2 == 0 and even:
        return n + v
    return n


def _ensure_parity_of_gpts(new_gpts, old_gpts, parity):
    if parity == "same":
        return (
            _ensure_parity(new_gpts[0], old_gpts[0] % 2 == 0),
            _ensure_parity(new_gpts[1], old_gpts[1] % 2 == 0),
        )
    elif parity == "odd":
        return (
            _ensure_parity(new_gpts[0], even=False),
            _ensure_parity(new_gpts[1], even=False),
        )
    elif parity == "even":
        return (
            _ensure_parity(new_gpts[0], even=True),
            _ensure_parity(new_gpts[1], even=True),
        )
    elif parity != "none":
        raise ValueError()


def _antialias_cutoff_gpts(gpts, sampling):
    kcut = 2.0 / 3.0 / max(sampling)
    extent = gpts[0] * sampling[0], gpts[1] * sampling[1]
    new_gpts = safe_floor_int(kcut * extent[0]), safe_floor_int(kcut * extent[1])
    return _ensure_parity_of_gpts(new_gpts, gpts, parity="same")


[docs] class BaseWaves(HasGridMixin, HasAcceleratorMixin): """Base class of all wave functions. Documented in the subclasses.""" @property @abstractmethod def device(self): """The device where the waves are built or stored.""" pass @property def dtype(self): """The datatype of waves.""" return get_dtype(complex=True) @property @abstractmethod def metadata(self) -> dict: """Metadata stored as a dictionary.""" pass @property def base_axes_metadata(self) -> list[AxisMetadata]: """List of AxisMetadata for the base axes in real space.""" self.grid.check_is_defined() return [ RealSpaceAxis( label="x", sampling=self.sampling[0], units="Å", endpoint=False ), RealSpaceAxis( label="y", sampling=self.sampling[1], units="Å", endpoint=False ), ] @property def reciprocal_space_axes_metadata(self) -> list[AxisMetadata]: """List of AxisMetadata for base axes in reciprocal space.""" self.grid.check_is_defined() self.accelerator.check_is_defined() return [ ReciprocalSpaceAxis( label="scattering angle x", sampling=self.angular_sampling[0], units="mrad", ), ReciprocalSpaceAxis( label="scattering angle y", sampling=self.angular_sampling[1], units="mrad", ), ] @property def antialias_cutoff_gpts(self) -> tuple[int, int]: """ The number of grid points along the x and y direction in the simulation grid at the antialiasing cutoff scattering angle. """ if "adjusted_antialias_cutoff_gpts" in self.metadata: n = min(self.metadata["adjusted_antialias_cutoff_gpts"][0], self.gpts[0]) m = min(self.metadata["adjusted_antialias_cutoff_gpts"][1], self.gpts[1]) return n, m self.grid.check_is_defined() return _antialias_cutoff_gpts(self.gpts, self.sampling) @property def antialias_valid_gpts(self) -> tuple[int, int]: """ The number of grid points along the x and y direction in the simulation grid for the largest rectangle that fits within antialiasing cutoff scattering angle. """ cutoff_gpts = self.antialias_cutoff_gpts valid_gpts = ( safe_floor_int(cutoff_gpts[0] / np.sqrt(2)), safe_floor_int(cutoff_gpts[1] / np.sqrt(2)), ) valid_gpts = _ensure_parity_of_gpts(valid_gpts, self.gpts, parity="same") if "adjusted_antialias_cutoff_gpts" in self.metadata: n = min(self.metadata["adjusted_antialias_cutoff_gpts"][0], valid_gpts[0]) m = min(self.metadata["adjusted_antialias_cutoff_gpts"][1], valid_gpts[1]) return n, m return valid_gpts def _gpts_within_angle( self, angle: float | str, parity: str = "same" ) -> tuple[int, int]: if angle is None or angle == "full": return self.gpts elif isinstance(angle, (Number, float)): gpts = ( int(2 * np.ceil(angle / self.angular_sampling[0])) + 1, int(2 * np.ceil(angle / self.angular_sampling[1])) + 1, ) elif angle == "cutoff": gpts = self.antialias_cutoff_gpts elif angle == "valid": gpts = self.antialias_valid_gpts else: raise ValueError( "Angle must be a number or one of 'cutoff', 'valid' or 'full'" ) return _ensure_parity_of_gpts(gpts, self.gpts, parity=parity) @property def cutoff_angles(self) -> tuple[float, float]: """Scattering angles at the antialias cutoff [mrad].""" return ( self.antialias_cutoff_gpts[0] // 2 * self.angular_sampling[0], self.antialias_cutoff_gpts[1] // 2 * self.angular_sampling[1], ) @property def rectangle_cutoff_angles(self) -> tuple[float, float]: """Scattering angles corresponding to the sides of the largest rectangle within the antialias cutoff [mrad].""" return ( self.antialias_valid_gpts[0] // 2 * self.angular_sampling[0], self.antialias_valid_gpts[1] // 2 * self.angular_sampling[1], ) @property def full_cutoff_angles(self) -> tuple[float, float]: """Scattering angles corresponding to the full wave function size [mrad].""" return ( self.gpts[0] // 2 * self.angular_sampling[0], self.gpts[1] // 2 * self.angular_sampling[1], ) @property def angular_sampling(self) -> tuple[float, float]: """Reciprocal-space sampling in units of scattering angles [mrad].""" self.accelerator.check_is_defined() fourier_space_sampling = self.reciprocal_space_sampling return ( fourier_space_sampling[0] * self.wavelength * 1e3, fourier_space_sampling[1] * self.wavelength * 1e3, ) def _angular_grid(self): xp = get_array_module(self.device) alpha, phi = polar_spatial_frequencies(self.gpts, self.sampling, xp=xp) alpha *= self.wavelength return alpha, phi
class _WaveRenormalization(EmptyEnsemble, ArrayObjectTransform): def _calculate_new_array(self, array_object) -> np.ndarray | tuple[np.ndarray, ...]: array = array_object.normalize().array return array
[docs] class Waves(BaseWaves, ArrayObject): """ Waves define a batch of arbitrary 2D wave functions defined by a complex array. Parameters ---------- array : array Complex array defining one or more 2D wave functions. The second-to-last and last dimensions are the wave function `y`- and `x`-axes, respectively. energy : float Electron energy [eV]. extent : one or two float Extent of wave functions in `x` and `y` [Å]. sampling : one or two float Sampling of wave functions in `x` and `y` [1 / Å]. reciprocal_space : bool, optional If True, the wave functions are assumed to be represented in reciprocal space instead of real space (default is False). ensemble_axes_metadata : list of AxesMetadata Axis metadata for each ensemble axis. The axis metadata must be compatible with the shape of the array. metadata : dict A dictionary defining wave function metadata. All items will be added to the metadata of measurements derived from the waves. """ _base_dims = 2
[docs] def __init__( self, array: np.ndarray, energy: float, extent: float | tuple[float, float] = None, sampling: float | tuple[float, float] = None, reciprocal_space: bool = False, ensemble_axes_metadata: list[AxisMetadata] = None, metadata: dict = None, ): if sampling is not None and extent is not None: extent = None self._grid = Grid( extent=extent, gpts=array.shape[-2:], sampling=sampling, lock_gpts=True ) self._accelerator = Accelerator(energy=energy) self._reciprocal_space = reciprocal_space super().__init__( array=array, ensemble_axes_metadata=ensemble_axes_metadata, metadata=metadata, )
@property def device(self) -> str: """The device where the array is stored.""" return device_name_from_array_module(get_array_module(self.array)) @property def base_tilt(self): """ The base small-angle beam tilt (i.e. the beam tilt not associated with an ensemble axis) applied to the Fresnel propagator [mrad]. """ return ( self.metadata.get("base_tilt_x", 0.0), self.metadata.get("base_tilt_y", 0.0), ) @property def reciprocal_space(self): """True if the waves are represented in reciprocal space.""" return self._reciprocal_space @property def metadata(self) -> dict: self._metadata["energy"] = self.energy self._metadata["reciprocal_space"] = self.reciprocal_space return self._metadata
[docs] @classmethod def from_array_and_metadata( cls, array: np.ndarray, axes_metadata: list[AxisMetadata], metadata: dict = None ) -> Waves: """ Creates wave functions from a given array and metadata. Parameters ---------- array : array Complex array defining one or more 2D wave functions. The second-to-last and last dimensions are the wave function `y`- and `x`-axis, respectively. axes_metadata : list of AxesMetadata Axis metadata for each axis. The axis metadata must be compatible with the shape of the array. The last two axes must be RealSpaceAxis. metadata : A dictionary defining wave function metadata. All items will be added to the metadata of measurements derived from the waves. The metadata must contain the electron energy [eV]. Returns ------- wave_functions : Waves The created wave functions. """ energy = metadata["energy"] reciprocal_space = metadata.get("reciprocal_space", False) x_axis, y_axis = axes_metadata[-2], axes_metadata[-1] if isinstance(x_axis, RealSpaceAxis) and isinstance(y_axis, RealSpaceAxis): sampling = x_axis.sampling, y_axis.sampling else: raise ValueError() return cls( array, sampling=sampling, energy=energy, reciprocal_space=reciprocal_space, ensemble_axes_metadata=axes_metadata[:-2], metadata=metadata, )
[docs] def convolve( self, kernel: np.ndarray, axes_metadata: list[AxisMetadata] = None, out_space: str = "in_space", in_place: bool = False, ): """ Convolve the wave-function array with a given array. Parameters ---------- kernel : np.ndarray Array to be convolved with. axes_metadata : list of AxisMetadata, optional Metadata for the resulting convolved array. Needed only if the given array has more than two dimensions. out_space : str, optional Space in which the convolved array is represented. Options are 'reciprocal_space' and 'real_space' (default is the space of the wave functions). in_place : bool, optional If True, the array representing the waves may be modified in-place. Returns ------- convolved : Waves The convolved wave functions. """ if out_space == "in_space": fourier_space_out = self.reciprocal_space elif out_space in ("reciprocal_space", "real_space"): fourier_space_out = out_space == "reciprocal_space" else: raise ValueError if axes_metadata is None: axes_metadata = [] if (len(kernel.shape) - 2) != len(axes_metadata): raise ValueError("provide axes metadata for each ensemble axis") waves = self.ensure_reciprocal_space(overwrite_x=in_place) waves_dims = tuple(range(len(kernel.shape) - 2)) kernel_dims = tuple( range( len(kernel.shape) - 2, len(waves.array.shape) - 2 + len(kernel.shape) - 2, ) ) kernel = _expand_dims(kernel, axis=kernel_dims) array = _expand_dims(waves._array, axis=waves_dims) xp = get_array_module(self.device) kernel = xp.array(kernel) if in_place and (array.shape == kernel.shape): array *= kernel else: array = array * kernel if not fourier_space_out: array = ifft2(array, overwrite_x=in_place) d = waves._copy_kwargs(exclude=("array",)) d["reciprocal_space"] = fourier_space_out d["array"] = array d["ensemble_axes_metadata"] = axes_metadata + d["ensemble_axes_metadata"] return waves.__class__(**d)
[docs] def normalize(self, space: str = "reciprocal", in_place: bool = False): """ Normalize the wave functions in real or reciprocal space. Parameters ---------- space : str Should be one of 'real' or 'reciprocal' (default is 'reciprocal'). Defines whether the wave function should be normalized such that the intensity sums to one in real or reciprocal space. in_place : bool, optional If True, the array representing the waves may be modified in-place. Returns ------- normalized_waves : Waves The normalized wave functions. """ if self.is_lazy: return self.apply_transform(_WaveRenormalization()) xp = get_array_module(self.device) reciprocal_space = self.reciprocal_space if space == "reciprocal": waves = self.ensure_reciprocal_space(overwrite_x=in_place) f = xp.sqrt(abs2(waves.array).sum((-2, -1), keepdims=True)) if in_place: waves._array /= f else: waves._array = waves._array / f if not reciprocal_space: waves = waves.ensure_real_space(overwrite_x=in_place) elif space == "real": raise NotImplementedError else: raise ValueError() return waves
[docs] def tile(self, repetitions: tuple[int, int], renormalize: bool = False) -> Waves: """ Tile the wave functions. Can only be applied in real space. Parameters ---------- repetitions : two int The number of repetitions of the wave functions along the `x`- and `y`-axes. renormalize : bool, optional If True, preserve the total intensity of the wave function (default is False). Returns ------- tiled_wave_functions : Waves The tiled wave functions. """ xp = get_array_module(self.device) if self.reciprocal_space: raise NotImplementedError if self.is_lazy: tile_func = da.tile else: tile_func = xp.tile array = tile_func(self.array, (1,) * len(self.ensemble_shape) + repetitions) if hasattr(array, "rechunk"): array = array.rechunk(array.chunks[:-2] + (-1, -1)) kwargs = self._copy_kwargs(exclude=("array", "extent")) kwargs["array"] = array if renormalize: kwargs["array"] /= xp.asarray(np.prod(repetitions)) return self.__class__(**kwargs)
[docs] def ensure_reciprocal_space(self, overwrite_x: bool = False): """ Transform to reciprocal space if the wave functions are represented in real space. Parameters ---------- overwrite_x : bool, optional If True, modify the array in place; otherwise a copy is created (default is False). Returns ------- waves_in_reciprocal_space : Waves The wave functions in reciprocal space. """ if self.reciprocal_space: return self d = self._copy_kwargs(exclude=("array",)) d["array"] = fft2(self.array, overwrite_x=overwrite_x) d["reciprocal_space"] = True return self.__class__(**d)
[docs] def ensure_real_space(self, overwrite_x: bool = False): """ Transform to real space if the wave functions are represented in reciprocal space. Parameters ---------- overwrite_x : bool, optional If True, modify the array in place; otherwise a copy is created (default is False). Returns ------- waves_in_real_space : Waves The wave functions in real space. """ if not self.reciprocal_space: return self d = self._copy_kwargs(exclude=("array",)) d["array"] = ifft2(self.array, overwrite_x=overwrite_x) d["reciprocal_space"] = False waves = self.__class__(**d) return waves
[docs] def phase_shift(self, amount: float): """ Shift the phase of the wave functions. Parameters ---------- amount : float Amount of phase shift [rad]. Returns ------- phase_shifted_waves : Waves The shifted wave functions. """ def _phase_shift(array): xp = get_array_module(self.array) return xp.exp(1.0j * amount) * array d = self._copy_kwargs(exclude=("array",)) d["array"] = _phase_shift(self.array) d["reciprocal_space"] = False return self.__class__(**d)
[docs] def to_images(self, convert_complex: str = None) -> Images: """ The complex array of the wave functions at the image plane. Returns ------- images : Images The wave functions as an image. """ array = self.array.copy() metadata = copy(self.metadata) metadata["label"] = "intensity" metadata["units"] = "arb. unit" images = Images( array, sampling=self.sampling, ensemble_axes_metadata=self.ensemble_axes_metadata, metadata=metadata, ) if not convert_complex: return images if convert_complex in ("intensity", "phase", "real", "imag"): return getattr(images, convert_complex)() else: raise ValueError( f"convert_complex must be one of 'intensity', 'phase', 'real', 'imag'" )
[docs] def intensity(self) -> Images: """ Calculate the intensity of the wave functions. Returns ------- intensity_images : Images The intensity of the wave functions. """ return self.to_images(convert_complex="intensity")
[docs] def phase(self) -> Images: """ Calculate the phase of the wave functions. Returns ------- phase_images : Images The phase of the wave functions. """ return self.to_images(convert_complex="phase")
[docs] def real(self) -> Images: """ Calculate the real part of the wave functions. Returns ------- real_images : Images The real part of the wave functions. """ return self.to_images(convert_complex="real")
[docs] def imag(self) -> Images: """ Calculate the imaginary part of the wave functions. Returns ------- imaginary_images : Images The imaginary part of the wave functions. """ return self.to_images(convert_complex="imag")
[docs] def downsample( self, max_angle: str | float = "cutoff", gpts: tuple[int, int] = None, normalization: str = "values", ) -> Waves: """ Downsample the wave functions to a lower maximum scattering angle. Parameters ---------- max_angle : {'cutoff', 'valid'} or float, optional Controls the downsampling of the wave functions. ``cutoff`` : Downsample to the antialias cutoff scattering angle (default). ``valid`` : Downsample to the largest rectangle that fits inside the circle with a radius defined by the antialias cutoff scattering angle. float : Downsample to a maximum scattering angle specified by a float [mrad]. gpts : two int, optional Number of grid points of the wave functions after downsampling. If given, `max_angle` is not used. normalization : {'values', 'amplitude'} The normalization parameter determines the preserved quantity after normalization. ``values`` : The pixel-wise values of the wave function are preserved (default). ``amplitude`` : The total amplitude of the wave function is preserved. Returns ------- downsampled_waves : Waves The downsampled wave functions. """ xp = get_array_module(self.array) if gpts is None: gpts = self._gpts_within_angle(max_angle) if self.is_lazy: array = self.array.map_blocks( fft_interpolate, new_shape=gpts, normalization=normalization, chunks=self.array.chunks[:-2] + gpts, meta=xp.array((), dtype=get_dtype(complex=True)), ) else: array = fft_interpolate( self.array, new_shape=gpts, normalization=normalization ) kwargs = self._copy_kwargs(exclude=("array",)) kwargs["array"] = array kwargs["sampling"] = (self.extent[0] / gpts[0], self.extent[1] / gpts[1]) kwargs["metadata"][ "adjusted_antialias_cutoff_gpts" ] = self.antialias_cutoff_gpts return self.__class__(**kwargs)
[docs] def diffraction_patterns( self, max_angle: str | float = "cutoff", # max_frequency: str | float = None, block_direct: bool | float = False, fftshift: bool = True, parity: str = "odd", return_complex: bool = False, renormalize: bool = True, ) -> DiffractionPatterns: """ Calculate the intensity of the wave functions at the diffraction plane. Parameters ---------- max_angle : {'cutoff', 'valid', 'full'} or float Control the maximum scattering angle of the diffraction patterns. ``cutoff`` : Downsample to the antialias cutoff scattering angle (default). ``valid`` : Downsample to the largest rectangle that fits inside the circle with a radius defined by the antialias cutoff scattering angle. ``full`` : The diffraction patterns are not cropped, and hence the antialiased region is included. float : Downsample to a maximum scattering angle specified by a float [mrad]. block_direct : bool or float, optional If True the direct beam is masked (default is False). If given as a float, masks up to that scattering angle [mrad]. fftshift : bool, optional If False, do not shift the direct beam to the center of the diffraction patterns (default is True). parity : {'same', 'even', 'odd', 'none'} The parity of the shape of the diffraction patterns. Default is 'odd', so that the shape of the diffraction pattern is odd with the zero at the middle. renormalize : bool, optional If true and the wave function intensities were normalized to sum to the number of pixels in real space, i.e. the default normalization of a plane wave, the intensities are to sum to one in reciprocal space. return_complex : bool If True, return complex-valued diffraction patterns (i.e. the wave function in reciprocal space) (default is False). Returns ------- diffraction_patterns : DiffractionPatterns The diffraction pattern(s). """ def _diffraction_pattern(array, new_gpts, return_complex, fftshift, normalize): xp = get_array_module(array) if normalize: array = array / float(np.prod(array.shape[-2:])) array = fft2(array, overwrite_x=False) if array.shape[-2:] != new_gpts: array = fft_crop(array, new_shape=array.shape[:-2] + new_gpts) if not return_complex: array = abs2(array) if fftshift: return xp.fft.fftshift(array, axes=(-1, -2)) return array xp = get_array_module(self.array) if max_angle is None: max_angle = "full" new_gpts = self._gpts_within_angle(max_angle, parity=parity) metadata = copy(self.metadata) metadata["label"] = "intensity" metadata["units"] = "arb. unit" normalize = False if renormalize and "normalization" in metadata: if metadata["normalization"] == "values": normalize = True elif metadata["normalization"] != "reciprocal_space": raise RuntimeError( f"normalization {metadata['normalization']} not recognized" ) validate_gpts(new_gpts) if self.is_lazy: dtype = get_dtype(complex=return_complex) pattern = self.array.map_blocks( _diffraction_pattern, new_gpts=new_gpts, fftshift=fftshift, return_complex=return_complex, normalize=normalize, chunks=self.array.chunks[:-2] + ((new_gpts[0],), (new_gpts[1],)), meta=xp.array((), dtype=dtype), ) else: pattern = _diffraction_pattern( self.array, new_gpts=new_gpts, return_complex=return_complex, fftshift=fftshift, normalize=normalize, ) diffraction_patterns = DiffractionPatterns( pattern, sampling=( self.reciprocal_space_sampling[0], self.reciprocal_space_sampling[1], ), fftshift=fftshift, ensemble_axes_metadata=self.ensemble_axes_metadata, metadata=metadata, ) if block_direct: diffraction_patterns = diffraction_patterns.block_direct( radius=block_direct ) return diffraction_patterns
[docs] def apply_ctf( self, ctf: CTF = None, max_batch: int | str = "auto", **kwargs ) -> Waves: """ Apply the aberrations and apertures of a contrast transfer function to the wave functions. Parameters ---------- ctf : CTF, optional Contrast transfer function to be applied. max_batch : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". kwargs : Provide the parameters of the contrast transfer function as keyword arguments (see :class:`.CTF`). Returns ------- aberrated_waves : Waves The wave functions with the contrast transfer function applied. """ if ctf is None: ctf = CTF(**kwargs) if not ctf.accelerator.energy: ctf.accelerator.match(self.accelerator) self.accelerator.match(ctf.accelerator, check_match=True) self.accelerator.check_is_defined() return self.apply_transform(ctf, max_batch=max_batch)
def transition_potential_multislice( self, potential: BasePotential, transition_potentials: BaseTransitionPotential | list[BaseTransitionPotential], detectors: BaseDetector | list[BaseDetector] = None, sites: SliceIndexedAtoms | Atoms = None, ) -> Waves | BaseMeasurements: transition_potentials = _validate_transition_potentials(transition_potentials) potential = _validate_potential(potential, self) measurements = [] for transition_potential in transition_potentials: multislice_transform = MultisliceTransform( potential=potential, detectors=detectors, multislice_func=transition_potential_multislice_and_detect, transition_potential=transition_potential, sites=sites, ) measurements.append(self.apply_transform(transform=multislice_transform)) if len(measurements) > 1: axis_metadata = OrdinalAxis( label="Z, n, l", values=[ ",".join( ( str(transition_potential.metadata["Z"]), str(transition_potential.metadata["n"]), str(transition_potential.metadata["l"]), ) ) for transition_potential in transition_potentials ], tex_label="$Z, n, \ell$", ) measurements = abtem.stack( measurements, axis_metadata, ) else: measurements = measurements[0] return _reduce_ensemble(measurements)
[docs] def multislice( self, potential: BasePotential, detectors: BaseDetector | list[BaseDetector] = None, ) -> Waves: """ Propagate and transmit wave function through the provided potential using the multislice algorithm. When detector(s) are given, output will be the corresponding measurement. Parameters ---------- potential : BasePotential or ASE.Atoms The potential through which to propagate the wave function. Optionally atoms can be directly given. detectors : BaseDetector or list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See `abtem.measurements.detect` for a list of implemented detectors. If not given, returns the wave functions themselves. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ potential = _validate_potential(potential, self) multislice_transform = MultisliceTransform( potential=potential, detectors=detectors ) waves = self.apply_transform(transform=multislice_transform) return _reduce_ensemble(waves)
[docs] def scan( self, scan: BaseScan | np.ndarray | Sequence, potential: Atoms | BasePotential = None, detectors: BaseDetector | Sequence[BaseDetector] = None, max_batch: int | str = "auto", ) -> BaseMeasurements | Waves | list[BaseMeasurements | Waves]: """ Run the multislice algorithm from probe wave functions over the provided scan. Parameters ---------- potential : BasePotential or Atoms The scattering potential. scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. max_batch : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ scan = _validate_scan(scan) waves = self.apply_transform(scan, max_batch=max_batch) if potential is None: return waves measurements = waves.multislice( potential=potential, detectors=detectors, ) return measurements
[docs] def show(self, complex_images: bool = False, **kwargs): """ Show the wave-function intensities. kwargs : Keyword arguments for `abtem.measurements.Images.show`. """ if complex_images: return self.complex_images().show(**kwargs) else: return self.intensity().show(**kwargs)
def _reduce_ensemble(ensemble): if isinstance(ensemble, (list, tuple)): outputs = [_reduce_ensemble(x) for x in ensemble] if hasattr(ensemble, "compute"): outputs = ComputableList(outputs) return outputs squeeze = () for i, axes_metadata in enumerate(ensemble.ensemble_axes_metadata): if axes_metadata._squeeze: squeeze += (i,) output = ensemble.squeeze(squeeze) if hasattr(output, "reduce_ensemble"): output = output.reduce_ensemble() return output class _WavesBuilder(BaseWaves, Ensemble, CopyMixin, EqualityMixin): def __init__(self, ensemble_names: tuple[str, ...], device: str): self._ensemble_names = ensemble_names self._device = device super().__init__() def apply_transform( self, transform, max_batch: int | str = "auto", lazy: bool = True ): return self.build(lazy=lazy).apply_transform(transform, max_batch=max_batch) def check_can_build(self, potential: BasePotential = None): """Check whether the wave functions can be built.""" if potential is not None: self.grid.match(potential) self.grid.check_is_defined() self.accelerator.check_is_defined() @property def _ensembles(self): return {name: getattr(self, name) for name in self._ensemble_names} @property def _ensemble_shapes(self): return tuple(ensemble.ensemble_shape for ensemble in self._ensembles.values()) @property def ensemble_shape(self): """Shape of the ensemble axes of the waves.""" return tuple(itertools.chain(*self._ensemble_shapes)) @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: """List of AxisMetadata of the ensemble axes.""" return list( itertools.chain( *tuple( ensemble.ensemble_axes_metadata for ensemble in self._ensembles.values() ) ) ) def _chunk_splits(self): shapes = (0,) + tuple( len(ensemble_shape) for ensemble_shape in self._ensemble_shapes ) cumulative_shapes = np.cumsum(shapes) return [ (cumulative_shapes[i], cumulative_shapes[i + 1]) for i in range(len(cumulative_shapes) - 1) ] def _arg_splits(self): shapes = (0,) for arg_split, ensemble in zip(self._chunk_splits(), self._ensembles.values()): shapes += (len(ensemble._partition_args(1, lazy=True)),) cumulative_shapes = np.cumsum(shapes) return [ (cumulative_shapes[i], cumulative_shapes[i + 1]) for i in range(len(cumulative_shapes) - 1) ] def _partition_args(self, chunks=(1,), lazy: bool = True): if chunks is None: chunks = self._default_ensemble_chunks chunks = validate_chunks( self.ensemble_shape, chunks, limit="auto", dtype=get_dtype(complex=True) ) chunks = validate_chunks(self.ensemble_shape, chunks) args = () for arg_split, ensemble in zip(self._chunk_splits(), self._ensembles.values()): arg_chunks = chunks[slice(*arg_split)] args += ensemble._partition_args(arg_chunks, lazy=lazy) return args @classmethod def _from_partitioned_args_func( cls, *args, partials, arg_splits, **kwargs, ): args = unpack_blockwise_args(args) for arg_split, (name, partial) in zip(arg_splits, partials.items()): kwargs[name] = partial(*args[slice(*arg_split)]).item() new_probe = cls( **kwargs, ) new_probe = _wrap_with_array(new_probe) return new_probe def _from_partitioned_args(self, *args, **kwargs): partials = { name: ensemble._from_partitioned_args() for name, ensemble in self._ensembles.items() } kwargs = self._copy_kwargs(exclude=tuple(self._ensembles.keys())) return partial( self._from_partitioned_args_func, partials=partials, arg_splits=self._arg_splits(), **kwargs, ) @property def _default_ensemble_chunks(self): return ("auto",) * len(self.ensemble_shape) @property def device(self): """The device where the waves are created.""" return self._device @property def shape(self): """Shape of the waves.""" return self.ensemble_shape + self.base_shape @property def base_shape(self) -> tuple[int, int]: """Shape of the base axes of the waves.""" return self.gpts @property def axes_metadata(self) -> AxesMetadataList: """List of AxisMetadata.""" return AxesMetadataList( self.ensemble_axes_metadata + self.base_axes_metadata, self.shape ) @staticmethod @abstractmethod def _build_waves(waves_builder: _WavesBuilder, wrapped: int): pass @staticmethod def _lazy_build_waves(waves_builder: _WavesBuilder, max_batch: int) -> Waves: if isinstance(max_batch, int): max_batch = int(max_batch * np.prod(waves_builder.gpts)) chunks = waves_builder._default_ensemble_chunks + waves_builder.gpts chunks = validate_chunks( shape=waves_builder.ensemble_shape + waves_builder.gpts, chunks=chunks + (-1, -1), limit=max_batch, dtype=waves_builder.dtype, ) blocks = waves_builder.ensemble_blocks(chunks=chunks[:-2]) xp = get_array_module(waves_builder.device) array = blocks.map_blocks( waves_builder._build_waves, meta=xp.array((), dtype=get_dtype(complex=True)), new_axis=tuple_range(2, len(waves_builder.ensemble_shape)), chunks=blocks.chunks + waves_builder.gpts, wrapped=False, ) return Waves( array, energy=waves_builder.energy, extent=waves_builder.extent, reciprocal_space=False, metadata=waves_builder.metadata, ensemble_axes_metadata=waves_builder.ensemble_axes_metadata, )
[docs] class PlaneWave(_WavesBuilder): """ Represents electron probe wave functions for simulating experiments with a plane-wave probe, such as HRTEM and SAED. Parameters ---------- extent : two float, optional Lateral extent of the wave function [Å]. gpts : two int, optional Number of grid points describing the wave function. sampling : two float, optional Lateral sampling of the wave functions [1 / Å]. If 'gpts' is also given, will be ignored. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. normalize : bool, optional If true, normalizes the wave function such that its reciprocal space intensity sums to one. If false, the wave function takes a value of one everywhere. tilt : two float, optional Small-angle beam tilt [mrad] (default is (0., 0.)). Implemented by shifting the wave functions at every slice. device : str, optional The wave functions are stored on this device ('cpu' or 'gpu'). The default is determined by the user configuration. """
[docs] def __init__( self, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, energy: float = None, normalize: bool = False, tilt: tuple[float, float] = (0.0, 0.0), device: str = None, ): self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling) self._accelerator = Accelerator(energy=energy) self._tilt = _validate_tilt(tilt=tilt) self._normalize = normalize device = validate_device(device) super().__init__(ensemble_names=("tilt",), device=device)
@property def tilt(self): """The small-angle tilt of applied to the Fresnel propagator [mrad].""" return self._tilt @tilt.setter def tilt(self, value): self._tilt = _validate_tilt(value) @property def metadata(self): metadata = { "energy": self.energy, **self._tilt.metadata, "normalization": ("reciprocal_space" if self._normalize else "values"), } return metadata @property def normalize(self): """True if the created waves are normalized in reciprocal space.""" return self._normalize @staticmethod def _build_waves(waves_builder, wrapped: bool = True): if hasattr(waves_builder, "item"): waves_builder = waves_builder.item() xp = get_array_module(waves_builder.device) if waves_builder.normalize: array = xp.full( waves_builder.gpts, 1 / np.prod(waves_builder.gpts), dtype=get_dtype(complex=True), ) else: array = xp.ones(waves_builder.gpts, dtype=get_dtype(complex=True)) waves = Waves( array, energy=waves_builder.energy, extent=waves_builder.extent, metadata=waves_builder.metadata, reciprocal_space=False, ) waves = waves.apply_transform(waves_builder.tilt) if not wrapped: waves = waves.array return waves
[docs] def build( self, lazy: bool = None, max_batch: int | str = "auto", ) -> Waves: """ Build plane-wave wave functions. Parameters ---------- lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If not given, defaults to the setting in the user configuration file. max_batch : int or str, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". Returns ------- plane_waves : Waves The wave functions. """ self.check_can_build() lazy = _validate_lazy(lazy) if not lazy: probes = self._build_waves(self) else: probes = self._lazy_build_waves(self, max_batch) return _reduce_ensemble(probes)
[docs] def multislice( self, potential: BasePotential | Atoms, detectors: BaseDetector = None, max_batch: int | str = "auto", lazy: bool = None, ) -> Waves: """ Run the multislice algorithm, after building the plane-wave wave function as needed. The grid of the wave functions will be set to the grid of the potential. Parameters ---------- potential : BasePotential, Atoms The potential through which to propagate the wave function. Optionally atoms can be directly given. detectors : Detector, list of detectors, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. max_batch : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If None, this defaults to the setting in the user configuration file. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ potential = _validate_potential(potential) lazy = _validate_lazy(lazy) self.check_can_build(potential) if not lazy: probes = self._build_waves(self) else: probes = self._lazy_build_waves(self, max_batch) multislice = MultisliceTransform(potential, detectors) measurements = probes.apply_transform(multislice) return _reduce_ensemble(measurements)
[docs] class Probe(_WavesBuilder): """ Represents electron-probe wave functions for simulating experiments with a convergent beam, such as CBED and STEM. Parameters ---------- semiangle_cutoff : float, optional The cutoff semiangle of the aperture [mrad]. Ignored if a custom aperture is given. extent : float or two float, optional Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal. gpts : two ints, optional Number of grid points describing the wave functions. sampling : two float, optional Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored. energy : float, optional Electron energy [eV]. If not provided, inferred from the wave functions. soft : float, optional Taper the edge of the default aperture [mrad] (default is 2.0). Ignored if a custom aperture is given. tilt : two float, two 1D :class:`.BaseDistribution`, 2D :class:`.BaseDistribution`, optional Small-angle beam tilt [mrad]. This value should generally not exceed one degree. device : str, optional The probe wave functions will be build and stored on this device ('cpu' or 'gpu'). The default is determined by the user configuration. aperture : BaseAperture, optional An optional custom aperture. The provided aperture should be a subtype of :class:`.BaseAperture`. aberrations : dict or Aberrations The phase aberrations as a dictionary. transforms : list of :class:`.WaveTransform` A list of additional wave function transforms which will be applied after creation of the probe wave functions. kwargs : Provide the aberrations as keyword arguments, forwarded to the :class:`.Aberrations`. """
[docs] def __init__( self, semiangle_cutoff: float = None, extent: float | tuple[float, float] = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, energy: float = None, soft: bool = True, tilt: ( tuple[float | BaseDistribution, float | BaseDistribution] | BaseDistribution ) = ( 0.0, 0.0, ), device: str = None, aperture: BaseAperture = None, aberrations: Aberrations | dict = None, positions: BaseScan = None, metadata: dict = None, **kwargs, ): self._accelerator = Accelerator(energy=energy) # if not ((semiangle_cutoff is None) + (aperture is None) == 1): # raise ValueError("provide exactly one of `semiangle_cutoff` or `aperture`") # elif semiangle_cutoff is None: # semiangle_cutoff = 30.0 if semiangle_cutoff is None and aperture is None: semiangle_cutoff = 30 if aperture is None: aperture = Aperture(semiangle_cutoff=semiangle_cutoff, soft=soft) aperture._accelerator = self._accelerator if aberrations is None: aberrations = {} if isinstance(aberrations, dict): aberrations = Aberrations(energy=energy, **aberrations, **kwargs) aberrations._accelerator = self._accelerator self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling) self._aperture = aperture self._aberrations = aberrations self.tilt = tilt self._metadata = {} if metadata is None else metadata if positions is None: positions = abtem.CustomScan(np.zeros((0, 2)), squeeze=True) self._positions = positions self.accelerator.match(self.aperture) ensemble_names = ( "tilt", "aberrations", "aperture", "positions", ) super().__init__(ensemble_names=ensemble_names, device=device)
@property def tilt(self): """The small-angle tilt of applied to the Fresnel propagator [mrad].""" return self._tilt @tilt.setter def tilt(self, value): self._tilt = _validate_tilt(value) @property def positions(self) -> BaseScan: """The position(s) of the probe.""" return self._positions @property def soft(self): """True if the aperture has a soft edge.""" return self.aperture.soft @classmethod def _from_ctf(cls, ctf, **kwargs): return cls( semiangle_cutoff=ctf.semiangle_cutoff, soft=ctf.soft, aberrations=ctf.aberration_coefficients, **kwargs, ) @property def ctf(self): """Contrast transfer function describing the probe.""" return CTF( aberration_coefficients=self.aberrations.aberration_coefficients, semiangle_cutoff=self.semiangle_cutoff, energy=self.energy, ) @property def semiangle_cutoff(self): """The semiangle cutoff [mrad].""" return self.aperture.semiangle_cutoff @semiangle_cutoff.setter def semiangle_cutoff(self, value): self.aperture.semiangle_cutoff = value @property def aperture(self) -> Aperture: """Condenser or probe-forming aperture.""" return self._aperture @aperture.setter def aperture(self, aperture: Aperture): self._aperture = aperture @property def aberrations(self) -> Aberrations: """Phase aberrations of the probe wave functions.""" return self._aberrations @aberrations.setter def aberrations(self, aberrations: Aberrations): self._aberrations = aberrations @property def metadata(self) -> dict: """Metadata describing the probe wave functions.""" return { **self._metadata, "energy": self.energy, **self.aperture.metadata, # **self._tilt.metadata, } @staticmethod def _build_waves(waves_builder, wrapped: bool = True): if hasattr(waves_builder, "item"): waves_builder = waves_builder.item() array = waves_builder.positions._evaluate_kernel(waves_builder) waves = Waves( array, energy=waves_builder.energy, extent=waves_builder.extent, metadata=waves_builder.metadata, reciprocal_space=True, ensemble_axes_metadata=waves_builder.positions.ensemble_axes_metadata, ) waves = waves.apply_transform(waves_builder.aperture) waves = waves.apply_transform(waves_builder.tilt) waves = waves.apply_transform(waves_builder.aberrations) waves = waves.normalize() waves = waves.ensure_real_space() if not wrapped: waves = waves.array return waves def _validate_and_build( self, scan: Sequence | BaseScan = None, max_batch: int | str = "auto", lazy: bool = None, potential=None, ): self.check_can_build(potential) lazy = _validate_lazy(lazy) probe = self.copy() if potential is not None: probe.grid.match(potential) scan = _validate_scan(scan, probe) if isinstance(scan, CustomScan): squeeze = True else: squeeze = False probe._positions = scan if not lazy: probes = self._build_waves(probe) else: probes = self._lazy_build_waves(probe, max_batch) if squeeze: probes = probes.squeeze(axis=(-3,)) return probes
[docs] def build( self, scan: Sequence | BaseScan = None, max_batch: int | str = "auto", lazy: bool = None, ) -> Waves: """ Build probe wave functions at the provided positions. Parameters ---------- scan : array of `xy`-positions or BaseScan, optional Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. max_batch : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If not given, defaults to the setting in the user configuration file. Returns ------- probe_wave_functions : Waves The built probe wave functions. """ return self._validate_and_build(scan=scan, max_batch=max_batch, lazy=lazy)
[docs] def multislice( self, potential: BasePotential | Atoms, scan: tuple | BaseScan = None, detectors: BaseDetector = None, max_batch: int | str = "auto", lazy: bool = None, ) -> BaseMeasurements | Waves | list[BaseMeasurements | Waves]: """ Run the multislice algorithm for probe wave functions at the provided positions. Parameters ---------- potential : BasePotential or Atoms The scattering potential. Optionally atoms can be directly given. scan : array of xy-positions or BaseScan, optional Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector or list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. If not given, defaults to the flexible annular detector. max_batch : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If None, this defaults to the setting in the user configuration file. Returns ------- measurements : BaseMeasurements or Waves or list of BaseMeasurement """ potential = _validate_potential(potential) probes = self._validate_and_build( scan=scan, max_batch=max_batch, lazy=lazy, potential=potential ) multislice = MultisliceTransform(potential, detectors) measurements = probes.apply_transform(multislice) return _reduce_ensemble(measurements)
[docs] def transition_potential_scan( self, potential: BasePotential | Atoms, transition_potentials: BaseTransitionPotential | list[BaseTransitionPotential], scan: tuple | BaseScan = None, detectors: BaseDetector | list[BaseDetector] = None, sites: SliceIndexedAtoms | Atoms = None, # detectors_elastic: BaseDetector | list[BaseDetector] = None, double_channel: bool = True, threshold: float = 1.0, max_batch: int | str = "auto", lazy: bool = None, ) -> Waves | BaseMeasurements: """ Parameters ---------- potential : BasePotential | Atoms The potential to be used for calculating the transition potentials. It can be an instance of `BasePotential` or an `Atoms` object. transition_potentials : BaseTransitionPotential | list[BaseTransitionPotential] The transition potentials to be used for multislice calculations. It can be an instance of `BaseTransitionPotential` or a list of `BaseTransitionPotential` objects. scan : tuple | BaseScan, optional The scan parameters. It can be a tuple or an instance of `BaseScan`. Defaults to None. detectors : BaseDetector | list[BaseDetector], optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. Defaults to None, which sites : SliceIndexedAtoms | Atoms, optional The slice indexed atoms to be used for multislice calculations. It can be an instance of `SliceIndexedAtoms` or an `Atoms` object. Defaults to None. detectors_elastic : BaseDetector | list[BaseDetector], optional The elastic detectors to be used for recording the measurements. It can be an instance of `BaseDetector` or a list of `BaseDetector` objects. Defaults to None. double_channel : bool, optional A boolean indicating whether to use double channel for recording the measurements. Defaults to True. max_batch : int | str, optional The maximum batch size for parallel processing. It can be an integer or the string "auto". Defaults to "auto". lazy : bool, optional A boolean indicating whether to use lazy evaluation for the calculations. Defaults to None. Returns ------- Waves | BaseMeasurements The calculated waves or measurements, depending on the value of `lazy`. """ if scan is None: scan = GridScan() if detectors is None: detectors = FlexibleAnnularDetector() potential = _validate_potential(potential) probes = self._validate_and_build( scan=scan, max_batch=max_batch, lazy=lazy, potential=potential ) multislice = MultisliceTransform( potential, detectors, multislice_func=transition_potential_multislice_and_detect, transition_potential=transition_potentials, sites=sites, double_channel=double_channel, threshold=threshold, ) measurements = probes.apply_transform(multislice) return _reduce_ensemble(measurements)
[docs] def scan( self, potential: Atoms | BasePotential, scan: BaseScan | np.ndarray | Sequence = None, detectors: BaseDetector | Sequence[BaseDetector] = None, max_batch: int | str = "auto", lazy: bool = None, ) -> BaseMeasurements | Waves | list[BaseMeasurements | Waves]: """ Run the multislice algorithm from probe wave functions over the provided scan. Parameters ---------- potential : BasePotential or Atoms The scattering potential. scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. max_batch : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". lazy : bool, optional If True, create the measurements lazily, otherwise, calculate instantly. If None, this defaults to the value set in the configuration file. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ if scan is None: scan = GridScan() if detectors is None: detectors = FlexibleAnnularDetector() measurements = self.multislice( scan=scan, potential=potential, detectors=detectors, lazy=lazy, max_batch=max_batch, ) return measurements
[docs] def profiles(self, angle: float = 0.0) -> RealSpaceLineProfiles: """ Create a line profile through the center of the probe. Parameters ---------- angle : float, optional Angle with respect to the `x`-axis of the line profile [degree]. """ def _line_intersect_rectangle(point0, point1, lower_corner, upper_corner): if point0[0] == point1[0]: return (point0[0], lower_corner[1]), (point0[0], upper_corner[1]) m = (point1[1] - point0[1]) / (point1[0] - point0[0]) def _y(x): return m * (x - point0[0]) + point0[1] def _x(y): return (y - point0[1]) / m + point0[0] if _y(0) < lower_corner[1]: intersect0 = (_x(lower_corner[1]), _y(_x(lower_corner[1]))) else: intersect0 = (0, _y(lower_corner[0])) if _y(upper_corner[0]) > upper_corner[1]: intersect1 = (_x(upper_corner[1]), _y(_x(upper_corner[1]))) else: intersect1 = (upper_corner[0], _y(upper_corner[0])) return intersect0, intersect1 point1 = (self.extent[0] / 2, self.extent[1] / 2) measurement = self.build(point1).intensity() point2 = point1 + np.array( [np.cos(np.pi * angle / 180), np.sin(np.pi * angle / 180)] ) point1, point2 = _line_intersect_rectangle( point1, point2, (0.0, 0.0), self.extent ) return measurement.interpolate_line(point1, point2)
[docs] def show(self, complex_images: bool = False, **kwargs): """ Show the intensity of the probe wave function. Parameters ---------- complex_images : bool If true shows complex images using domain-coloring instead of the intensity. kwargs : Keyword arguments for the :func:`.Images.show` function. """ wave = self.build((self.extent[0] / 2, self.extent[1] / 2)) if complex_images: images = wave.complex_images() else: images = wave.intensity() return images.show(**kwargs)