Source code for abtem.prism.s_matrix

"""Module describing the scattering matrix used in the PRISM algorithm."""
from __future__ import annotations
import inspect
import operator
import warnings
from abc import abstractmethod
from functools import partial, reduce

import dask.array as da
import numpy as np
from ase import Atoms
from dask.graph_manipulation import wait_on

from abtem.array import _validate_lazy, ArrayObject, ComputableList
from abtem.core.axes import (
    OrdinalAxis,
    AxisMetadata,
    ScanAxis,
    UnknownAxis,
    WaveVectorAxis,
)
from abtem.core.backend import get_array_module, cp, validate_device, copy_to_device
from abtem.core.chunks import chunk_ranges, validate_chunks, equal_sized_chunks, Chunks
from abtem.core.complex import complex_exponential
from abtem.core.energy import Accelerator
from abtem.core.ensemble import Ensemble, _wrap_with_array
from abtem.core.grid import Grid, GridUndefinedError
from abtem.core.utils import (
    safe_ceiling_int,
    expand_dims_to_broadcast,
    ensure_list,
    CopyMixin,
    EqualityMixin, tuple_range,
)
from abtem.detectors import (
    BaseDetector,
    _validate_detectors,
    WavesDetector,
    FlexibleAnnularDetector, AnnularDetector,
)
from abtem.measurements import BaseMeasurements
from abtem.multislice import (
    allocate_multislice_measurements,
    multislice_and_detect,
)
from abtem.potentials.iam import BasePotential, _validate_potential
from abtem.prism.utils import (
    plane_waves,
    wrapped_crop_2d,
    minimum_crop,
    batch_crop_2d,
)
from abtem.scan import BaseScan, _validate_scan, GridScan
from abtem.transfer import CTF
from abtem.waves import BaseWaves, _antialias_cutoff_gpts
from abtem.waves import Waves, Probe


def _extract_measurement(array, index):
    if array.size == 0:
        return array

    array = array.item()[index].array
    return array


def _wrap_measurements(measurements):
    return measurements[0] if len(measurements) == 1 else ComputableList(measurements)


def _finalize_lazy_measurements(
    arrays, waves, detectors, extra_ensemble_axes_metadata=None, chunks=None
):

    if extra_ensemble_axes_metadata is None:
        extra_ensemble_axes_metadata = []

    measurements = []
    for i, detector in enumerate(detectors):

        base_shape = detector._out_base_shape(waves)

        if isinstance(detector, AnnularDetector):
            # TODO
            base_shape = ()

        meta = detector._out_meta(waves)

        new_axis = tuple(range(len(arrays.shape), len(arrays.shape) + len(base_shape)))

        if chunks is None:
            chunks = arrays.chunks

        array = arrays.map_blocks(
            _extract_measurement,
            i,
            chunks=chunks + tuple((n,) for n in base_shape),
            new_axis=new_axis,
            meta=meta,
        )

        ensemble_axes_metadata = detector._out_ensemble_axes_metadata(waves)

        base_axes_metadata = detector._out_base_axes_metadata(waves)

        axes_metadata = ensemble_axes_metadata + base_axes_metadata

        metadata = detector._out_metadata(waves)

        cls = detector._out_type(waves)

        axes_metadata = extra_ensemble_axes_metadata + axes_metadata

        measurement = cls.from_array_and_metadata(
            array, axes_metadata=axes_metadata, metadata=metadata
        )

        if hasattr(measurement, "reduce_ensemble"):
            measurement = measurement.reduce_ensemble()

        measurements.append(measurement)

    return measurements


def _round_gpts_to_multiple_of_interpolation(
    gpts: tuple[int, int], interpolation: tuple[int, int]
) -> tuple[int, int]:
    return tuple(n + (-n) % f for f, n in zip(interpolation, gpts))  # noqa


