Source code for abtem.transform
"""Module to describe wave function transformations."""
from __future__ import annotations
import itertools
from abc import abstractmethod, ABCMeta
from functools import partial, reduce
from typing import TYPE_CHECKING, Iterator
import dask.array as da
import numpy as np
from abtem.array import T
from abtem.core.axes import AxisMetadata, ParameterAxis
from abtem.core.backend import get_array_module
from abtem.core.chunks import Chunks, validate_chunks
from abtem.core.ensemble import (
Ensemble,
EmptyEnsemble,
_wrap_with_array,
unpack_blockwise_args,
)
from abtem.core.fft import ifft2
from abtem.core.utils import (
CopyMixin,
expand_dims_to_broadcast,
EqualityMixin,
)
from abtem.distributions import (
EnsembleFromDistributions,
validate_distribution,
BaseDistribution,
)
if TYPE_CHECKING:
from abtem.waves import Waves
from abtem.array import ArrayObject
[docs]
class ArrayObjectTransform(Ensemble, EqualityMixin, CopyMixin, metaclass=ABCMeta):
_allow_base_chunks: bool = False
@property
def _num_outputs(self) -> int:
return 1
@property
def metadata(self) -> dict:
"""Metadata added to the waves when applying the transform."""
return {}
@property
def ensemble_shape(self) -> tuple[int, ...]:
"""The shape of the ensemble axes added to the waves when applying the transform."""
return ()
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
"""Axes metadata describing the ensemble axes added to the waves when applying the transform."""
return []
def _out_meta(self, array_object: T, index: int = 0) -> np.ndarray:
"""
The meta describing the measurement array created when detecting the given waves.
Parameters
----------
array_object : ArrayObject
The array object to derive the measurement meta from.
Returns
-------
meta : array-like
Empty array.
"""
xp = get_array_module(array_object.device)
return xp.array((), dtype=self._out_dtype(array_object))
def _out_metadata(self, array_object: T, index: int = 0) -> dict:
"""
Metadata added to the measurements created when detecting the given waves.
Parameters
----------
array_object : ArrayObject
The array object to derive the metadata from.
Returns
-------
metadata : dict
"""
return array_object.metadata
def _out_dtype(
self, array_object: ArrayObject | T, index: int = 0
) -> type[np.dtype]:
"""Datatype of the output array."""
return array_object.dtype
def _out_type(self, array_object: T, index: int = 0) -> type[T]:
"""
The subtype of the created array object after applying the transform.
Parameters
----------
array_object : ArrayObject
The waves to derive the measurement shape from.
Returns
-------
measurement_type : type of :class:`BaseMeasurements`
"""
return array_object.__class__
def _out_ensemble_shape(self, array_object: T, index: int = 0) -> tuple[int, ...]:
"""
Shape of the measurements created when detecting the given waves.
Parameters
----------
array_object : ArrayObject
The array object to derive the shape of the output array object from.
Returns
-------
measurement_shape : tuple of int
"""
return self.ensemble_shape + array_object.ensemble_shape
def _out_base_shape(self, array_object: T, index: int = 0) -> tuple[int, ...]:
"""
Shape of the array object created by the transformation.
Parameters
----------
array_object : ArrayObject
The waves to derive the measurement shape from.
Returns
-------
measurement_shape : tuple of int
"""
return array_object.base_shape
def _out_shape(self, array_object: T, index: int = 0) -> tuple[int, ...]:
ensemble_shape = self._out_ensemble_shape(array_object)
base_shape = self._out_base_shape(array_object, index)
return ensemble_shape + base_shape
def _out_base_axes_metadata(
self, array_object: T, index: int = 0
) -> list[AxisMetadata]:
"""
Axes metadata of the created measurements when detecting the given waves.
Parameters
----------
array_object: ArrayObject
The waves to derive the measurement shape from.
Returns
-------
axes_metadata : list of :class:`AxisMetadata`
"""
return array_object.base_axes_metadata
def _out_ensemble_axes_metadata(
self, array_object: ArrayObject | T, index: int = 0
) -> list[AxisMetadata] | tuple[list[AxisMetadata], ...]:
return [*self.ensemble_axes_metadata, *array_object.ensemble_axes_metadata]
def __add__(self, other: ArrayObjectTransform) -> CompositeArrayObjectTransform:
transforms = []
for transform in (self, other):
if hasattr(transform, "transforms"):
transforms += transform.transforms
else:
transforms += [transform]
return CompositeArrayObjectTransform(transforms)
def _out_axes_metadata(self, array_object: T):
return [
*self._out_ensemble_axes_metadata(array_object),
*self._out_base_axes_metadata(array_object),
]
def _get_blockwise_args(self, chunks):
def _tuple_range(length, offset=0):
return tuple(range(offset, offset + length))
def _arrays_to_symbols(arrays):
offset = 0
symbols = ()
for array in arrays:
length = len(array.shape)
symbols += (_tuple_range(length=length, offset=offset),)
offset += length
return symbols
transform_args = self._partition_args(chunks=chunks)
transform_symbols = _arrays_to_symbols(transform_args)
assert sum(len(args.shape) for args in transform_args) == sum(
len(symbols) for symbols in transform_symbols
)
return transform_args, transform_symbols
@staticmethod
def _extract(array, index):
array = array.item()[index]
return array
def _pack_multiple_outputs(
self, array_object: T, new_arrays: np.ndarray | da.core.Array
):
is_lazy = isinstance(new_arrays, da.core.Array)
# if is_lazy:
# assert ensemble_shape == new_arrays.shape
outputs = ()
for output_index in range(self._num_outputs):
base_shape = self._out_base_shape(array_object, output_index)
ensemble_shape = self._out_ensemble_shape(array_object, output_index)
meta = self._out_meta(array_object, output_index)
cls = self._out_type(array_object, output_index)
metadata = self._out_metadata(array_object, output_index)
base_axes_metadata = self._out_base_axes_metadata(
array_object, output_index
)
ensemble_axes_metadata = self._out_ensemble_axes_metadata(
array_object, output_index
)
if is_lazy:
shape = ensemble_shape + base_shape
consumed_ensemble_axes = len(new_arrays.shape) - len(ensemble_shape)
new_axis = tuple(range(len(ensemble_shape), len(shape)))
chunks = new_arrays.chunks
chunks = chunks + tuple((n,) for n in base_shape)
chunks = chunks[: len(chunks) - consumed_ensemble_axes]
new_axis = new_axis[: len(new_axis) - consumed_ensemble_axes]
new_array = new_arrays.map_blocks(
self._extract,
output_index,
chunks=chunks,
new_axis=new_axis,
meta=meta,
)
else:
new_array = new_arrays[output_index]
axes_metadata = ensemble_axes_metadata + base_axes_metadata
output = cls.from_array_and_metadata(
new_array, axes_metadata=axes_metadata, metadata=metadata
)
outputs += (output,)
return outputs
def _pack_single_output(
self,
array_object: T,
new_array: np.ndarray,
):
ensemble_axes_metadata = self._out_ensemble_axes_metadata(array_object)
base_axes_metadata = self._out_base_axes_metadata(array_object)
axes_metadata = ensemble_axes_metadata + base_axes_metadata
metadata = self._out_metadata(array_object)
cls = self._out_type(array_object)
array_object = cls.from_array_and_metadata(
new_array, axes_metadata=axes_metadata, metadata=metadata
)
return array_object
def _calculate_new_array(
self, array_object: T
) -> np.ndarray | tuple[np.ndarray, ...]:
raise NotImplementedError
[docs]
def apply(self, array_object: T) -> T | tuple[T, ...]:
"""
Apply the transform to the given waves.
Parameters
----------
array_object : ArrayObject
The array object to transform.
Returns
-------
transformed_array_object : ArrayObject
"""
# if not isinstance(array_object, ArrayObject):
# raise ValueError("Input must be an ArrayObject.")
new_array = self._calculate_new_array(array_object)
if self._num_outputs > 1:
return self._pack_multiple_outputs(array_object, new_array)
else:
return self._pack_single_output(array_object, new_array)
[docs]
class EnsembleTransform(EnsembleFromDistributions, ArrayObjectTransform):
[docs]
def __init__(self, distributions: tuple[str, ...] = ()):
super().__init__(distributions=distributions)
@staticmethod
def _validate_distribution(distribution):
return validate_distribution(distribution)
def _validate_ensemble_axes_metadata(self, ensemble_axes_metadata):
if isinstance(ensemble_axes_metadata, AxisMetadata):
ensemble_axes_metadata = [ensemble_axes_metadata]
assert len(ensemble_axes_metadata) == len(self.ensemble_shape)
return ensemble_axes_metadata
def _get_axes_metadata_from_distributions(self, **kwargs):
ensemble_axes_metadata = []
for name in kwargs.keys():
assert name in self._distributions
distribution = getattr(self, name)
if isinstance(distribution, BaseDistribution):
axis_kwargs = kwargs[name]
ensemble_axes_metadata += [
ParameterAxis(
values=distribution,
_ensemble_mean=distribution.ensemble_mean,
**axis_kwargs,
)
]
return ensemble_axes_metadata
[docs]
class WavesTransform(EnsembleTransform):
@property
def distributions(self):
return self._distributions
[docs]
class TransformFromFunc(WavesTransform):
[docs]
def __init__(self, func, func_kwargs):
self._func = func
self._func_kwargs = func_kwargs
super().__init__()
@property
def func(self):
return self._func
@property
def func_kwargs(self):
return self._func_kwargs
def _calculate_new_array(self, array_object):
return self.func(array_object, **self.func_kwargs)
[docs]
class CompositeArrayObjectTransform(ArrayObjectTransform):
"""
Combines multiple array object transformations into a single transformation.
Parameters
----------
transforms : ArrayObject
The array object to transform.
"""
[docs]
def __init__(
self,
transforms: list[ArrayObjectTransform] = None,
):
if transforms is None:
transforms = []
self._transforms = transforms
self._base_shapes = None
self._ensemble_shapes = None
self._base_axes_metadata = None
self._ensemble_axes_metadata = None
self._types = None
self._metas = None
self._metadata = None
super().__init__()
@property
def _num_outputs(self) -> int:
return self._transforms[0]._num_outputs
def set_output_specification(self, array_object):
for transform in reversed(self.transforms):
array_object = array_object.apply_transform(transform)
if self._num_outputs == 1:
output = [array_object]
else:
output = array_object
self._base_shapes = ()
self._ensemble_shapes = ()
self._base_axes_metadata = ()
self._ensemble_axes_metadata = ()
self._metas = ()
self._types = ()
self._metadata = ()
for i in range(self._num_outputs):
self._base_shapes += (output[i].base_shape,)
self._ensemble_shapes += (output[i].ensemble_shape,)
self._base_axes_metadata += (output[i].base_axes_metadata,)
self._ensemble_axes_metadata += (output[i].ensemble_axes_metadata,)
xp = get_array_module(output[i].array)
self._metas += (xp.array((), dtype=output[i].dtype),)
self._types += (output[i].__class__,)
self._metadata += (output[i].metadata,)
return self
[docs]
def insert(
self, transform: ArrayObjectTransform, index: int
) -> CompositeArrayObjectTransform:
"""
Inserts an array object transform to the sequence of transforms before the specified index.
Parameters
----------
transform : ArrayObjectTransform
Array object transform to insert.
index : int
The array object transform is inserted before this index.
Returns
-------
composite_array_transform : CompositeArrayObjectTransform
"""
self._transforms.insert(index, transform)
return self
def __len__(self) -> int:
return len(self.transforms)
def __iter__(self) -> Iterator[ArrayObjectTransform]:
return iter(self.transforms)
def _out_metadata(self, array_object, index=0):
if self._metadata is not None:
return self._metadata[index]
metadata = [
transform._out_metadata(array_object, index)
for transform in self.transforms
]
return reduce(lambda a, b: {**a, **b}, metadata)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
ensemble_axes_metadata = [
transform.ensemble_axes_metadata
for i, transform in enumerate(self.transforms)
]
ensemble_axes_metadata = list(itertools.chain(*ensemble_axes_metadata))
return ensemble_axes_metadata
def _out_ensemble_axes_metadata(self, array_object, index=0) -> list[AxisMetadata]:
if self._ensemble_axes_metadata is not None:
return self._ensemble_axes_metadata[index]
return self.ensemble_axes_metadata + array_object.ensemble_axes_metadata
def _out_base_axes_metadata(self, array_object, index=0) -> list[AxisMetadata]:
if self._base_axes_metadata is not None:
return self._base_axes_metadata[index]
return self._transforms[0]._out_base_axes_metadata(array_object, index)
@property
def ensemble_shape(self) -> tuple[int, ...]:
ensemble_shape = [transform.ensemble_shape for transform in self.transforms]
return tuple(itertools.chain(*ensemble_shape))
def _out_ensemble_shape(self, array_object, index=0) -> tuple[int, ...]:
if self._ensemble_shapes is not None:
return self._ensemble_shapes[index]
ensemble_shape = self.ensemble_shape + array_object.ensemble_shape
return ensemble_shape
def _out_base_shape(self, array_object, index=0):
if self._base_shapes is not None:
return self._base_shapes[index]
return self._transforms[0]._out_base_shape(array_object, index)
def _out_meta(self, array_object, index=0):
if self._metas is not None:
return self._metas[index]
return self._transforms[0]._out_meta(array_object, index)
def _out_dtype(self, array_object, index=0):
if self._metas is not None:
return self._metas[index].dtype
return self._out_meta(array_object, index).dtype
def _out_type(self, array_object, index=0):
if self._types is not None:
return self._types[index]
return self._transforms[0]._out_type(array_object, index)
@property
def transforms(self) -> list[ArrayObjectTransform]:
"""The list of transforms in the composite."""
return self._transforms
@property
def _default_ensemble_chunks(self) -> Chunks:
default_ensemble_chunks = [
transform._default_ensemble_chunks for transform in self.transforms
]
return tuple(itertools.chain(*default_ensemble_chunks))
[docs]
def apply(self, array_object):
if len(self):
return super().apply(array_object)
else:
return array_object
def _calculate_new_array(
self, array_object: T
) -> np.ndarray | tuple[np.ndarray, ...]:
for transform in reversed(self.transforms):
array_object = transform.apply(array_object)
if self._num_outputs > 1:
return tuple(array_object[i].array for i in range(self._num_outputs))
else:
return array_object.array
@staticmethod
def _partial(*args, partials):
args = unpack_blockwise_args(args)
transforms = []
for partial, arg_indices in partials:
partial_args = tuple(args[i] for i in arg_indices)
transforms += [partial(*partial_args).item()]
new_transform = CompositeArrayObjectTransform(transforms)
return _wrap_with_array(new_transform)
def _from_partitioned_args(self):
partials = ()
i = 0
for transform in self.transforms:
num_args = len(transform._partition_args(1))
arg_indices = tuple(range(i, i + num_args))
partials += ((transform._from_partitioned_args(), arg_indices),)
i += num_args
return partial(self._partial, partials=partials)
def _partition_args(self, chunks=None, lazy: bool = True):
if chunks is None:
chunks = self._default_ensemble_chunks
chunks = validate_chunks(self.ensemble_shape, chunks, limit="auto")
chunks = self._validate_ensemble_chunks(chunks)
blocks = ()
start = 0
for transform in self.transforms:
stop = start + len(transform.ensemble_shape)
blocks += transform._partition_args(chunks[start:stop], lazy=lazy)
start = stop
return blocks
[docs]
class ReciprocalSpaceMultiplication(WavesTransform):
"""
Wave function transformation for multiplying each member of an ensemble of wave functions with an array.
Parameters
----------
in_place: bool, optional
If True, the array representing the waves may be modified in-place.
distributions : tuple of str, optional
Names of properties that may be described by a distribution.
"""
[docs]
def __init__(
self,
in_place: bool = False,
distributions: tuple[str, ...] = (),
):
self._in_place = in_place
super().__init__(distributions=distributions)
@property
def in_place(self) -> bool:
"""The array representing the waves may be modified in-place."""
return self._in_place
@abstractmethod
def _evaluate_kernel(self, array_object):
pass
def _calculate_new_array(self, waves: Waves) -> np.ndarray:
real_space_in = not waves.reciprocal_space
waves = waves.ensure_reciprocal_space(overwrite_x=self.in_place)
kernel = self._evaluate_kernel(waves)
kernel, new_array = expand_dims_to_broadcast(
kernel, waves.array, match_dims=[(-2, -1), (-2, -1)]
)
xp = get_array_module(waves.array)
kernel = xp.array(kernel)
if self.in_place:
new_array *= kernel
else:
new_array = new_array * kernel
if real_space_in:
new_array = ifft2(new_array, overwrite_x=self.in_place)
return new_array