"""Module for describing electrostatic potentials using the independent atom model."""
from __future__ import annotations
import warnings
from abc import abstractmethod, ABCMeta
from functools import partial
from functools import reduce
from numbers import Number
from operator import mul
from typing import Sequence, TYPE_CHECKING, Type
import dask
import dask.array as da
import numpy as np
from ase import Atoms
from ase.cell import Cell
from ase.data import chemical_symbols
from abtem.array import ArrayObject, _validate_lazy
from abtem.atoms import (
is_cell_orthogonal,
orthogonalize_cell,
best_orthogonal_cell,
cut_cell,
pad_atoms,
plane_to_axes,
rotate_atoms_to_plane,
)
from abtem.core.axes import (
RealSpaceAxis,
_find_axes_type,
)
from abtem.core.axes import ThicknessAxis, FrozenPhononsAxis, AxisMetadata
from abtem.core.backend import get_array_module, validate_device
from abtem.core.chunks import generate_chunks, Chunks, chunk_ranges
from abtem.core.chunks import validate_chunks
from abtem.core.complex import complex_exponential
from abtem.core.energy import HasAcceleratorMixin, Accelerator, energy2sigma
from abtem.core.ensemble import Ensemble, _wrap_with_array, unpack_blockwise_args
from abtem.core.grid import Grid, HasGridMixin
from abtem.core.utils import EqualityMixin, CopyMixin, get_dtype
from abtem.inelastic.phonons import (
BaseFrozenPhonons,
DummyFrozenPhonons,
_validate_seeds,
AtomsEnsemble,
)
from abtem.integrals import (
ScatteringFactorProjectionIntegrals,
QuadratureProjectionIntegrals,
)
from abtem.measurements import Images
from abtem.slicing import (
_validate_slice_thickness,
SliceIndexedAtoms,
SlicedAtoms,
BaseSlicedAtoms,
)
if TYPE_CHECKING:
from abtem.waves import Waves, BaseWaves
from abtem.parametrizations import Parametrization
from abtem.integrals import FieldIntegrator
[docs]
class BaseField(Ensemble, HasGridMixin, EqualityMixin, CopyMixin, metaclass=ABCMeta):
@property
@abstractmethod
def base_shape(self) -> tuple[int, ...]:
pass
@property
@abstractmethod
def num_configurations(self):
"""Number of frozen phonons in the ensemble of potentials."""
pass
@property
@abstractmethod
def base_axes_metadata(self):
pass
def _get_exit_planes_axes_metadata(self):
return ThicknessAxis(label="z", values=tuple(self.exit_thicknesses))
@property
@abstractmethod
def exit_planes(self) -> tuple[int]:
"""The "exit planes" of the potential. The indices of slices where a measurement is returned."""
pass
@property
def _exit_plane_after(self):
exit_plane_index = 0
exit_planes = self.exit_planes
if exit_planes[0] == -1:
exit_plane_index += 1
is_exit_plane = np.zeros(len(self), dtype=bool)
for i in range(len(is_exit_plane)):
if i == exit_planes[exit_plane_index]:
is_exit_plane[i] = True
exit_plane_index += 1
return is_exit_plane
@property
def exit_thicknesses(self) -> tuple[float]:
"""The "exit thicknesses" of the potential. The thicknesses in the potential where a measurement is returned."""
thicknesses = np.cumsum(self.slice_thickness)
if self.exit_planes[0] == -1:
return tuple(
np.insert(
thicknesses[np.array(self.exit_planes[1:], dtype=int)], 0, 0.0
)
)
else:
return tuple(thicknesses[np.array(self.exit_planes, dtype=int)])
@property
def num_exit_planes(self) -> int:
"""Number of exit planes."""
return len(self.exit_planes)
@abstractmethod
def generate_slices(self, first_slice: int = 0, last_slice: int = None):
pass
@abstractmethod
def build(
self,
first_slice: int = 0,
last_slice: int = None,
chunks: int = 1,
lazy: bool = None,
):
pass
def __len__(self) -> int:
return self.num_slices
@property
def num_slices(self) -> int:
"""Number of projected potential slices."""
return len(self.slice_thickness)
@property
@abstractmethod
def slice_thickness(self) -> np.ndarray:
"""Slice thicknesses for each slice."""
pass
@property
def slice_limits(self) -> list[tuple[float, float]]:
"""The entrance and exit thicknesses of each slice [Å]."""
cumulative_thickness = np.cumsum(np.concatenate(((0,), self.slice_thickness)))
return [
(cumulative_thickness[i], cumulative_thickness[i + 1])
for i in range(len(cumulative_thickness) - 1)
]
@property
def thickness(self) -> float:
"""Thickness of the potential [Å]."""
return sum(self.slice_thickness)
def __iter__(self):
for slic in self.generate_slices():
yield slic
[docs]
def project(self) -> Images:
"""
Sum of the potential slices as an image.
Returns
-------
projected : Images
The projected potential.
"""
return self.build().project()
@property
def _default_ensemble_chunks(self) -> tuple:
return validate_chunks(self.ensemble_shape, (1,) * len(self.ensemble_shape))
[docs]
def to_images(self):
"""
Converts the potential to an ensemble of images.
Returns
-------
image_ensemble : Images
The potential slices as images.
"""
return self.build().to_images()
[docs]
def show(self, project: bool = True, **kwargs):
"""
Show the potential projection. This requires building all potential slices.
Parameters
----------
project : bool, optional
Show the projected potential (True, default) or show all potential slices. It is recommended to index a
subset of the potential slices when this keyword set to False.
kwargs :
Additional keyword arguments for the show method of :class:`.Images`.
"""
if project:
return self.project().show(**kwargs)
else:
if "explode" not in kwargs.keys():
kwargs["explode"] = True
return self.to_images().show(**kwargs)
[docs]
class BasePotential(BaseField):
"""Base class of all potentials. Documented in the subclasses."""
@property
def base_shape(self):
"""Shape of the base axes of the potential."""
return (self.num_slices,) + self.gpts
@property
def base_axes_metadata(self):
"""List of AxisMetadata for the base axes."""
return [
ThicknessAxis(
label="z", values=tuple(np.cumsum(self.slice_thickness)), units="Å"
),
RealSpaceAxis(
label="x", sampling=self.sampling[0], units="Å", endpoint=False
),
RealSpaceAxis(
label="y", sampling=self.sampling[1], units="Å", endpoint=False
),
]
def _validate_potential(
potential: Atoms | BasePotential, waves: BaseWaves = None
) -> BasePotential:
if isinstance(potential, (Atoms, BaseFrozenPhonons)):
device = None
if waves is not None:
device = waves.device
potential = Potential(potential, device=device)
# elif not isinstance(potential, BasePotential):
# raise ValueError()
if waves is not None and potential is not None:
potential.grid.match(waves)
return potential
def _validate_exit_planes(exit_planes, num_slices):
if isinstance(exit_planes, int):
if exit_planes >= num_slices:
return (num_slices - 1,)
exit_planes = list(range(exit_planes - 1, num_slices, exit_planes))
if exit_planes[-1] != (num_slices - 1):
exit_planes.append(num_slices - 1)
exit_planes = (-1,) + tuple(exit_planes)
elif exit_planes is None:
exit_planes = (num_slices - 1,)
return exit_planes
def _require_cell_transform(cell, box, plane, origin):
if box == tuple(np.diag(cell)):
return False
if not is_cell_orthogonal(cell):
return True
if box is not None:
return True
if plane != "xy":
return True
if origin != (0.0, 0.0, 0.0):
return True
return False
class _FieldBuilder(BaseField):
def __init__(
self,
array_object: Type[FieldArray],
slice_thickness: float | tuple[float, ...],
exit_planes: int | tuple[int, ...],
cell: np.ndarray | Cell,
gpts: int | tuple[int, int] = None,
sampling: float | tuple[float, float] = None,
box: tuple[float, float, float] = None,
plane: (
str | tuple[tuple[float, float, float], tuple[float, float, float]]
) = "xy",
origin: tuple[float, float, float] = (0.0, 0.0, 0.0),
periodic: bool = True,
device: str = None,
):
self._array_object = array_object
if _require_cell_transform(cell, box=box, plane=plane, origin=origin):
axes = plane_to_axes(plane)
cell = cell[:, list(axes)]
box = tuple(best_orthogonal_cell(cell))
elif box is None:
box = tuple(np.diag(cell))
self._grid = Grid(
extent=box[:2], gpts=gpts, sampling=sampling, lock_extent=True
)
self._device = validate_device(device)
self._box = box
self._plane = plane
self._origin = origin
self._periodic = periodic
self._slice_thickness = _validate_slice_thickness(
slice_thickness, thickness=box[2]
)
self._exit_planes = _validate_exit_planes(
exit_planes, len(self._slice_thickness)
)
@property
def slice_thickness(self) -> tuple[float, ...]:
return self._slice_thickness
@property
def exit_planes(self) -> tuple[int]:
return self._exit_planes
@property
def device(self) -> str:
"""The device where the potential is created."""
return self._device
@property
def periodic(self) -> bool:
"""Specifies whether the potential is periodic."""
return self._periodic
@property
def plane(self) -> str:
"""The plane relative to the atoms mapped to `xy` plane of the potential, i.e. the plane is perpendicular to the
propagation direction."""
return self._plane
@property
def box(self) -> tuple[float, float, float]:
"""The extent of the potential in `x`, `y` and `z`."""
return self._box
@property
def origin(self) -> tuple[float, float, float]:
"""The origin relative to the provided atoms mapped to the origin of the potential."""
return self._origin
def __getitem__(self, item) -> PotentialArray:
return self.build(lazy=False)[item]
@staticmethod
def _wrap_build_potential(potential, first_slice, last_slice):
potential = potential.item()
array = potential.build(first_slice, last_slice, lazy=False).array
return array
def build(
self,
first_slice: int = 0,
last_slice: int = None,
max_batch: int | str = 1,
lazy: bool = None,
) -> FieldArray:
"""
Build the potential.
Parameters
----------
first_slice : int, optional
Index of the first slice of the generated potential.
last_slice : int, optional
Index of the last slice of the generated potential
max_batch : int or str, optional
Maximum number of slices to calculate in task. Default is 1.
lazy : bool, optional
If True, create the wave functions lazily, otherwise, calculate instantly. If None, this defaults to the
value set in the configuration file.
Returns
-------
potential_array : PotentialArray
The built potential as an array.
"""
lazy = _validate_lazy(lazy)
self.grid.check_is_defined()
if last_slice is None:
last_slice = len(self)
if lazy:
blocks = self.ensemble_blocks(self._default_ensemble_chunks)
xp = get_array_module(self.device)
chunks = validate_chunks(self.ensemble_shape, self._default_ensemble_chunks)
chunks = chunks + self.base_shape
if self.ensemble_shape:
new_axis = tuple(
range(
len(self.ensemble_shape),
len(self.ensemble_shape) + len(self.base_shape),
)
)
else:
new_axis = tuple(range(1, len(self.base_shape)))
array = blocks.map_blocks(
self._wrap_build_potential,
new_axis=new_axis,
first_slice=first_slice,
last_slice=last_slice,
chunks=chunks,
meta=xp.array((), dtype=get_dtype(complex=False)),
)
else:
xp = get_array_module(self.device)
array = xp.zeros(
self.ensemble_shape + (last_slice - first_slice,) + self.base_shape[1:],
dtype=get_dtype(complex=False),
)
if self.ensemble_shape:
for i, _, potential in self.generate_blocks(1):
potential = potential.item()
i = np.unravel_index((0,), self.ensemble_shape)
for j, slic in enumerate(
potential.generate_slices(first_slice, last_slice)
):
array[i + (j,)] = slic.array[0]
else:
for j, slic in enumerate(self.generate_slices(first_slice, last_slice)):
array[j] = slic.array[0]
potential = self._array_object(
array,
sampling=(self.sampling[0], self.sampling[1]),
slice_thickness=self.slice_thickness[first_slice:last_slice],
exit_planes=self.exit_planes,
ensemble_axes_metadata=self.ensemble_axes_metadata,
)
return potential
class _FieldBuilderFromAtoms(_FieldBuilder):
def __init__(
self,
atoms: Atoms | BaseFrozenPhonons,
array_object: Type[FieldArray],
gpts: int | tuple[int, int] = None,
sampling: float | tuple[float, float] = None,
slice_thickness: float | tuple[float, ...] = 1,
exit_planes: int | tuple[int, ...] = None,
plane: (
str | tuple[tuple[float, float, float], tuple[float, float, float]]
) = "xy",
origin: tuple[float, float, float] = (0.0, 0.0, 0.0),
box: tuple[float, float, float] = None,
periodic: bool = True,
integrator=None,
device: str = None,
):
self._frozen_phonons = _validate_frozen_phonons(atoms)
self._integrator = integrator
self._sliced_atoms = None
self._array_object = array_object
super().__init__(
array_object=array_object,
gpts=gpts,
sampling=sampling,
cell=self._frozen_phonons.cell,
slice_thickness=slice_thickness,
exit_planes=exit_planes,
device=device,
plane=plane,
origin=origin,
box=box,
periodic=periodic,
)
@property
def frozen_phonons(self) -> BaseFrozenPhonons:
"""Ensemble of atomic configurations representing frozen phonons."""
return self._frozen_phonons
@property
def num_configurations(self) -> int:
"""Size of the ensemble of atomic configurations representing frozen phonons."""
return len(self.frozen_phonons)
@property
def integrator(self):
"""The integrator determining how the projection integrals for each slice is calculated."""
return self._integrator
def _cutoffs(self):
atoms = self.frozen_phonons.atoms
unique_numbers = np.unique(atoms.numbers)
return tuple(
self._integrator.cutoff(chemical_symbols[number])
for number in unique_numbers
)
def get_transformed_atoms(self):
"""
The atoms used in the multislice algorithm, transformed to the given plane, origin and box.
Returns
-------
transformed_atoms : Atoms
Transformed atoms.
"""
atoms = self.frozen_phonons.atoms
if is_cell_orthogonal(atoms.cell) and self.plane != "xy":
atoms = rotate_atoms_to_plane(atoms, self.plane)
elif tuple(np.diag(atoms.cell)) != self.box:
if self.periodic:
atoms = orthogonalize_cell(
atoms,
box=self.box,
plane=self.plane,
origin=self.origin,
return_transform=False,
allow_transform=True,
)
return atoms
else:
cutoffs = self._cutoffs()
atoms = cut_cell(
atoms,
cell=self.box,
plane=self.plane,
origin=self.origin,
margin=max(cutoffs) if cutoffs else 0.0,
)
return atoms
def _prepare_atoms(self):
atoms = self.get_transformed_atoms()
if self.integrator.finite:
cutoffs = self._cutoffs()
margins = max(cutoffs) if len(cutoffs) else 0.0
else:
margins = 0.0
if self.periodic:
atoms = self.frozen_phonons.randomize(atoms)
atoms.wrap()
if not self.integrator.periodic and self.integrator.finite:
atoms = pad_atoms(atoms, margins=margins)
elif self.integrator.periodic:
atoms = pad_atoms(atoms, margins=margins, directions="z")
if not self.periodic:
atoms = self.frozen_phonons.randomize(atoms)
if self.integrator.finite:
sliced_atoms = SlicedAtoms(
atoms=atoms, slice_thickness=self.slice_thickness, z_padding=margins
)
else:
sliced_atoms = SliceIndexedAtoms(
atoms=atoms, slice_thickness=self.slice_thickness
)
return sliced_atoms
def get_sliced_atoms(self) -> BaseSlicedAtoms:
"""
The atoms grouped into the slices given by the slice thicknesses.
Returns
-------
sliced_atoms : BaseSlicedAtoms
"""
if self._sliced_atoms is not None:
return self._sliced_atoms
self._sliced_atoms = self._prepare_atoms()
return self._sliced_atoms
def generate_slices(
self, first_slice: int = 0, last_slice: int = None, return_depth: float = False
):
"""
Generate the slices for the potential.
Parameters
----------
first_slice : int, optional
Index of the first slice of the generated potential.
last_slice : int, optional
Index of the last slice of the generated potential.
return_depth : bool
If True, return the depth of each generated slice.
Yields
------
slices : generator of np.ndarray
Generator for the array of slices.
"""
if last_slice is None:
last_slice = len(self)
xp = get_array_module(self.device)
sliced_atoms = self.get_sliced_atoms()
numbers = np.unique(sliced_atoms.atoms.numbers)
exit_plane_after = self._exit_plane_after
cumulative_thickness = np.cumsum(self.slice_thickness)
for start, stop in generate_chunks(
last_slice - first_slice, chunks=1, start=first_slice
):
if len(numbers) > 1 or stop - start > 1:
array = xp.zeros(
(stop - start,) + self.base_shape[1:],
dtype=get_dtype(complex=False),
)
else:
array = None
for i, slice_idx in enumerate(range(start, stop)):
atoms = sliced_atoms.get_atoms_in_slices(slice_idx)
new_array = self._integrator.integrate_on_grid(
atoms,
a=sliced_atoms.slice_limits[slice_idx][0],
b=sliced_atoms.slice_limits[slice_idx][1],
gpts=self.gpts,
sampling=self.sampling,
device=self.device,
)
if array is not None:
array[i] += new_array
else:
array = new_array[None]
if array is None:
array = xp.zeros(
(stop - start,) + self.base_shape[1:],
dtype=get_dtype(complex=False),
)
# array -= array.min()
exit_planes = tuple(np.where(exit_plane_after[start:stop])[0])
potential_array = self._array_object(
array,
slice_thickness=self.slice_thickness[start:stop],
exit_planes=exit_planes,
extent=self.extent,
)
if return_depth:
depth = cumulative_thickness[stop - 1]
yield depth, potential_array
else:
yield potential_array
@property
def ensemble_axes_metadata(self):
return self.frozen_phonons.ensemble_axes_metadata
@property
def ensemble_shape(self) -> tuple[int, ...]:
return self.frozen_phonons.ensemble_shape
@classmethod
def _from_partitioned_args_func(cls, *args, frozen_phonons_partial, **kwargs):
args = unpack_blockwise_args(args)
frozen_phonons = frozen_phonons_partial(*args)
frozen_phonons = frozen_phonons.item()
new_potential = cls(frozen_phonons, **kwargs)
ndims = max(len(new_potential.ensemble_shape), 1)
new_potential = _wrap_with_array(new_potential, ndims)
return new_potential
def _from_partitioned_args(self, *args, **kwargs):
frozen_phonons_partial = self.frozen_phonons._from_partitioned_args()
kwargs = self._copy_kwargs(exclude=("atoms", "sampling"))
return partial(
self._from_partitioned_args_func,
frozen_phonons_partial=frozen_phonons_partial,
**kwargs,
)
def _partition_args(self, chunks: Chunks = (1,), lazy: bool = True):
return self.frozen_phonons._partition_args(chunks, lazy=lazy)
class _PotentialBuilder(_FieldBuilder, BasePotential):
pass
def _validate_frozen_phonons(atoms):
if isinstance(atoms, Atoms):
atoms = atoms.copy()
atoms.calc = None
if not hasattr(atoms, "randomize"):
if isinstance(atoms, (list, tuple)):
frozen_phonons = AtomsEnsemble(atoms)
elif isinstance(atoms, Atoms):
frozen_phonons = DummyFrozenPhonons(atoms)
else:
raise ValueError()
else:
frozen_phonons = atoms
return frozen_phonons
[docs]
class Potential(_FieldBuilderFromAtoms, BasePotential):
"""
Calculate the electrostatic potential of a set of atoms or frozen phonon configurations. The potential is calculated
with the Independent Atom Model (IAM) using a user-defined parametrization of the atomic potentials.
Parameters
----------
atoms : ase.Atoms or abtem.FrozenPhonons
Atoms or FrozenPhonons defining the atomic configuration(s) used in the independent atom model for calculating
the electrostatic potential(s).
gpts : one or two int, optional
Number of grid points in `x` and `y` describing each slice of the potential. Provide either "sampling" (spacing
between consecutive grid points) or "gpts" (total number of grid points).
sampling : one or two float, optional
Sampling of the potential in `x` and `y` [1 / Å]. Provide either "sampling" or "gpts".
slice_thickness : float or sequence of float, optional
Thickness of the potential slices in the propagation direction in [Å] (default is 0.5 Å).
If given as a float, the number of slices is calculated by dividing the slice thickness into the `z`-height of
supercell. The slice thickness may be given as a sequence of values for each slice, in which case an error will
be thrown if the sum of slice thicknesses is not equal to the height of the atoms.
parametrization : 'lobato' or 'kirkland', optional
The potential parametrization describes the radial dependence of the potential for each element. Two of the
most accurate parametrizations are available (by Lobato et al. and Kirkland; default is 'lobato').
See the citation guide for references.
projection : 'finite' or 'infinite', optional
If 'finite' the 3D potential is numerically integrated between the slice boundaries. If 'infinite' (default),
the infinite potential projection of each atom will be assigned to a single slice.
exit_planes : int or tuple of int, optional
The `exit_planes` argument can be used to calculate thickness series.
Providing `exit_planes` as a tuple of int indicates that the tuple contains the slice indices after which an
exit plane is desired, and hence during a multislice simulation a measurement is created. If `exit_planes` is
an integer a measurement will be collected every `exit_planes` number of slices.
plane : str or two tuples of three float, optional
The plane relative to the provided atoms mapped to `xy` plane of the potential, i.e. provided plane is
perpendicular to the propagation direction. If string, it must be a concatenation of two of 'x', 'y' and 'z';
the default value 'xy' indicates that potential slices are cuts along the `xy`-plane of the atoms.
The plane may also be specified with two arbitrary 3D vectors, which are mapped to the `x` and `y` directions of
the potential, respectively. The length of the vectors has no influence. If the vectors are not perpendicular,
the second vector is rotated in the plane to become perpendicular to the first. Providing a value of
((1., 0., 0.), (0., 1., 0.)) is equivalent to providing 'xy'.
origin : three float, optional
The origin relative to the provided atoms mapped to the origin of the potential. This is equivalent to
translating the atoms. The default is (0., 0., 0.).
box : three float, optional
The extent of the potential in `x`, `y` and `z`. If not given this is determined from the atoms' cell.
If the box size does not match an integer number of the atoms' supercell, an affine transformation may be
necessary to preserve periodicity, determined by the `periodic` keyword.
periodic : bool, True
If a transformation of the atomic structure is required, `periodic` determines how the atomic structure is
transformed. If True, the periodicity of the Atoms is preserved, which may require applying a small affine
transformation to the atoms. If False, the transformed potential is effectively cut out of a larger repeated
potential, which may not preserve periodicity.
integrator : ProjectionIntegrator, optional
Provide a custom integrator for the projection integrals of the potential slicing.
device : str, optional
The device used for calculating the potential, 'cpu' or 'gpu'. The default is determined by the user
configuration file.
"""
_exclude_from_copy = ("parametrization", "projection")
[docs]
def __init__(
self,
atoms: Atoms | BaseFrozenPhonons = None,
gpts: int | tuple[int, int] = None,
sampling: float | tuple[float, float] = None,
slice_thickness: float | tuple[float, ...] = 1,
parametrization: str | Parametrization = "lobato",
projection: str = "infinite",
exit_planes: int | tuple[int, ...] = None,
plane: (
str | tuple[tuple[float, float, float], tuple[float, float, float]]
) = "xy",
origin: tuple[float, float, float] = (0.0, 0.0, 0.0),
box: tuple[float, float, float] = None,
periodic: bool = True,
integrator: FieldIntegrator = None,
device: str = None,
):
if integrator is None:
if projection == "finite":
integrator = QuadratureProjectionIntegrals(
parametrization=parametrization
)
elif projection == "infinite":
integrator = ScatteringFactorProjectionIntegrals(
parametrization=parametrization
)
else:
raise NotImplementedError
super().__init__(
atoms=atoms,
array_object=PotentialArray,
gpts=gpts,
sampling=sampling,
slice_thickness=slice_thickness,
exit_planes=exit_planes,
device=device,
plane=plane,
origin=origin,
box=box,
periodic=periodic,
integrator=integrator,
)
[docs]
class FieldArray(BaseField, ArrayObject):
[docs]
def __init__(
self,
array: np.ndarray | da.core.Array,
slice_thickness: float | tuple[float, ...] = None,
extent: float | tuple[float, float] = None,
sampling: float | tuple[float, float] = None,
exit_planes: int | tuple[int, ...] = None,
ensemble_axes_metadata: list[AxisMetadata] = None,
metadata: dict = None,
):
self._slice_thickness = _validate_slice_thickness(
slice_thickness, num_slices=array.shape[-self._base_dims]
)
self._exit_planes = _validate_exit_planes(
exit_planes, len(self._slice_thickness)
)
self._grid = Grid(extent=extent, gpts=array.shape[-2:], sampling=sampling)
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def num_configurations(self):
indices = _find_axes_type(self, FrozenPhononsAxis)
if indices:
return reduce(mul, tuple(self.array.shape[i] for i in indices))
else:
return 1
@property
def slice_thickness(self) -> tuple[float, ...]:
return self._slice_thickness
@property
def exit_planes(self) -> tuple[int, ...]:
return self._exit_planes
def build(
self,
first_slice: int = 0,
last_slice: int = None,
chunks: int = 1,
lazy: bool = None,
):
raise RuntimeError("potential is already built")
[docs]
def generate_slices(self, first_slice: int = 0, last_slice: int = None):
"""
Generate the slices for the potential.
Parameters
----------
first_slice : int, optional
Index of the first slice of the generated potential.
last_slice : int, optional
Index of the last slice of the generated potential.
Yields
------
slices : generator of np.ndarray
Generator for the array of slices.
"""
if last_slice is None:
last_slice = len(self)
exit_plane_after = self._exit_plane_after
cum_thickness = np.cumsum(self.slice_thickness)
start = first_slice
stop = first_slice + 1
for i in range(first_slice, last_slice):
s = (0,) * (len(self.array.shape) - 3) + (i,)
array = self.array[s][None]
slic = self.__class__(
array, self.slice_thickness[i : i + 1], extent=self.extent
)
exit_planes = tuple(np.where(exit_plane_after[start:stop])[0])
slic._exit_planes = exit_planes
start += 1
stop += 1
yield slic
def __getitem__(self, items):
if isinstance(items, (Number, slice)):
items = (items,)
ensemble_items = items[: len(self.ensemble_shape)]
slic_items = items[len(self.ensemble_shape) :]
if len(ensemble_items):
potential_array = super().__getitem__(ensemble_items)
else:
potential_array = self
if len(slic_items) == 0:
return potential_array
padded_items = (slice(None),) * len(potential_array.ensemble_shape) + slic_items
array = potential_array._array[padded_items]
slice_thickness = np.array(potential_array.slice_thickness)[slic_items]
if len(array.shape) < len(potential_array.shape):
array = array[
(slice(None),) * len(potential_array.ensemble_shape) + (None,)
]
slice_thickness = slice_thickness[None]
kwargs = potential_array._copy_kwargs(exclude=("array", "slice_thickness"))
kwargs["array"] = array
kwargs["slice_thickness"] = slice_thickness
kwargs["sampling"] = None
return potential_array.__class__(**kwargs)
[docs]
def tile(self, repetitions: tuple[int, int] | tuple[int, int, int]):
"""
Tile the potential.
Parameters
----------
repetitions: two or three int
The number of repetitions of the potential along each axis. NOTE: if three integers are given, the first
represents the number of repetitions along the `z`-axis.
Returns
-------
PotentialArray object
The tiled potential.
"""
if len(repetitions) == 2:
repetitions = tuple(repetitions) + (1,)
new_array = np.tile(
self.array, (repetitions[2], repetitions[0], repetitions[1])
)
new_extent = (self.extent[0] * repetitions[0], self.extent[1] * repetitions[1])
new_slice_thickness = tuple(np.tile(self.slice_thickness, repetitions[2]))
return self.__class__(
array=new_array,
slice_thickness=new_slice_thickness,
extent=new_extent,
ensemble_axes_metadata=self.ensemble_axes_metadata,
)
[docs]
def to_hyperspy(self):
return self.to_images().to_hyperspy()
[docs]
def to_images(self):
"""Convert slices of the potential to a stack of images."""
metadata = {"label": "potential", "units": "eV / e"}
return Images(
array=self._array,
sampling=(self.sampling[0], self.sampling[1]),
metadata=metadata,
ensemble_axes_metadata=self.axes_metadata[:-2],
)
[docs]
def project(self) -> Images:
"""
Create a 2D array representing a projected image of the potential(s).
Returns
-------
images : Images
One or more images of the projected potential(s).
"""
metadata = {"label": "potential", "units": "eV / e"}
array = self.array.sum(-self._base_dims)
# array -= array.min((-2, -1), keepdims=True)
ensemble_axes_metadata = (
self.ensemble_axes_metadata + self.base_axes_metadata[1:-2]
)
return Images(
array=array,
sampling=self.sampling,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
[docs]
class PotentialArray(BasePotential, FieldArray):
"""
The potential array represents slices of the electrostatic potential as an array. All other potentials build
potential arrays.
Parameters
----------
array: 3D np.ndarray
The array representing the potential slices. The first dimension is the slice index and the last two are the
spatial dimensions.
slice_thickness: float
The thicknesses of potential slices [Å]. If a float, the thickness is the same for all slices.
If a sequence, the length must equal the length of the potential array.
extent: one or two float, optional
Lateral extent of the potential [Å].
sampling: one or two float, optional
Lateral sampling of the potential [1 / Å].
exit_planes : int or tuple of int, optional
The `exit_planes` argument can be used to calculate thickness series.
Providing `exit_planes` as a tuple of int indicates that the tuple contains the slice indices after which an
exit plane is desired, and hence during a multislice simulation a measurement is created. If `exit_planes` is
an integer a measurement will be collected every `exit_planes` number of slices.
ensemble_axes_metadata : list of AxesMetadata
Axis metadata for each ensemble axis. The axis metadata must be compatible with the shape of the array.
metadata : dict
A dictionary defining wave function metadata. All items will be added to the metadata of measurements derived
from the waves.
"""
_base_dims = 3
[docs]
def __init__(
self,
array: np.ndarray | da.core.Array,
slice_thickness: float | tuple[float, ...] = None,
extent: float | tuple[float, float] = None,
sampling: float | tuple[float, float] = None,
exit_planes: int | tuple[int, ...] = None,
ensemble_axes_metadata: list[AxisMetadata] = None,
metadata: dict = None,
):
super().__init__(
array=array,
slice_thickness=slice_thickness,
extent=extent,
sampling=sampling,
exit_planes=exit_planes,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
[docs]
def transmission_function(self, energy: float) -> TransmissionFunction:
"""
Calculate the transmission functions for each slice for a specific energy.
Parameters
----------
energy: float
Electron energy [eV].
Returns
-------
transmissionfunction : TransmissionFunction
Transmission functions for each slice.
"""
xp = get_array_module(self.array)
def _transmission_function(array, energy):
dtype = get_dtype(complex=False)
sigma = dtype(energy2sigma(energy))
array = complex_exponential(sigma * array)
return array
if self.is_lazy:
array = self._array.map_blocks(
_transmission_function,
energy=energy,
meta=xp.array((), dtype=get_dtype(complex=True)),
)
else:
array = _transmission_function(self._array, energy=energy)
t = TransmissionFunction(
array,
slice_thickness=self.slice_thickness,
extent=self.extent,
energy=energy,
)
return t
[docs]
def transmit(self, waves: Waves, conjugate: bool = False) -> Waves:
"""
Transmit a wave function through a potential slice.
Parameters
----------
waves: Waves
Waves object to transmit.
conjugate : bool, optional
If True, use the conjugate of the transmission function. Default is False.
Returns
-------
transmission_function : TransmissionFunction
Transmission function for the wave function through the potential slice.
"""
transmission_function = self.transmission_function(waves.energy)
return transmission_function.transmit(waves, conjugate=conjugate)
[docs]
class TransmissionFunction(PotentialArray, HasAcceleratorMixin):
"""Class to describe transmission functions.
Parameters
----------
array : 3D np.ndarray
The array representing the potential slices. The first dimension is the slice index and the last two are the
spatial dimensions.
slice_thickness : float
The thicknesses of potential slices [Å]. If a float, the thickness is the same for all slices.
If a sequence, the length must equal the length of the potential array.
extent : one or two float, optional
Lateral extent of the potential [Å].
sampling : one or two float, optional
Lateral sampling of the potential [1 / Å].
energy : float
Electron energy [eV].
"""
[docs]
def __init__(
self,
array: np.ndarray,
slice_thickness: float | Sequence[float],
extent: float | tuple[float, float] = None,
sampling: float | tuple[float, float] = None,
energy: float = None,
):
self._accelerator = Accelerator(energy=energy)
super().__init__(array, slice_thickness, extent, sampling)
def get_chunk(self, first_slice, last_slice) -> TransmissionFunction:
array = self.array[first_slice:last_slice]
if len(array.shape) == 2:
array = array[None]
return self.__class__(
array,
self.slice_thickness[first_slice:last_slice],
extent=self.extent,
energy=self.energy,
)
[docs]
def transmission_function(self, energy) -> TransmissionFunction:
"""
Calculate the transmission functions for each slice for a specific energy.
Parameters
----------
energy: float
Electron energy [eV].
Returns
-------
transmissionfunction : TransmissionFunction
Transmission functions for each slice.
"""
if energy != self.energy:
raise RuntimeError()
return self
[docs]
def transmit(self, waves: Waves, conjugate: bool = False) -> Waves:
"""
Transmit a wave function through a potential slice.
Parameters
----------
waves: Waves
Waves object to transmit.
conjugate : bool, optional
If True, use the conjugate of the transmission function. Default is False.
Returns
-------
transmission_function : Waves
Transmission function for the wave function through the potential slice.
"""
self.accelerator.check_match(waves)
self.grid.check_match(waves)
xp = get_array_module(self.array[0])
if conjugate:
waves._array *= xp.conjugate(self.array[0])
else:
waves._array *= self.array[0]
return waves
[docs]
class CrystalPotential(_PotentialBuilder):
"""
The crystal potential may be used to represent a potential consisting of a repeating unit. This may allow
calculations to be performed with lower computational cost by calculating the potential unit once and repeating it.
If the repeating unit is a potential with frozen phonons it is treated as an ensemble from which each repeating
unit along the `z`-direction is randomly drawn. If `num_frozen_phonons` an ensemble of crystal potentials are created
each with a random seed for choosing potential units.
Parameters
----------
potential_unit : BasePotential
The potential unit to assemble the crystal potential from.
repetitions : three int
The repetitions of the potential in `x`, `y` and `z`.
num_frozen_phonons : int, optional
Number of frozen phonon configurations assembled from the potential units.
exit_planes : int or tuple of int, optional
The `exit_planes` argument can be used to calculate thickness series.
Providing `exit_planes` as a tuple of int indicates that the tuple contains the slice indices after which an
exit plane is desired, and hence during a multislice simulation a measurement is created. If `exit_planes` is
an integer a measurement will be collected every `exit_planes` number of slices.
seeds: int or sequence of int
Seed for the random number generator (RNG), or one seed for each RNG in the frozen phonon ensemble.
"""
[docs]
def __init__(
self,
potential_unit: BasePotential,
repetitions: tuple[int, int, int],
num_frozen_phonons: int = None,
exit_planes: int = None,
seeds: int | tuple[int, ...] = None,
):
if num_frozen_phonons is None and seeds is None:
self._seeds = None
else:
if num_frozen_phonons is None and seeds:
num_frozen_phonons = len(seeds)
elif num_frozen_phonons is None and seeds is None:
num_frozen_phonons = 1
self._seeds = _validate_seeds(seeds, num_frozen_phonons)
if (
(potential_unit.num_configurations == 1)
and (num_frozen_phonons is not None)
and (num_frozen_phonons > 1)
):
warnings.warn(
"'num_frozen_phonons' is greater than one, but the potential unit does not have frozen phonons"
)
# if (potential_unit.num_frozen_phonons > 1) and (num_frozen_phonons is not None):
# warnings.warn(
# "the potential unit has frozen phonons, but 'num_frozen_phonons' is not set"
# )
gpts = (
potential_unit.gpts[0] * repetitions[0],
potential_unit.gpts[1] * repetitions[1],
)
extent = (
potential_unit.extent[0] * repetitions[0],
potential_unit.extent[1] * repetitions[1],
)
box = extent + (potential_unit.thickness * repetitions[2],)
slice_thickness = potential_unit.slice_thickness * repetitions[2]
super().__init__(
array_object=PotentialArray,
gpts=gpts,
cell=Cell(np.diag(box)),
slice_thickness=slice_thickness,
exit_planes=exit_planes,
device=potential_unit.device,
plane="xy",
origin=(0.0, 0.0, 0.0),
box=box,
periodic=True,
)
self._potential_unit = potential_unit
self._repetitions = repetitions
@property
def ensemble_shape(self) -> tuple[int, ...]:
if self._seeds is None:
return ()
else:
return (self.num_configurations,)
@property
def num_configurations(self):
if self._seeds is None:
return 1
else:
return len(self._seeds)
@property
def seeds(self):
return self._seeds
@property
def potential_unit(self) -> BasePotential:
return self._potential_unit
@HasGridMixin.gpts.setter
def gpts(self, gpts):
if not (
(gpts[0] % self.repetitions[0] == 0)
and (gpts[1] % self.repetitions[0] == 0)
):
raise ValueError(
"Number of grid points must be divisible by the number of potential repetitions."
)
self.grid.gpts = gpts
self._potential_unit.gpts = (
gpts[0] // self._repetitions[0],
gpts[1] // self._repetitions[1],
)
@HasGridMixin.sampling.setter
def sampling(self, sampling):
self.sampling = sampling
self._potential_unit.sampling = sampling
@property
def repetitions(self) -> tuple[int, int, int]:
return self._repetitions
@property
def num_slices(self) -> int:
return self._potential_unit.num_slices * self.repetitions[2]
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
if self.seeds is None:
return []
else:
return [FrozenPhononsAxis(_ensemble_mean=True)]
@classmethod
def _from_partitioned_args_func(cls, *args, **kwargs):
args = unpack_blockwise_args(args)
potential, seed = args[0]
if hasattr(potential, "item"):
potential = potential.item()
if seed is not None:
num_frozen_phonons = len(seed)
else:
num_frozen_phonons = None
new = cls(
potential_unit=potential,
seeds=seed,
num_frozen_phonons=num_frozen_phonons,
**kwargs,
)
return _wrap_with_array(new)
def _from_partitioned_args(self):
kwargs = self._copy_kwargs(
exclude=("potential_unit", "seeds", "num_frozen_phonons")
)
output = partial(self._from_partitioned_args_func, **kwargs)
return output
def _partition_args(self, chunks: int = 1, lazy: bool = True):
chunks = validate_chunks(self.ensemble_shape, chunks)
if chunks == ():
chunks = ((1,),)
if lazy:
arrays = []
for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
if self.seeds is not None:
seeds = self.seeds[start:stop]
else:
seeds = None
lazy_atoms = dask.delayed(self.potential_unit)
lazy_args = dask.delayed(_wrap_with_array)((lazy_atoms, seeds), ndims=1)
lazy_array = da.from_delayed(lazy_args, shape=(1,), dtype=object)
arrays.append(lazy_array)
array = da.concatenate(arrays)
else:
potential_unit = self.potential_unit
# if self.potential_unit.array:
# atoms = atoms.compute()
array = np.zeros((len(chunks[0]),), dtype=object)
for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
if self.seeds is not None:
seeds = self.seeds[start:stop]
else:
seeds = None
array.itemset(i, (potential_unit, self.seeds))
return (array,)
# chunks = validate_chunks(self.ensemble_shape, chunks)
#
# if not len(self.ensemble_shape):
# chunks = ((1,),)
#
# if lazy:
# arrays = []
# for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
# if self.seeds is not None:
# seeds = self.seeds[start:stop]
# else:
# seeds = self.seeds
# lazy_potential = self.potential_unit.ensemble_blocks(-1)
# lazy_args = dask.delayed(_wrap_with_array)((lazy_potential, seeds), ndims=1)
# lazy_array = da.from_delayed(lazy_args, shape=(1,), dtype=object)
# arrays.append(lazy_array)
#
# array = da.concatenate(arrays)
# else:
#
# array = np.zeros((chunks[0],), dtype=object)
# for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
# if self.seeds is not None:
# seeds = self.seeds[start:stop]
# else:
# seeds = self.seeds
#
# array.itemset(i, (self.potential_unit, seeds))
# return (array,)
[docs]
def generate_slices(
self, first_slice: int = 0, last_slice: int = None, return_depth: bool = False
):
"""
Generate the slices for the potential.
Parameters
----------
first_slice : int, optional
Index of the first slice of the generated potential.
last_slice : int, optional
Index of the last slice of the generated potential.
return_depth : bool
If True, return the depth of each generated slice.
Yields
------
slices : generator of np.ndarray
Generator for the array of slices.
"""
if hasattr(self.potential_unit, "array"):
potentials = self.potential_unit
else:
potentials = self.potential_unit.build(lazy=False)
if len(potentials.shape) == 3:
potentials = potentials.expand_dims(axis=0)
if self.seeds is None:
rng = np.random.default_rng(self.seeds)
else:
rng = np.random.default_rng(self.seeds[0])
exit_plane_after = self._exit_plane_after
cum_thickness = np.cumsum(self.slice_thickness)
start = first_slice
stop = first_slice + 1
for i in range(self.repetitions[2]):
generator = potentials[
rng.integers(0, potentials.shape[0])
].generate_slices()
for i in range(len(self.potential_unit)):
slic = next(generator).tile(self.repetitions[:2])
exit_planes = tuple(np.where(exit_plane_after[start:stop])[0])
slic._exit_planes = exit_planes
start += 1
stop += 1
if return_depth:
yield cum_thickness[stop - 1], slic
else:
yield slic
if i == last_slice:
break