[docs] class BaseSMatrix(BaseWaves): """Base class for scattering matrices.""" _device: str ensemble_axes_metadata: list[AxisMetadata] ensemble_shape: tuple[int, ...] _base_dims = 3 @property def device(self): """The device where the S-Matrix is created and reduced.""" return self._device @property @abstractmethod def interpolation(self): """Interpolation factor in the `x` and `y` directions""" pass @property @abstractmethod def wave_vectors(self) -> np.ndarray: """The wave vectors corresponding to each plane wave.""" pass @property @abstractmethod def semiangle_cutoff(self) -> float: """The radial cutoff of the plane-wave expansion [mrad].""" pass @property @abstractmethod def window_extent(self): """The cropping window extent of the waves.""" pass @property @abstractmethod def window_gpts(self): """The number of grid points describing the cropping window of the wave functions.""" pass def __len__(self) -> int: return len(self.wave_vectors) @property def base_axes_metadata(self) -> list[AxisMetadata]: wave_axes_metadata = super().base_axes_metadata return [ WaveVectorAxis( label="q", values=tuple(tuple(value) for value in self.wave_vectors), ), wave_axes_metadata[0], wave_axes_metadata[1], ]
[docs] def dummy_probes( self, scan: BaseScan = None, ctf: CTF = None, plane: str = "entrance", **kwargs ) -> Probe: # TODO """ A probe or an ensemble of probes equivalent reducing the SMatrix at a single position. Parameters ---------- scan : BaseScan ctf : CTF plane : str Returns ------- dummy_probes : Probes """ if ctf is None: ctf = CTF(energy=self.energy, semiangle_cutoff=self.semiangle_cutoff) elif isinstance(ctf, dict): ctf = CTF(energy=self.energy, semiangle_cutoff=self.semiangle_cutoff, **ctf) elif isinstance(ctf, CTF): ctf = ctf.copy() else: raise ValueError() if plane == "exit": defocus = 0.0 if hasattr(self, "potential"): if self.potential is not None: defocus = self.potential.thickness elif "accumulated_defocus" in self.metadata: defocus = self.metadata["accumulated_defocus"] ctf.defocus = ctf.defocus - defocus ctf.semiangle_cutoff = min(ctf.semiangle_cutoff, self.semiangle_cutoff) default_kwargs = {"device": self.device, "metadata": {**self.metadata}} kwargs = {**default_kwargs, **kwargs} probes = Probe._from_ctf( extent=self.window_extent, gpts=self.window_gpts, ctf=ctf, energy=self.energy, **kwargs, ) if scan is not None: probes._positions = scan return probes
def _validate_interpolation(interpolation: int | tuple[int, int]): if isinstance(interpolation, int): interpolation = (interpolation,) * 2 elif not len(interpolation) == 2: raise ValueError("Interpolation factor must be an integer.") return tuple(interpolation) def _common_kwargs(a, b): a_kwargs = inspect.signature(a).parameters.keys() b_kwargs = inspect.signature(b).parameters.keys() return set(a_kwargs).intersection(b_kwargs) def _pack_wave_vectors(wave_vectors): return tuple( (float(wave_vector[0]), float(wave_vector[1])) for wave_vector in wave_vectors ) def _chunked_axis(s_matrix_array): window_margin = s_matrix_array._window_margin argsort = np.argsort( ( -s_matrix_array.gpts[0] // window_margin[0], -s_matrix_array.gpts[1] // window_margin[1], ) ) return int(argsort[0]), int(argsort[1]) def _chunks_for_multiple_rechunk_reduce(partitions): chunks_1 = () chunk_indices_1 = () for i in range(1, len(partitions) - 1, 3): chunks_1 += (sum(partitions[i - 1 : i + 2]),) chunk_indices_1 += (i - 1,) chunks_1 = chunks_1 + (sum(partitions[i + 2 :]),) assert sum(chunks_1) == sum(partitions) chunks_2 = (sum(partitions[:1]),) chunk_indices_2 = () for i in range(2, len(partitions) - 1, 3): chunks_2 += (sum(partitions[i - 1 : i + 2]),) chunk_indices_2 += (i - 1,) chunks_2 = chunks_2 + (sum(partitions[i + 2 :]),) assert sum(chunks_2) == sum(partitions) chunks_3 = (sum(partitions[:2]),) chunk_indices_3 = () for i in range(3, len(partitions) - 1, 3): chunks_3 += (sum(partitions[i - 1 : i + 2]),) chunk_indices_3 += (i - 1,) chunks_3 = chunks_3 + (sum(partitions[i + 2 :]),) assert sum(chunks_3) == sum(partitions) assert len(chunk_indices_1 + chunk_indices_2 + chunk_indices_3) == ( len(partitions) - 2 ) return (chunks_1, chunks_2, chunks_3), ( chunk_indices_1, chunk_indices_2, chunk_indices_3, ) def _lazy_reduce( array, waves_partial, ensemble_axes_metadata, from_waves_kwargs, scan, ctf, detectors, max_batch_reduction, ): args = (array, ensemble_axes_metadata) waves = waves_partial(args).item() s_matrix = SMatrixArray._from_waves(waves, **from_waves_kwargs) measurements = s_matrix._batch_reduce_to_measurements( scan, ctf, detectors, max_batch_reduction ) arr = np.zeros((1,) * (len(array.shape) - 1), dtype=object) arr.itemset(measurements) return arr def _map_blocks(array, scans, block_indices, window_offset=(0, 0), **kwargs): ctf_chunks = tuple((n,) for n in kwargs["ctf"].ensemble_shape) blocks = () for i, scan in zip(block_indices, scans): block = array.blocks[(slice(None),) * (len(array.shape) - 2) + i] new_chunks = array.chunks[:-3] + ctf_chunks + scan.shape kwargs["from_waves_kwargs"]["window_offset"] = ( window_offset[0] + sum(array.chunks[-2][: i[0]]), window_offset[1] + sum(array.chunks[-1][: i[1]]), ) if len(scan.shape) == 1: drop_axis = (len(array.shape) - 3, len(array.shape) - 1) elif len(scan.shape) == 2: drop_axis = (len(array.shape) - 3,) else: raise NotImplementedError block = da.map_blocks( _lazy_reduce, block, scan=scan, drop_axis=drop_axis, chunks=new_chunks, **kwargs, meta=np.array((), dtype=np.complex64), ) if len(scan) == 0: block = da.zeros( (0,) * len(block.shape), dtype=np.complex64, ) blocks += (block,) return blocks def _tuple_from_indices(*args): temp_list = [None] * (len(args) // 2) for arg1, arg2 in zip(args[::2], args[1::2]): temp_list[arg1] = arg2 return tuple(temp_list) def _multiple_rechunk_reduce(s_matrix_array, scan, detectors, ctf, max_batch_reduction): assert np.all(s_matrix_array.periodic) window_margin = s_matrix_array._window_margin chunked_axis, nochunks_axis = _chunked_axis(s_matrix_array) pad_amounts = _tuple_from_indices( chunked_axis, (window_margin[chunked_axis],) * 2, nochunks_axis, (0, 0) ) s_matrix_array = s_matrix_array._pad(pad_amounts) chunk_size = window_margin[chunked_axis] size = s_matrix_array.shape[-2:][chunked_axis] - window_margin[chunked_axis] * 2 num_chunks = -(size // -chunk_size) partitions = _tuple_from_indices( chunked_axis, (chunk_size,) * num_chunks, nochunks_axis, (s_matrix_array.shape[len(s_matrix_array.shape) - 2 + chunked_axis],), ) chunk_extents = tuple( tuple(((cc[0]) * d, (cc[1]) * d) for cc in c) for c, d in zip(chunk_ranges(partitions), s_matrix_array.sampling) ) scan, scan_chunks = scan._sort_into_extents(chunk_extents) scans = [(indices, scan.item()) for indices, _, scan in scan.generate_blocks(scan_chunks)] partitions = (pad_amounts[chunked_axis][0],) + partitions[chunked_axis] partitions = partitions + ( s_matrix_array.shape[len(s_matrix_array.shape) - 2 + chunked_axis] - sum(partitions), ) (chunks_1, chunks_2, chunks_3), ( scan_indices_1, scan_indices_2, scan_indices_3, ) = _chunks_for_multiple_rechunk_reduce(partitions) chunks_1 = ( s_matrix_array.array.chunks[:-3] + (-1,) + _tuple_from_indices(chunked_axis, chunks_1, nochunks_axis, -1) ) chunks_2 = ( s_matrix_array.array.chunks[:-3] + (-1,) + _tuple_from_indices(chunked_axis, chunks_2, nochunks_axis, -1) ) chunks_3 = ( s_matrix_array.array.chunks[:-3] + (-1,) + _tuple_from_indices(chunked_axis, chunks_3, nochunks_axis, -1) ) shape = tuple(len(c) for c in scan_chunks) blocks = np.zeros(shape, dtype=object) kwargs = { "waves_partial": s_matrix_array.waves._from_partitioned_args(), "ensemble_axes_metadata": s_matrix_array.waves.ensemble_axes_metadata, "from_waves_kwargs": s_matrix_array._copy_kwargs(exclude=("array", "extent")), "ctf": ctf, "detectors": detectors, "max_batch_reduction": max_batch_reduction, } array = s_matrix_array.array.rechunk(chunks_1) window_offset = s_matrix_array.window_offset block_indices = [ _tuple_from_indices(chunked_axis, i, nochunks_axis, 0) for i in range(len(scan_indices_1)) ] new_blocks = _map_blocks( array, [scans[i][1] for i in scan_indices_1], block_indices, window_offset=window_offset, **kwargs, ) for i, block in zip(scan_indices_1, new_blocks): blocks.itemset(scans[i][0], block) if s_matrix_array.ensemble_shape: fp_arrays = [] for i in np.ndindex(s_matrix_array.ensemble_shape): try: fp_new_blocks = tuple(block[i] for block in new_blocks) fp_array = wait_on(array[i], *fp_new_blocks)[0] fp_arrays.append(fp_array) except IndexError: fp_arrays.append(array[i]) array = da.stack(fp_arrays, axis=0) array = array.rechunk(chunks_2) block_indices = [ _tuple_from_indices(chunked_axis, i, nochunks_axis, 0) for i in range(1, len(scan_indices_2) + 1) ] new_blocks = _map_blocks( array, [scans[i][1] for i in scan_indices_2], block_indices, window_offset=window_offset, **kwargs, ) for i, block in zip(scan_indices_2, new_blocks): blocks.itemset(scans[i][0], block) if s_matrix_array.ensemble_shape: fp_arrays = [] for i in np.ndindex(s_matrix_array.ensemble_shape): try: fp_new_blocks = tuple(block[i] for block in new_blocks) fp_array = wait_on(array[i], *fp_new_blocks)[0] fp_arrays.append(fp_array) except IndexError: fp_arrays.append(array[i]) array = da.stack(fp_arrays, axis=0) array = array.rechunk(chunks_3) block_indices = [ _tuple_from_indices(chunked_axis, i, nochunks_axis, 0) for i in range(1, len(scan_indices_3) + 1) ] new_blocks = _map_blocks( array, [scans[i][1] for i in scan_indices_3], block_indices, window_offset=window_offset, **kwargs, ) for i, block in zip(scan_indices_3, new_blocks): blocks.itemset(scans[i][0], block) array = da.block(blocks.tolist()) dummy_probes = s_matrix_array.dummy_probes(scan=scan, ctf=ctf) measurements = _finalize_lazy_measurements( array, waves=dummy_probes, detectors=detectors, extra_ensemble_axes_metadata=s_matrix_array.ensemble_axes_metadata, ) return measurements def _single_rechunk_reduce( s_matrix_array: "SMatrixArray", scan: BaseScan, detectors: list[BaseDetector], ctf: CTF, max_batch_reduction: int, ): chunked_axis, nochunks_axis = _chunked_axis(s_matrix_array) num_chunks = ( s_matrix_array.gpts[chunked_axis] // s_matrix_array._window_margin[chunked_axis] ) chunks = equal_sized_chunks( s_matrix_array.shape[-2:][chunked_axis], num_chunks=num_chunks ) assert np.all(np.array(chunks) > s_matrix_array._window_margin[chunked_axis]) chunks = ( s_matrix_array.array.chunks[:-3] + (-1,) + _tuple_from_indices(chunked_axis, chunks, nochunks_axis, -1) ) array = s_matrix_array._array.rechunk(chunks) assert all(s_matrix_array.periodic) # chunk_extents = tuple( # tuple(((cc[0]) * d, (cc[1]) * d) for cc in c) # for c, d in zip(chunk_ranges(array.chunks[-2:]), s_matrix_array.sampling) # ) chunk_extents_x = tuple( ((cc[0]) * s_matrix_array.sampling[0], (cc[1]) * s_matrix_array.sampling[0]) for cc in array.chunks[-2] ) chunk_extents_y = tuple( ((cc[0]) * s_matrix_array.sampling[1], (cc[1]) * s_matrix_array.sampling[1]) for cc in array.chunks[-1] ) chunk_extents = (chunk_extents_x, chunk_extents_y) scan, scan_chunks = scan._sort_into_extents(chunk_extents) ctf_chunks = tuple((n,) for n in ctf.ensemble_shape) chunks = array.chunks[:-3] + ctf_chunks shape = tuple(len(c) for c, p in zip(scan_chunks, s_matrix_array.periodic)) blocks = np.zeros((1,) * len(array.shape[:-3]) + shape, dtype=object) kwargs = { "waves_partial": s_matrix_array.waves._from_partitioned_args(), "ensemble_axes_metadata": s_matrix_array.waves.ensemble_axes_metadata, "from_waves_kwargs": s_matrix_array._copy_kwargs(exclude=("array", "extent")), "ctf": ctf, "detectors": detectors, "max_batch_reduction": max_batch_reduction, } for indices, _, sub_scan in scan.generate_blocks(scan_chunks): sub_scan = sub_scan.item() if len(sub_scan) == 0: blocks.itemset( (0,) * len(array.shape[:-3]) + indices, da.zeros( (0,) * len(blocks.shape), dtype=np.complex64, ), ) continue slics = (slice(None),) * (len(array.shape) - 2) window_offset = () for i, k in enumerate(indices): if len(array.chunks[-2:][i]) > 1: slics += ([k - 1, k, (k + 1) % len(array.chunks[-2:][i])],) window_offset += ( sum(array.chunks[-2:][i][:k]) - array.chunks[-2:][i][k - 1], ) else: slics += (slice(None),) window_offset += (0,) new_block = array.blocks[slics] new_block = new_block.rechunk(array.chunks[:-2] + (-1, -1)) new_chunks = chunks + sub_scan.shape kwargs["from_waves_kwargs"]["window_offset"] = tuple(window_offset) if len(scan.shape) == 1: drop_axis = (len(array.shape) - 3, len(array.shape) - 1) elif len(scan.shape) == 2: drop_axis = (len(array.shape) - 3,) else: raise NotImplementedError new_block = da.map_blocks( _lazy_reduce, new_block, scan=sub_scan, drop_axis=drop_axis, chunks=new_chunks, **kwargs, meta=np.array((), dtype=np.complex64), ) blocks.itemset((0,) * len(array.shape[:-3]) + indices, new_block) array = da.block(blocks.tolist()) dummy_probes = s_matrix_array.dummy_probes(scan=scan, ctf=ctf) measurements = _finalize_lazy_measurements( array, waves=dummy_probes, detectors=detectors, extra_ensemble_axes_metadata=s_matrix_array.ensemble_axes_metadata, ) return measurements def _no_chunks_reduce( s_matrix_array: "SMatrixArray", scan: BaseScan, detectors: list[BaseDetector], ctf: CTF, max_batch_reduction: int, ): array = s_matrix_array._array.rechunk( s_matrix_array.array.chunks[:-3] + (-1, -1, -1) ) kwargs = { "waves_partial": s_matrix_array.waves._from_partitioned_args(), "ensemble_axes_metadata": s_matrix_array.waves.ensemble_axes_metadata, "from_waves_kwargs": s_matrix_array._copy_kwargs(exclude=("array", "extent")), "ctf": ctf, "detectors": detectors, "max_batch_reduction": max_batch_reduction, } ctf_chunks = tuple((n,) for n in ctf.ensemble_shape) chunks = array.chunks[:-3] + ctf_chunks + scan.shape if len(scan.shape) == 1: drop_axis = (len(array.shape) - 3, len(array.shape) - 1) elif len(scan.shape) == 2: drop_axis = (len(array.shape) - 3,) else: raise NotImplementedError array = da.map_blocks( _lazy_reduce, array, scan=scan, drop_axis=drop_axis, chunks=chunks, **kwargs, meta=np.array((), dtype=np.complex64), ) dummy_probes = s_matrix_array.dummy_probes(scan=scan, ctf=ctf) measurements = _finalize_lazy_measurements( array, waves=dummy_probes, detectors=detectors, extra_ensemble_axes_metadata=s_matrix_array.ensemble_axes_metadata, ) return measurements
[docs] class SMatrixArray(BaseSMatrix, ArrayObject): """ A scattering matrix defined by a given array of dimension 3, where the first indexes the probe plane waves and the latter two are the `y` and `x` scan directions. Parameters ---------- array : np.ndarray Array defining the scattering matrix. Must be 3D or higher, dimensions before the last three dimensions should represent ensemble dimensions, the next dimension indexes the plane waves and the last two dimensions represent the spatial extent of the plane waves. wave_vectors : np.ndarray Array defining the wave vectors corresponding to each plane wave. Must have shape Nx2, where N is equal to the number of plane waves. semiangle_cutoff : float The radial cutoff of the plane-wave expansion [mrad]. energy : float Electron energy [eV]. sampling : one or two float, optional Lateral sampling of wave functions [1 / Å]. Provide only if potential is not given. Will be ignored if 'gpts' is also provided. extent : one or two float, optional Lateral extent of wave functions [Å]. Provide only if potential is not given. interpolation : one or two int, optional Interpolation factor in the `x` and `y` directions (default is 1, ie. no interpolation). If a single value is provided, assumed to be the same for both directions. window_gpts : tuple of int The number of grid points describing the cropping window of the wave functions. window_offset : tuple of int The number of grid points from the origin the cropping windows of the wave functions is displaced. periodic: tuple of bool Specifies whether the SMatrix should be assumed to be periodic along the x and y-axis. device : str, optional The calculations will be carried out on this device ('cpu' or 'gpu'). Default is 'cpu'. The default is determined by the user configuration. ensemble_axes_metadata : list of AxesMetadata Axis metadata for each ensemble axis. The axis metadata must be compatible with the shape of the array. metadata : dict A dictionary defining wave function metadata. All items will be added to the metadata of measurements derived from the waves. """
[docs] def __init__( self, array: np.ndarray, wave_vectors: np.ndarray, semiangle_cutoff: float, energy: float = None, interpolation: int | tuple[int, int] = (1, 1), sampling: float | tuple[float, float] = None, extent: float | tuple[float, float] = None, window_gpts: tuple[int, int] = (0, 0), window_offset: tuple[int, int] = (0, 0), periodic: tuple[bool, bool] = (True, True), device: str = None, ensemble_axes_metadata: list[AxisMetadata] = None, metadata: dict = None, ): self._grid = Grid( extent=extent, gpts=array.shape[-2:], sampling=sampling, lock_gpts=True ) self._accelerator = Accelerator(energy=energy) self._wave_vectors = wave_vectors super().__init__( array=array, ensemble_axes_metadata=ensemble_axes_metadata, metadata=metadata, ) self._semiangle_cutoff = semiangle_cutoff self._window_gpts = tuple(window_gpts) self._window_offset = tuple(window_offset) self._interpolation = _validate_interpolation(interpolation) self._device = device self._periodic = periodic
@classmethod def _pack_kwargs(cls, kwargs): kwargs["wave_vectors"] = _pack_wave_vectors(kwargs["wave_vectors"]) return super()._pack_kwargs(kwargs) @classmethod def _unpack_kwargs(cls, attrs): kwargs = super()._unpack_kwargs(attrs) kwargs["wave_vectors"] = np.array(kwargs["wave_vectors"], dtype=np.float32) return kwargs # kwargs["wave_vectors"] = _pack_wave_vectors(kwargs["wave_vectors"])
[docs] def copy_to_device(self, device: str) -> "SMatrixArray": """Copy SMatrixArray to specified device.""" s_matrix = super().copy_to_device(device) s_matrix._wave_vectors = copy_to_device(self._wave_vectors, device) return s_matrix
@staticmethod def _packed_wave_vectors(wave_vectors): return _pack_wave_vectors(wave_vectors) @property def device(self): """The device on which the SMatrixArray is reduced.""" return self._device @property def storage_device(self): """The device on which the SMatrixArray is stored.""" return super().device @classmethod def _from_waves(cls, waves: Waves, **kwargs): common_kwargs = _common_kwargs(cls, Waves) kwargs.update({key: getattr(waves, key) for key in common_kwargs}) kwargs["ensemble_axes_metadata"] = kwargs["ensemble_axes_metadata"][:-1] return cls(**kwargs) @property def waves(self) -> Waves: """The wave vectors describing each plane wave.""" kwargs = { key: getattr(self, key) for key in _common_kwargs(self.__class__, Waves) } kwargs["ensemble_axes_metadata"] = ( kwargs["ensemble_axes_metadata"] + self.base_axes_metadata[:-2] ) return Waves(**kwargs) def _copy_with_new_waves(self, waves): keys = set( inspect.signature(self.__class__).parameters.keys() ) - _common_kwargs(self.__class__, Waves) kwargs = {key: getattr(self, key) for key in keys} return self._from_waves(waves, **kwargs) @property def periodic(self) -> tuple[bool, bool]: """If True the SMatrix is assumed to be periodic along corresponding axis.""" return self._periodic @property def metadata(self) -> dict: self._metadata["energy"] = self.energy return self._metadata @property def ensemble_axes_metadata(self) -> list[AxisMetadata]: """Axis metadata for each ensemble axis.""" return self._ensemble_axes_metadata @property def ensemble_shape(self) -> tuple[int, int]: return self.array.shape[:-3] @property def interpolation(self) -> tuple[int, int]: return self._interpolation
[docs] def rechunk(self, chunks: Chunks, in_place: bool = True): array = self.array.rechunk(chunks) if in_place: self._array = array return self else: kwargs = self._copy_kwargs(exclude=("array",)) return self.__class__(array, **kwargs)
@property def semiangle_cutoff(self) -> float: """The cutoff semiangle of the plane wave expansion.""" return self._semiangle_cutoff @property def wave_vectors(self) -> np.ndarray: return self._wave_vectors @property def window_gpts(self) -> tuple[int, int]: return self._window_gpts @property def window_extent(self) -> tuple[float, float]: return ( self.window_gpts[0] * self.sampling[0], self.window_gpts[1] * self.sampling[1], ) @property def window_offset(self) -> tuple[float, float]: """The number of grid points from the origin the cropping windows of the wave functions is displaced.""" return self._window_offset
[docs] def multislice(self, potential: BasePotential = None) -> "SMatrixArray": """ Parameters ---------- potential : Returns ------- """ waves = self.waves.multislice(potential) return self._copy_with_new_waves(waves)
def _reduce_to_waves( self, array, positions, position_coefficients, ): xp = get_array_module(self._device) if self._device == "gpu" and isinstance(array, np.ndarray): array = xp.asarray(array) position_coefficients = xp.array(position_coefficients, dtype=xp.complex64) # return xp.zeros(position_coefficients.shape[:2] + self.window_gpts, dtype=xp.complex64) if self.window_gpts != self.gpts: pixel_positions = positions / xp.array(self.waves.sampling) - xp.asarray( self.window_offset ) crop_corner, size, corners = minimum_crop(pixel_positions, self.window_gpts) array = wrapped_crop_2d(array, crop_corner, size) array = xp.tensordot(position_coefficients, array, axes=[-1, -3]) if len(self.waves.shape) > 3: array = xp.moveaxis(array, -3, 0) array = batch_crop_2d(array, corners, self.window_gpts) else: array = xp.tensordot(position_coefficients, array, axes=[-1, -3]) if len(self.waves.shape) > 3: array = xp.moveaxis(array, -3, 0) return array def _calculate_positions_coefficients(self, scan): xp = get_array_module(self.wave_vectors) if isinstance(scan, GridScan): x = xp.asarray(scan._x_coordinates()) y = xp.asarray(scan._y_coordinates()) coefficients = complex_exponential( -2.0 * xp.pi * x[:, None, None] * self.wave_vectors[None, None, :, 0] ) * complex_exponential( -2.0 * xp.pi * y[None, :, None] * self.wave_vectors[None, None, :, 1] ) else: positions = xp.asarray(scan.get_positions()) coefficients = complex_exponential( -2.0 * xp.pi * positions[..., 0, None] * self.wave_vectors[:, 0][None] - 2.0 * xp.pi * positions[..., 1, None] * self.wave_vectors[:, 1][None] ) return coefficients def _calculate_ctf_coefficients(self, ctf): wave_vectors = self.wave_vectors xp = get_array_module(wave_vectors) alpha = ( xp.sqrt(wave_vectors[:, 0] ** 2 + wave_vectors[:, 1] ** 2) * ctf.wavelength ) phi = xp.arctan2(wave_vectors[:, 1], wave_vectors[:, 0]) array = ctf._evaluate_from_angular_grid(alpha, phi) array = array / xp.sqrt((array**2).sum(axis=-1, keepdims=True)) return array def _batch_reduce_to_measurements( self, scan: BaseScan, ctf: CTF, detectors: list[BaseDetector], max_batch_reduction: int, ) -> tuple[BaseMeasurements | Waves, ...]: dummy_probes = self.dummy_probes(scan=scan, ctf=ctf) #print(self.waves.ensemble_axes_metadata[:-1]) measurements = allocate_multislice_measurements( dummy_probes, detectors, extra_ensemble_axes_shape=self.waves.ensemble_shape[:-1], extra_ensemble_axes_metadata=self.waves.ensemble_axes_metadata[:-1], ) xp = get_array_module(self._device) if self._device == "gpu" and isinstance(self.waves.array, np.ndarray): array = cp.asarray(self.waves.array) else: array = self.waves.array for _, ctf_slics, sub_ctf in ctf.generate_blocks(1): sub_ctf = sub_ctf.item() ctf_coefficients = self._calculate_ctf_coefficients(sub_ctf) for _, slics, sub_scan in scan.generate_blocks(max_batch_reduction): sub_scan = sub_scan.item() positions = xp.asarray(sub_scan.get_positions()) positions_coefficients = self._calculate_positions_coefficients( sub_scan ) if ctf_coefficients is not None: ( expanded_ctf_coefficients, positions_coefficients, ) = expand_dims_to_broadcast( ctf_coefficients, positions_coefficients, match_dims=[(-1,), (-1,)], ) coefficients = positions_coefficients * expanded_ctf_coefficients else: coefficients = positions_coefficients ensemble_shape = len(array.shape[:-3]) + len(sub_ctf.ensemble_shape) ensemble_axes_metadata = [] ensemble_axes_metadata.extend( [UnknownAxis() for _ in range(ensemble_shape)] ) ensemble_axes_metadata.extend( [ScanAxis() for _ in range(len(scan.shape))] ) waves_array = self._reduce_to_waves(array, positions, coefficients) waves = Waves( waves_array, sampling=self.sampling, energy=self.energy, ensemble_axes_metadata=ensemble_axes_metadata, metadata=self.metadata, ) indices = ( (slice(None),) * (len(self.waves.shape) - 3) + ctf_slics + slics ) for detector, measurement in zip(detectors, measurements): measurement.array[indices] = detector.detect(waves).array return tuple(measurements) @property def _window_margin(self): return -(self.window_gpts[0] // -2), -(self.window_gpts[1] // -2) def _pad(self, pad_width): array = self.array pad_width = ((0,) * 2,) * len(array.shape[:-2]) + tuple(pad_width) pad_amounts = sum(pad_width[-2]), sum(pad_width[-1]) pad_chunks = array.chunks[:-2] + ( array.shape[-2] + pad_amounts[-2], array.shape[-1] + pad_amounts[-1], ) array = array.map_blocks( np.pad, pad_width=pad_width, meta=array._meta, chunks=pad_chunks, mode="wrap", ) kwargs = self._copy_kwargs(exclude=("array", "extent")) kwargs["periodic"] = tuple( False if pad_amount else periodic for periodic, pad_amount in zip(kwargs["periodic"], pad_amounts) ) kwargs["window_offset"] = tuple( window_offset - pad_amount[0] for window_offset, pad_amount in zip( kwargs["window_offset"], pad_width[-2:] ) ) return self.__class__(array, **kwargs) def _chunks_for_reduction(self): chunks = -(self.gpts[0] // -(self.interpolation[0] * 2)), -( self.gpts[1] // -(self.interpolation[1] * 2) ) num_chunks = self.gpts[0] // chunks[0], self.gpts[1] // chunks[1] if num_chunks[1] > num_chunks[0]: num_chunks = (1, num_chunks[1]) else: num_chunks = (num_chunks[0], 1) chunks = tuple( equal_sized_chunks(n, num_chunks=nsc) for n, nsc in zip(self.shape[-2:], num_chunks) ) if chunks is None: chunks = self.array.chunks[-2:] else: chunks = validate_chunks(self.shape[-2:], chunks) return chunks def _validate_max_batch_reduction( self, scan, max_batch_reduction: int | str = "auto" ): shape = (len(scan),) + self.window_gpts chunks = (max_batch_reduction, -1, -1) return validate_chunks(shape, chunks, dtype=np.dtype("complex64"))[0][0] def _validate_reduction_scheme(self, reduction_scheme): if self.interpolation == (1, 1) and reduction_scheme == "no-chunks": raise NotImplementedError if reduction_scheme == "auto" and max(self.interpolation) <= 2: return "no-chunks" elif reduction_scheme == "auto": return "multiple-rechunk" return reduction_scheme
[docs] def reduce( self, scan: BaseScan = None, ctf: CTF = None, detectors: BaseDetector | list[BaseDetector] = None, max_batch_reduction: int | str = "auto", reduction_scheme: str = "auto", ) -> BaseMeasurements | Waves | list[BaseMeasurements | Waves]: """ Scan the probe across the potential and record a measurement for each detector. Parameters ---------- detectors : list of Detector objects The detectors recording the measurements. scan : Scan object Scan defining the positions of the probe wave functions. ctf: CTF object, optional The probe contrast transfer function. Default is None (aperture is set by the planewave cutoff). max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". rechunk : two int or str, optional Partitioning of the scan. The scattering matrix will be reduced in similarly partitioned chunks. Should be equal to or greater than the interpolation. """ self.accelerator.check_is_defined() if ctf is None: ctf = CTF(semiangle_cutoff=self.semiangle_cutoff) ctf.grid.match(self.dummy_probes()) ctf.accelerator.match(self) if ctf.semiangle_cutoff == np.inf: ctf.semiangle_cutoff = self.semiangle_cutoff if not isinstance(scan, BaseScan): squeeze = (-3,) else: squeeze = () if scan is None: squeeze_scan = True scan = self.extent[0] / 2, self.extent[1] / 2 scan = _validate_scan( scan, Probe._from_ctf(extent=self.extent, ctf=ctf, energy=self.energy) ) detectors = _validate_detectors(detectors) max_batch_reduction = self._validate_max_batch_reduction( scan, max_batch_reduction ) reduction_scheme = self._validate_reduction_scheme(reduction_scheme) if self.is_lazy: if reduction_scheme == "multiple-rechunk": measurements = _multiple_rechunk_reduce( self, scan, detectors, ctf, max_batch_reduction ) elif reduction_scheme == "single-rechunk": measurements = _single_rechunk_reduce( self, scan, detectors, ctf, max_batch_reduction ) elif reduction_scheme == "no-chunks": measurements = _no_chunks_reduce( self, scan, detectors, ctf, max_batch_reduction ) else: raise ValueError() else: measurements = self._batch_reduce_to_measurements( scan, ctf, detectors, max_batch_reduction ) measurements = [measurement.squeeze(squeeze) for measurement in measurements] return _wrap_measurements(measurements)
[docs] def scan( self, scan: BaseScan = None, detectors: BaseDetector | list[BaseDetector] = None, ctf: CTF = None, max_batch_reduction: int | str = "auto", rechunk: tuple[int, int] | str = "auto", ): """ Reduce the SMatrix using coefficients calculated by a BaseScan and a CTF, to obtain the exit wave functions at given initial probe positions and aberrations. Parameters ---------- scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. ctf : CTF Contrast transfer function from used for calculating the expansion coefficients in the reduction of the SMatrix. max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". rechunk : str or tuple of int, optional Parallel reduction of the SMatrix requires rechunking the Dask array from chunking along the expansion axis to chunking over the spatial axes. If given as a tuple of int of length the SMatrix is rechunked to have those chunks. If 'auto' (default) the chunks are taken to be identical to the interpolation factor. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ if scan is None: scan = GridScan() if detectors is None: detectors = [FlexibleAnnularDetector()] return self.reduce( scan=scan, ctf=ctf, detectors=detectors, max_batch_reduction=max_batch_reduction, rechunk=rechunk, )
[docs] class SMatrix(BaseSMatrix, Ensemble, CopyMixin, EqualityMixin): """ The scattering matrix is used for simulating STEM experiments using the PRISM algorithm. Parameters ---------- semiangle_cutoff : float The radial cutoff of the plane-wave expansion [mrad]. energy : float Electron energy [eV]. potential : Atoms or AbstractPotential, optional Atoms or a potential that the scattering matrix represents. If given as atoms, a default potential will be created. If nothing is provided the scattering matrix will represent a vacuum potential, in which case the sampling and extent must be provided. gpts : one or two int, optional Number of grid points describing the scattering matrix. Provide only if potential is not given. sampling : one or two float, optional Lateral sampling of scattering matrix [1 / Å]. Provide only if potential is not given. Will be ignored if 'gpts' is also provided. extent : one or two float, optional Lateral extent of scattering matrix [Å]. Provide only if potential is not given. interpolation : one or two int, optional Interpolation factor in the `x` and `y` directions (default is 1, ie. no interpolation). If a single value is provided, assumed to be the same for both directions. downsample : {'cutoff', 'valid'} or float or bool Controls whether to downsample the scattering matrix after running the multislice algorithm. ``cutoff`` : Downsample to the antialias cutoff scattering angle (default). ``valid`` : Downsample to the largest rectangle that fits inside the circle with a radius defined by the antialias cutoff scattering angle. float : Downsample to a specified maximum scattering angle [mrad]. device : str, optional The calculations will be carried out on this device ('cpu' or 'gpu'). Default is 'cpu'. The default is determined by the user configuration. store_on_host : bool, optional If True, store the scattering matrix in host (cpu) memory so that the necessary memory is transferred as chunks to the device to run calculations (default is False). """
[docs] def __init__( self, semiangle_cutoff: float, energy: float, potential: Atoms | BasePotential = None, gpts: int | tuple[int, int] = None, sampling: float | tuple[float, float] = None, extent: float | tuple[float, float] = None, interpolation: int | tuple[int, int] = 1, downsample: bool | str = "cutoff", # tilt: Tuple[float, float] = (0.0, 0.0), device: str = None, store_on_host: bool = False, ): if downsample is True: downsample = "cutoff" self._device = validate_device(device) self._grid = Grid(extent=extent, gpts=gpts, sampling=sampling) if potential is None: try: self.grid.check_is_defined() except GridUndefinedError: raise ValueError("Provide a potential or provide 'extent' and 'gpts'.") else: potential = _validate_potential(potential) self.grid.match(potential) self._grid = potential.grid self._potential = potential self._interpolation = _validate_interpolation(interpolation) self._semiangle_cutoff = semiangle_cutoff self._downsample = downsample self._accelerator = Accelerator(energy=energy) # self._beam_tilt = BeamTilt(tilt=tilt) self._store_on_host = store_on_host assert semiangle_cutoff > 0.0 if not all(n % f == 0 for f, n in zip(self.interpolation, self.gpts)): warnings.warn( "The interpolation factor does not exactly divide 'gpts', normalization may not be exactly preserved." )
@property def base_shape(self) -> tuple[int, int, int]: """Shape of the base axes of the SMatrix.""" return len(self), self.gpts[0], self.gpts[1] @property def tilt(self): """The small-angle tilt of applied to the Fresnel propagator [mrad].""" return 0.0, 0.0
[docs] def round_gpts_to_interpolation(self) -> SMatrix: """ Round the gpts of the SMatrix to the closest multiple of the interpolation factor. Returns ------- s_matrix_with_rounded_gpts : SMatrix """ rounded = _round_gpts_to_multiple_of_interpolation( self.gpts, self.interpolation ) if rounded == self.gpts: return self self.gpts = rounded return self
@property def downsample(self) -> str | bool: """How to downsample the scattering matrix after running the multislice algorithm.""" return self._downsample @property def store_on_host(self) -> bool: """Store the SMatrix in host memory. The reduction may still be calculated on the device.""" return self._store_on_host @property def metadata(self): return {"energy": self.energy} @property def shape(self) -> tuple[int, ...]: """Shape of the SMatrix.""" return self.ensemble_shape + (len(self),) + self.gpts @property def ensemble_shape(self) -> tuple[int, ...]: """Shape of the SMatrix ensemble axes.""" if self.potential is None: return () else: return self.potential.ensemble_shape @property def ensemble_axes_metadata(self): """Axis metadata for each ensemble axis.""" if self.potential is None: return [] else: return self.potential.ensemble_axes_metadata @property def wave_vectors(self) -> np.ndarray: self.grid.check_is_defined() self.accelerator.check_is_defined() dummy_probes = self.dummy_probes(device="cpu") aperture = dummy_probes.aperture._evaluate_kernel(dummy_probes) indices = np.where(aperture > 0.0) n = np.fft.fftfreq(aperture.shape[0], d=1 / aperture.shape[0])[indices[0]] m = np.fft.fftfreq(aperture.shape[1], d=1 / aperture.shape[1])[indices[1]] w, h = self.extent kx = n / w * np.float32(self.interpolation[0]) ky = m / h * np.float32(self.interpolation[1]) xp = get_array_module(self.device) return xp.asarray([kx, ky]).T @property def potential(self) -> BasePotential: """The potential described by the SMatrix.""" return self._potential @potential.setter def potential(self, potential: BasePotential): self._potential = potential self._grid = potential.grid @property def semiangle_cutoff(self) -> float: """Plane-wave expansion cutoff.""" return self._semiangle_cutoff @semiangle_cutoff.setter def semiangle_cutoff(self, value: float): self._semiangle_cutoff = value @property def interpolation(self) -> tuple[int, int]: return self._interpolation def _wave_vector_chunks(self, max_batch): if isinstance(max_batch, int): max_batch = max_batch * reduce(operator.mul, self.gpts) chunks = validate_chunks( shape=(len(self),) + self.gpts, chunks=("auto", -1, -1), limit=max_batch, dtype=np.dtype("complex64"), device=self.device, ) return chunks @property def downsampled_gpts(self) -> tuple[int, int]: """The gpts of the SMatrix after downsampling.""" if self.downsample: downsampled_gpts = self._gpts_within_angle(self.downsample) rounded = _round_gpts_to_multiple_of_interpolation( downsampled_gpts, self.interpolation ) return rounded else: return self.gpts @property def window_gpts(self): return ( safe_ceiling_int(self.downsampled_gpts[0] / self.interpolation[0]), safe_ceiling_int(self.downsampled_gpts[1] / self.interpolation[1]), ) @property def window_extent(self): sampling = ( self.extent[0] / self.downsampled_gpts[0], self.extent[1] / self.downsampled_gpts[1], ) return ( self.window_gpts[0] * sampling[0], self.window_gpts[1] * sampling[1], ) # @staticmethod # def _wrapped_build_s_matrix(*args, s_matrix_partial): # s_matrix = s_matrix_partial(*tuple(arg.item() for arg in args[:-1])) # # wave_vector_range = slice(*np.squeeze(args[-1])) # array = s_matrix._build_s_matrix(wave_vector_range).array # return array # # def _s_matrix_partial(self): # def s_matrix(*args, potential_partial, **kwargs): # if potential_partial is not None: # potential = potential_partial(*args + (np.array([None], dtype=object),)) # else: # potential = None # return SMatrix(potential=potential, **kwargs) # # potential_partial = ( # self.potential._from_partitioned_args() # if self.potential is not None # else None # ) # return partial( # s_matrix, # potential_partial=potential_partial, # **self._copy_kwargs(exclude=("potential",)), # )
[docs] def multislice( self, potential=None, lazy: bool = None, max_batch: int | str = "auto", ): """ Parameters ---------- potential lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If not given, defaults to the setting in the user configuration file. max_batch : int or str, optional The number of expansion plane waves in each run of the multislice algorithm. Returns ------- """ s_matrix = self.__class__( potential=potential, **self._copy_kwargs(exclude=("potential",)) ) return s_matrix.build(lazy=lazy, max_batch=max_batch)
@property def _default_ensemble_chunks(self): return self.potential._default_ensemble_chunks def _partition_args(self, chunks=(1,), lazy: bool = True): if self.potential is not None: return self.potential._partition_args(chunks, lazy=lazy) else: array = np.empty((1,), dtype=object) if lazy: array = da.from_array(array, chunks=1) return (array,) @staticmethod def _s_matrix(*args, potential_partial, **kwargs): potential = potential_partial(*args).item() s_matrix = SMatrix(potential=potential, **kwargs) return _wrap_with_array(s_matrix) def _from_partitioned_args(self, *args, **kwargs): if self.potential is not None: potential_partial = self.potential._from_partitioned_args() kwargs = self._copy_kwargs(exclude=("potential", "sampling", "extent")) else: potential_partial = lambda *args, **kwargs: _wrap_with_array(None, 1) kwargs = self._copy_kwargs(exclude=("potential",)) return partial(self._s_matrix, potential_partial=potential_partial, **kwargs) @staticmethod def _wave_vector_blocks(wave_vector_chunks, lazy: bool = True): wave_vector_blocks = chunk_ranges(wave_vector_chunks)[0] array = np.zeros(len(wave_vector_blocks), dtype=object) for i, wave_vector_block in enumerate(wave_vector_blocks): array.itemset(i, wave_vector_block) if lazy: array = da.from_array(array, chunks=1) return array @staticmethod def _build_s_matrix(s_matrix, wave_vector_range=slice(None)): if isinstance(s_matrix, np.ndarray): s_matrix = s_matrix.item() if isinstance(wave_vector_range, np.ndarray): wave_vector_range = slice(*wave_vector_range.item()) xp = get_array_module(s_matrix.device) wave_vectors = xp.asarray(s_matrix.wave_vectors, dtype=xp.float32) array = plane_waves( wave_vectors[wave_vector_range], s_matrix.extent, s_matrix.gpts ) array *= np.prod(s_matrix.interpolation) / np.prod(array.shape[-2:]) waves = Waves( array, energy=s_matrix.energy, extent=s_matrix.extent, ensemble_axes_metadata=[ OrdinalAxis(values=wave_vectors[wave_vector_range]) ], ) if s_matrix.potential is not None: waves = multislice_and_detect(waves, s_matrix.potential, [WavesDetector()])[ 0 ] if s_matrix.downsampled_gpts != s_matrix.gpts: waves.metadata[ "adjusted_antialias_cutoff_gpts" ] = waves.antialias_cutoff_gpts waves = waves.downsample( gpts=s_matrix.downsampled_gpts, normalization="intensity", ) if s_matrix.store_on_host and s_matrix.device == "gpu": waves = waves.to_cpu() return waves.array
[docs] def build( self, lazy: bool = None, max_batch: int | str = "auto", bound: bool = None ) -> SMatrixArray: """ Build the plane waves of the scattering matrix and propagate them through the potential using the multislice algorithm. Parameters ---------- lazy : bool, optional If True, create the wave functions lazily, otherwise, calculate instantly. If not given, defaults to the setting in the user configuration file. max_batch : int or str, optional The number of expansion plane waves in each run of the multislice algorithm. Returns ------- s_matrix_array : SMatrixArray The built scattering matrix. """ lazy = _validate_lazy(lazy) downsampled_gpts = self.downsampled_gpts s_matrix_blocks = self.ensemble_blocks(1) xp = get_array_module(self.device) wave_vector_chunks = self._wave_vector_chunks(max_batch) if lazy: wave_vector_blocks = self._wave_vector_blocks( wave_vector_chunks, lazy=False ) wave_vector_blocks = np.tile( wave_vector_blocks[None], (len(s_matrix_blocks), 1) ) wave_vector_blocks = da.from_array(wave_vector_blocks, chunks=1) from dask.graph_manipulation import bind if bound is not None: wave_vector_blocks = bind(wave_vector_blocks, bound) adjust_chunks = { 1: wave_vector_chunks[0], 2: (downsampled_gpts[0],), 3: (downsampled_gpts[1],), } symbols = (0, 1, 2, 3) if self.potential is None or not self.potential.ensemble_shape: symbols = symbols[1:] array = da.blockwise( self._build_s_matrix, symbols, s_matrix_blocks, (0,), wave_vector_blocks[..., None, None], (0, 1, 2, 3), concatenate=True, adjust_chunks=adjust_chunks, meta=xp.array((), dtype=np.complex64), ) else: wave_vector_blocks = self._wave_vector_blocks( wave_vector_chunks, lazy=False ) if self.store_on_host: array = np.zeros( self.ensemble_shape + (len(self),) + self.downsampled_gpts, dtype=np.complex64, ) else: array = xp.zeros( self.ensemble_shape + (len(self),) + self.downsampled_gpts, dtype=np.complex64, ) for i, _, s_matrix in self.generate_blocks(1): s_matrix = s_matrix.item() for start, stop in wave_vector_blocks: items = (slice(start, stop),) if self.ensemble_shape: items = i + items new_array = self._build_s_matrix(s_matrix, slice(start, stop)) if self.store_on_host: new_array = xp.asnumpy(new_array) array[items] = new_array waves = Waves( array, energy=self.energy, extent=self.extent, ensemble_axes_metadata=self.ensemble_axes_metadata + self.base_axes_metadata[:1], ) if self.downsampled_gpts != self.gpts: waves.metadata["adjusted_antialias_cutoff_gpts"] = _antialias_cutoff_gpts( self.window_gpts, self.sampling ) s_matrix_array = SMatrixArray._from_waves( waves, wave_vectors=self.wave_vectors, interpolation=self.interpolation, semiangle_cutoff=self.semiangle_cutoff, window_gpts=self.window_gpts, device=self.device, ) return s_matrix_array
[docs] def scan( self, scan: np.ndarray | BaseScan = None, detectors: BaseDetector | list[BaseDetector] = None, ctf: CTF | dict = None, max_batch_multislice: str | int = "auto", max_batch_reduction: str | int = "auto", reduction_scheme: str = "auto", disable_s_matrix_chunks: bool = "auto", lazy: bool = None, ) -> BaseMeasurements | Waves | list[BaseMeasurements | Waves]: """ Run the multislice algorithm, then reduce the SMatrix using coefficients calculated by a BaseScan and a CTF, to obtain the exit wave functions at given initial probe positions and aberrations. Parameters ---------- scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. ctf : CTF Contrast transfer function from used for calculating the expansion coefficients in the reduction of the SMatrix. max_batch_multislice : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". reduction_scheme : str or tuple of int, optional Parallel reduction of the SMatrix requires rechunking the Dask array from chunking along the expansion axis to chunking over the spatial axes. If given as a tuple of int of length the SMatrix is rechunked to have those chunks. If 'auto' (default) the chunks are taken to be identical to the interpolation factor. disable_s_matrix_chunks : bool, optional If True, each S-Matrix is kept as a single chunk, thus lowering the communication overhead, but providing fewer opportunities for parallelization. lazy : bool, optional If True, create the measurements lazily, otherwise, calculate instantly. If None, this defaults to the value set in the configuration file. Returns ------- detected_waves : BaseMeasurements or list of BaseMeasurement The detected measurement (if detector(s) given). exit_waves : Waves Wave functions at the exit plane(s) of the potential (if no detector(s) given). """ if scan is None: scan = GridScan( start=(0, 0), end=self.extent, sampling=self.dummy_probes().aperture.nyquist_sampling, ) if detectors is None: detectors = FlexibleAnnularDetector() return self.reduce( scan=scan, detectors=detectors, max_batch_reduction=max_batch_reduction, max_batch_multislice=max_batch_multislice, ctf=ctf, reduction_scheme=reduction_scheme, disable_s_matrix_chunks=disable_s_matrix_chunks, lazy=lazy, )
def _eager_build_s_matrix_detect(self, scan, ctf, detectors, squeeze): extra_ensemble_axes_shape = () extra_ensemble_axes_metadata = [] for shape, axis_metadata in zip( self.ensemble_shape, self.ensemble_axes_metadata ): extra_ensemble_axes_metadata += [axis_metadata] extra_ensemble_axes_shape += (shape,) if axis_metadata._ensemble_mean: extra_ensemble_axes_shape = (1,) + extra_ensemble_axes_shape[1:] if self.potential is not None and len(self.potential.exit_planes) > 1: extra_ensemble_axes_shape = extra_ensemble_axes_shape + ( len(self.potential.exit_planes), ) extra_ensemble_axes_metadata = extra_ensemble_axes_metadata + [ self.potential.base_axes_metadata[0] ] detectors = _validate_detectors(detectors) if self.ensemble_shape: measurements = allocate_multislice_measurements( self.build(lazy=True).dummy_probes(scan, ctf), detectors, extra_ensemble_axes_shape, extra_ensemble_axes_metadata, ) else: measurements = None for i, _, s_matrix in self.generate_blocks(1): s_matrix = s_matrix.item() s_matrix_array = s_matrix.build(lazy=False) new_measurements = s_matrix_array.reduce( scan=scan, detectors=detectors, ctf=ctf ) new_measurements = ensure_list(new_measurements) if measurements is None: measurements = new_measurements else: for measurement, new_measurement in zip(measurements, new_measurements): if measurement.axes_metadata[0]._ensemble_mean: measurement.array[:] += new_measurement.array else: measurement.array[i] = new_measurement.array # measurements = list(measurements.values()) for i, measurement in enumerate(measurements): if ( hasattr(measurement.axes_metadata[0], "_ensemble_mean") and measurement.axes_metadata[0]._ensemble_mean ) and squeeze: measurements[i] = measurement.squeeze((0,)) return measurements @staticmethod def _lazy_build_s_matrix_detect(s_matrix, scan, ctf, detectors): s_matrix = s_matrix.item() measurements = s_matrix._eager_build_s_matrix_detect( scan=scan, ctf=ctf, detectors=detectors, squeeze=False ) # measurements = ensure_list(measurements) array = np.zeros((1,) + (1,) * len(scan.shape), dtype=object) array.itemset(0, measurements) return array
[docs] def reduce( self, scan: np.ndarray | BaseScan = None, detectors: BaseDetector | list[BaseDetector] = None, ctf: CTF | dict = None, reduction_scheme: str = "auto", max_batch_multislice: str | int = "auto", max_batch_reduction: str | int = "auto", disable_s_matrix_chunks: bool = "auto", lazy: bool = None, ) -> BaseMeasurements | Waves | list[BaseMeasurements | Waves]: """ Run the multislice algorithm, then reduce the SMatrix using coefficients calculated by a BaseScan and a CTF, to obtain the exit wave functions at given initial probe positions and aberrations. Parameters ---------- scan : BaseScan Positions of the probe wave functions. If not given, scans across the entire potential at Nyquist sampling. detectors : BaseDetector, list of BaseDetector, optional A detector or a list of detectors defining how the wave functions should be converted to measurements after running the multislice algorithm. See abtem.measurements.detect for a list of implemented detectors. ctf : CTF Contrast transfer function from used for calculating the expansion coefficients in the reduction of the SMatrix. max_batch_multislice : int, optional The number of wave functions in each chunk of the Dask array. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". max_batch_reduction : int or str, optional Number of positions per reduction operation. A large number of positions better utilize thread parallelization, but requires more memory and floating point operations. If 'auto' (default), the batch size is automatically chosen based on the abtem user configuration settings "dask.chunk-size" and "dask.chunk-size-gpu". reduction_scheme : str, optional Parallel reduction of the SMatrix requires rechunking the Dask array from chunking along the expansion axis to chunking over the spatial axes. If given as a tuple of int of length the SMatrix is rechunked to have those chunks. If 'auto' (default) the chunks are taken to be identical to the interpolation factor. disable_s_matrix_chunks : bool, optional If True, each S-Matrix is kept as a single chunk, thus lowering the communication overhead, but providing fewer opportunities for parallelization. lazy : bool, optional If True, create the measurements lazily, otherwise, calculate instantly. If None, this defaults to the value set in the configuration file. Returns ------- measurements : BaseMeasurements or Waves or list of BaseMeasurements or list of Waves The detected measurement (if detector(s) given). """ detectors = _validate_detectors(detectors) if scan is None: scan = (self.extent[0] / 2, self.extent[1] / 2) lazy = _validate_lazy(lazy) if self.device == "gpu" and disable_s_matrix_chunks == "auto": disable_s_matrix_chunks = True elif disable_s_matrix_chunks == "auto": disable_s_matrix_chunks = False if not lazy: scan = _validate_scan(scan, self) measurements = self._eager_build_s_matrix_detect( scan, ctf, detectors, squeeze=True ) return _wrap_measurements(measurements) if disable_s_matrix_chunks: scan = _validate_scan(scan, self) blocks = self.ensemble_blocks(1) chunks = blocks.chunks + scan.shape new_axis = tuple_range(offset=len(blocks.shape), length=len(scan.shape)) drop_axis = () if len(self.ensemble_shape) == 0: drop_axis = (0,) chunks = chunks[1:] new_axis = tuple_range(offset=len(blocks.shape) - 1, length=len(scan.shape)) arrays = blocks.map_blocks( self._lazy_build_s_matrix_detect, drop_axis=drop_axis, new_axis=new_axis, chunks=chunks, scan=scan, ctf=ctf, detectors=detectors, meta=np.array((), dtype=object), ) waves = self.build(lazy=True).dummy_probes(scan=scan) extra_axes_metadata = [] if self.potential is not None: extra_axes_metadata = self.potential.ensemble_axes_metadata measurements = _finalize_lazy_measurements( arrays, waves, detectors, extra_axes_metadata ) return _wrap_measurements(measurements) s_matrix_array = self.build(max_batch=max_batch_multislice, lazy=lazy) return s_matrix_array.reduce( scan=scan, detectors=detectors, reduction_scheme=reduction_scheme, max_batch_reduction=max_batch_reduction, ctf=ctf, )