Source code for abtem.transform
"""Module to describe wave function transformations."""
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, Mapping, Optional, TypeVar
import numpy as np
from abtem.core.axes import AxisMetadata, ParameterAxis
from abtem.core.backend import get_array_module
from abtem.core.chunks import Chunks
from abtem.core.ensemble import (
EmptyEnsemble,
Ensemble,
_wrap_with_array,
)
from abtem.core.fft import ifft2
from abtem.core.utils import CopyMixin, EqualityMixin, expand_dims_to_broadcast
from abtem.distributions import (
BaseDistribution,
EnsembleFromDistributions,
validate_distribution,
)
if TYPE_CHECKING:
from abtem.array import ArrayObject, ArrayObjectType, ArrayObjectTypeAlt
from abtem.measurements import BaseMeasurements
from abtem.waves import Waves
else:
ArrayObjectType = TypeVar("ArrayObjectType", bound="ArrayObject")
ArrayObjectTypeAlt = TypeVar("ArrayObjectTypeAlt", bound="ArrayObject")
ArrayObject = object
Waves = object
BaseMeasurements = object
# if TYPE_CHECKING:
# from abtem.array import ArrayObject, ArrayObjectType, ArrayObjectTypeAlt
# from abtem.waves import Waves
WavesType = TypeVar("WavesType", bound=Waves)
[docs]
class ArrayObjectTransform(
Generic[ArrayObjectType, ArrayObjectTypeAlt],
Ensemble,
EqualityMixin,
CopyMixin,
metaclass=ABCMeta,
):
_allow_base_chunks: bool = False
@property
def _num_outputs(self) -> int:
return 1
@property
def metadata(self) -> dict:
"""Metadata added to the waves when applying the transform."""
return {}
@property
def ensemble_shape(self) -> tuple[int, ...]:
"""The shape of the ensemble axes added to the waves when applying the
transform."""
return ()
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
"""Axes metadata describing the ensemble axes added to the waves when applying
the transform."""
return []
def _out_meta(self, array_object: ArrayObjectType) -> tuple[np.ndarray, ...]:
"""
The meta describing the measurement array created when detecting the given
waves.
Parameters
----------
array_object : ArrayObject
The array object to derive the measurement meta from.
Returns
-------
meta : array-like
Empty array.
"""
xp = get_array_module(array_object.device)
return (xp.array((), dtype=self._out_dtype(array_object)[0]),)
def _out_metadata(self, array_object: ArrayObjectType) -> tuple[dict, ...]:
"""
Metadata added to the measurements created when detecting the given waves.
Parameters
----------
array_object : ArrayObject
The array object to derive the metadata from.
Returns
-------
metadata : dict
"""
return ({**array_object.metadata, **self.metadata},)
def _out_dtype(self, array_object: ArrayObjectType) -> tuple[np.dtype, ...]:
"""Datatype of the output array."""
return (array_object.dtype,)
def _out_type(
self, array_object: ArrayObjectType
) -> tuple[type[ArrayObjectType], ...] | tuple[type[ArrayObjectTypeAlt], ...]:
"""
The subtype of the created array object after applying the transform.
Parameters
----------
array_object : ArrayObject
The waves to derive the measurement shape from.
Returns
-------
measurement_type : type of :class:`BaseMeasurements`
"""
return (array_object.__class__,)
def _out_ensemble_shape(
self, array_object: ArrayObjectType
) -> tuple[tuple[int, ...], ...]:
"""
Shape of the measurements created when detecting the given waves.
Parameters
----------
array_object : ArrayObject
The array object to derive the shape of the output array object from.
Returns
-------
measurement_shape : tuple of int
"""
return (self.ensemble_shape + array_object.ensemble_shape,)
def _out_base_shape(
self, array_object: ArrayObjectType
) -> tuple[tuple[int, ...], ...]:
"""
Shape of the array object created by the transformation.
Parameters
----------
array_object : ArrayObject
The waves to derive the measurement shape from.
Returns
-------
measurement_shape : tuple of int
"""
return (array_object.base_shape,)
def _out_shape(self, array_object: ArrayObjectType) -> tuple[tuple[int, ...], ...]:
ensemble_shapes = self._out_ensemble_shape(array_object)
base_shapes = self._out_base_shape(array_object)
return tuple(
ensemble_shape + base_shape
for ensemble_shape, base_shape in zip(ensemble_shapes, base_shapes)
)
def _out_dims(self, array_object: ArrayObjectType) -> tuple[int, ...]:
return tuple(len(dims) for dims in self._out_shape(array_object))
def _out_base_axes_metadata(
self, array_object: ArrayObjectType
) -> tuple[list[AxisMetadata], ...]:
"""
Axes metadata of the created measurements when detecting the given waves.
Parameters
----------
array_object: ArrayObject
The waves to derive the measurement shape from.
Returns
-------
axes_metadata : list of :class:`AxisMetadata`
"""
return (array_object.base_axes_metadata,)
def _out_ensemble_axes_metadata(
self, array_object: ArrayObjectType
) -> tuple[list[AxisMetadata], ...]:
return ([*self.ensemble_axes_metadata, *array_object.ensemble_axes_metadata],)
def _out_axes_metadata(
self, array_object: ArrayObjectType
) -> tuple[list[AxisMetadata], ...]:
return (
[
*self._out_ensemble_axes_metadata(array_object)[0],
*self._out_base_axes_metadata(array_object)[0],
],
)
@abstractmethod
def _calculate_new_array(
self, array_object: ArrayObjectType
) -> np.ndarray | tuple[np.ndarray, ...]:
pass
@abstractmethod
def apply(
self, array_object: ArrayObjectType, max_batch: int | str = "auto"
) -> (
ArrayObjectType
| ArrayObjectTypeAlt
| list[ArrayObjectType | ArrayObjectTypeAlt]
):
pass
[docs]
class EmptyTransform(EmptyEnsemble, ArrayObjectTransform[ArrayObject, ArrayObject]):
def apply(
self, array_object: ArrayObject, max_batch: int | str = "auto"
) -> ArrayObject:
return array_object
[docs]
class EnsembleTransform(
EnsembleFromDistributions, ArrayObjectTransform[ArrayObjectType, ArrayObjectTypeAlt]
):
[docs]
def __init__(self, distributions: tuple[str, ...] = ()):
super().__init__(distributions=distributions)
@staticmethod
def _validate_distribution(distribution):
return validate_distribution(distribution)
def _validate_ensemble_axes_metadata(
self, ensemble_axes_metadata: list[AxisMetadata]
) -> list[AxisMetadata]:
if isinstance(ensemble_axes_metadata, AxisMetadata):
ensemble_axes_metadata = [ensemble_axes_metadata]
assert len(ensemble_axes_metadata) == len(self.ensemble_shape)
return ensemble_axes_metadata
def _get_axes_metadata_from_distributions(
self, **kwargs: Mapping[str, Any]
) -> list[AxisMetadata]:
ensemble_axes_metadata: list[AxisMetadata] = []
for name, value in kwargs.items():
assert name in self._distributions
distribution = getattr(self, name)
if isinstance(distribution, BaseDistribution):
ensemble_axes_metadata += [
ParameterAxis(
values=tuple(distribution),
_ensemble_mean=distribution.ensemble_mean,
**value,
)
]
return ensemble_axes_metadata
[docs]
class WavesTransform(EnsembleTransform[Waves, ArrayObjectType]):
[docs]
def __init__(self, distributions: tuple[str, ...] = ()):
super().__init__(distributions=distributions)
@property
def distributions(self) -> tuple[str, ...]:
return self._distributions
@abstractmethod
def _calculate_new_array(
self, waves: WavesType
) -> np.ndarray | tuple[np.ndarray, ...]:
pass
@abstractmethod
def apply(
self, waves: Waves, max_batch: int | str = "auto"
) -> Waves | ArrayObjectType | list[Waves | ArrayObjectType]:
pass
[docs]
class WavesToMeasurementTransform(WavesTransform[BaseMeasurements]):
def apply(
self, waves: Waves, max_batch: int | str = "auto"
) -> Waves | BaseMeasurements | list[Waves | BaseMeasurements]:
transformed_waves = waves.apply_transform(self, max_batch=max_batch)
if TYPE_CHECKING:
assert isinstance(transformed_waves, BaseMeasurements)
return transformed_waves
[docs]
class WavesToWavesTransform(WavesTransform):
[docs]
def __init__(self, distributions: tuple[str, ...] = ()):
super().__init__(distributions=distributions)
@abstractmethod
def _calculate_new_array(self, waves: Waves) -> np.ndarray:
pass
def _out_type(self, array_object: Waves) -> tuple[type[Waves], ...]:
return (array_object.__class__,)
def apply(self, waves: Waves, max_batch: int | str = "auto") -> Waves:
transformed_waves = waves.apply_transform(self, max_batch=max_batch)
if TYPE_CHECKING:
assert isinstance(transformed_waves, Waves)
return transformed_waves
[docs]
class TransformFromFunc(ArrayObjectTransform[ArrayObject, ArrayObject]):
[docs]
def __init__(self, func, func_kwargs):
self._func = func
self._func_kwargs = func_kwargs
super().__init__()
@property
def ensemble_shape(self) -> tuple[int, ...]:
return ()
@property
def _default_ensemble_chunks(self) -> Chunks:
return ()
@property
def func(self):
return self._func
@property
def func_kwargs(self):
return self._func_kwargs
def _out_type(
self, array_object: ArrayObject
) -> tuple[type[ArrayObject], ...]:
return (array_object.__class__,)
def _calculate_new_array(self, array_object: ArrayObject) -> np.ndarray:
return self.func(array_object, **self.func_kwargs)
def _partition_args(self, chunks: Optional[Chunks] = 1, lazy: bool = True) -> tuple:
return ()
@classmethod
def _partial_transform(cls, *args, **kwargs) -> np.ndarray:
new_transform = _wrap_with_array(cls(**kwargs), ndims=0)
return new_transform
def _from_partitioned_args(self) -> Callable[..., np.ndarray]:
kwargs = self._copy_kwargs()
return partial(self._partial_transform, **kwargs)
def apply(
self, array_object: ArrayObjectType, max_batch: int | str = "auto"
) -> ArrayObjectType:
new_array_object = array_object.apply_transform(self, max_batch=max_batch)
assert isinstance(new_array_object, array_object.__class__)
return new_array_object
[docs]
def join_tuples(tuples: tuple[tuple[Any, ...], ...]) -> tuple[Any, ...]:
return tuple(item for subtuple in tuples for item in subtuple)
[docs]
class ReciprocalSpaceMultiplication(WavesToWavesTransform):
"""
Wave function transformation for multiplying each member of an ensemble of wave
functions with an array.
Parameters
----------
in_place: bool, optional
If True, the array representing the waves may be modified in-place.
distributions : tuple of str, optional
Names of properties that may be described by a distribution.
"""
[docs]
def __init__(
self,
in_place: bool = False,
distributions: tuple[str, ...] = (),
):
self._in_place = in_place
super().__init__(distributions=distributions)
@property
def in_place(self) -> bool:
"""The array representing the waves may be modified in-place."""
return self._in_place
@abstractmethod
def _evaluate_kernel(self, waves: Waves) -> np.ndarray:
pass
def _calculate_new_array(self, waves: Waves) -> np.ndarray:
real_space_in = not waves.reciprocal_space
waves = waves.ensure_reciprocal_space(overwrite_x=self.in_place)
kernel = self._evaluate_kernel(waves)
array = waves._eager_array
kernel, new_array = expand_dims_to_broadcast(
kernel, array, match_dims=((-2, -1), (-2, -1))
)
xp = get_array_module(array)
kernel = xp.array(kernel)
if self.in_place:
new_array *= kernel
else:
new_array = new_array * kernel
if real_space_in:
new_array = ifft2(new_array, overwrite_x=self.in_place)
return new_array