"""Module for describing abTEM measurement objects."""
from __future__ import annotations
import copy
import itertools
import warnings
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from numbers import Number
from types import ModuleType
from typing import (
TYPE_CHECKING,
Callable,
Dict,
Hashable,
Optional,
Self,
Sequence,
SupportsFloat,
Type,
TypeVar,
)
import dask.array as da
import numpy as np
from ase import Atom
from ase.cell import Cell
from matplotlib.axes import Axes
from numba import jit # type: ignore
from abtem.array import ArrayObject, _validate_array_items, stack
from abtem.core import config
from abtem.core.axes import (
AxisMetadata,
LinearAxis,
NonLinearAxis,
RealSpaceAxis,
ReciprocalSpaceAxis,
ScaleAxis,
ScanAxis,
)
from abtem.core.backend import asnumpy, cp, get_array_module, get_ndimage_module
from abtem.core.complex import abs2
from abtem.core.energy import energy2wavelength
from abtem.core.fft import fft_crop, fft_interpolate
from abtem.core.grid import (
adjusted_gpts,
polar_spatial_frequencies,
spatial_frequencies,
)
from abtem.core.units import get_conversion_factor
from abtem.core.utils import (
CopyMixin,
EqualityMixin,
get_dtype,
is_broadcastable,
label_to_index,
)
# from abtem.distributions import BaseDistribution
from abtem.noise import NoiseTransform, ScanNoiseTransform
from abtem.visualize.artists import LinesArtist
from abtem.visualize.visualizations import Visualization
from abtem.visualize.widgets import ImageGUI, LinesGUI, ScatterGUI
interpolate_bilinear_cuda: Optional[Callable] = None
sum_run_length_encoded: Optional[Callable] = None
sum_run_length_encoded_cuda: Optional[Callable] = None
if cp is not None:
from abtem.core._cuda import interpolate_bilinear as interpolate_bilinear_cuda
from abtem.core._cuda import sum_run_length_encoded as sum_run_length_encoded_cuda
xr: Optional[ModuleType] = None
try:
import xarray as xr
except ImportError:
xr = None
pd: Optional[ModuleType] = None
try:
import pandas as pd # type: ignore
except ImportError:
pd = None
if TYPE_CHECKING:
from abtem.waves import BaseWaves
BaseMeasurementsSubclass = TypeVar("BaseMeasurementsSubclass", bound="BaseMeasurements")
def _scanned_measurement_type(
measurement: BaseMeasurements | BaseWaves,
) -> Type[RealSpaceLineProfiles | Images | MeasurementsEnsemble]:
if len(_scan_shape(measurement)) == 1:
return RealSpaceLineProfiles
elif len(_scan_shape(measurement)) == 2:
return Images
else:
return MeasurementsEnsemble
def _bin_extent(n: int) -> tuple[float, float]:
if n % 2 == 0:
return -n // 2 - 0.5, n // 2 - 0.5
else:
return -n // 2 + 0.5, n // 2 + 0.5
def _reduced_scanned_images_or_line_profiles(
new_array: np.ndarray,
old_measurement: BaseMeasurements,
metadata: dict | None = None,
) -> RealSpaceLineProfiles | Images | MeasurementsEnsemble | np.ndarray:
if metadata is None:
metadata = {}
metadata = {**old_measurement.metadata, **metadata}
ensemble_axes = tuple(range(len(old_measurement.ensemble_shape)))
source = _scan_axes(old_measurement)
destination = tuple(range(len(ensemble_axes) - len(source), len(ensemble_axes)))
scan_axes_metadata = [old_measurement.ensemble_axes_metadata[i] for i in source]
ensemble_axes_metadata = [
m
for i, m in enumerate(old_measurement.ensemble_axes_metadata)
if i not in source
]
if source != destination:
xp = get_array_module(new_array)
if old_measurement.is_lazy:
new_array = da.moveaxis(new_array, source, destination)
else:
new_array = xp.moveaxis(new_array, source, destination)
sampling: float | tuple[float, ...]
if len(scan_axes_metadata) == 1:
sampling = _scan_sampling(old_measurement)[-1]
return RealSpaceLineProfiles(
new_array,
sampling=sampling,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
elif len(scan_axes_metadata) == 2:
sampling = _scan_sampling(old_measurement)
sampling = (sampling[-2], sampling[-1])
images = Images(
new_array,
sampling=sampling,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
return images
elif _scanned_measurement_type(old_measurement) is MeasurementsEnsemble:
ensemble_axes_metadata = old_measurement.ensemble_axes_metadata
measurement_ensemble = MeasurementsEnsemble(
new_array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
return measurement_ensemble
else:
return new_array
def _scan_axes(measurement):
num_scan_axes = 0
scan_axes = ()
for i, axis in enumerate(measurement.ensemble_axes_metadata):
if num_scan_axes == 2:
break
if isinstance(axis, ScanAxis) and axis._main is True:
scan_axes += (i,)
num_scan_axes += 1
scan_axes = scan_axes[-2:]
return scan_axes
def _scan_sampling(measurements: BaseMeasurements | BaseWaves) -> tuple[float, ...]:
return tuple(
measurements.axes_metadata[i].sampling for i in _scan_axes(measurements)
)
def _scan_axes_metadata(
measurements: BaseMeasurements | BaseWaves,
) -> list[AxisMetadata]:
return [measurements.axes_metadata[i] for i in _scan_axes(measurements)]
def _scan_shape(measurements: BaseMeasurements | BaseWaves) -> tuple[int, ...]:
return tuple(measurements.shape[i] for i in _scan_axes(measurements))
def _scan_area_per_pixel(measurements):
if len(_scan_sampling(measurements)) == 2:
return np.prod(_scan_sampling(measurements))
else:
raise RuntimeError("Cannot infer pixel area from axes metadata.")
def _scan_extent(measurement):
extent = ()
for n, metadata in zip(_scan_shape(measurement), _scan_axes_metadata(measurement)):
extent += (metadata.sampling * n,)
return extent
def _annular_detector_mask(
gpts: tuple[int, int],
sampling: tuple[float, float],
inner: float,
outer: float,
offset: tuple[float, float] = (0.0, 0.0),
fftshift: bool = False,
xp: ModuleType = np,
) -> np.ndarray | list[np.ndarray]:
kx, ky = spatial_frequencies(
gpts, (1 / sampling[0] / gpts[0], 1 / sampling[1] / gpts[1]), False, xp
)
k2 = kx[:, None] ** 2 + ky[None] ** 2
bins = (k2 >= inner**2) & (k2 < outer**2)
if np.any(np.array(offset) != 0.0):
offset = (
int(round(offset[0] / sampling[0])),
int(round(offset[1] / sampling[1])),
)
# if (abs(offset[0]) > bins[0]) or (abs(offset[1]) > bins[1]):
# raise RuntimeError("Detector offset exceeds maximum detected angle.")
bins = np.roll(bins, offset, (0, 1))
if fftshift:
bins = xp.fft.fftshift(bins)
return bins
def _polar_detector_bins(
gpts: tuple[int, int],
sampling: tuple[float, float],
inner: float,
outer: float,
nbins_radial: int,
nbins_azimuthal: int,
rotation: float = 0.0,
offset: tuple[float, float] = (0.0, 0.0),
fftshift: bool = False,
return_indices: bool = False,
) -> np.ndarray | list[np.ndarray]:
alpha, phi = polar_spatial_frequencies(
gpts, sampling=(1 / sampling[0] / gpts[0], 1 / sampling[1] / gpts[1])
)
phi = (phi - rotation) % (2 * np.pi)
radial_bins = -np.ones(gpts, dtype=int)
valid = (alpha >= inner) & (alpha < outer)
radial_bins[valid] = nbins_radial * (alpha[valid] - inner) / (outer - inner)
angular_bins = np.floor(nbins_azimuthal * (phi / (2 * np.pi)))
angular_bins = np.clip(angular_bins, 0, nbins_azimuthal - 1).astype(int)
bins = -np.ones(gpts, dtype=int)
bins[valid] = angular_bins[valid] + radial_bins[valid] * nbins_azimuthal
if np.any(np.array(offset) != 0.0):
offset = (
int(round(offset[0] / sampling[0])),
int(round(offset[1] / sampling[1])),
)
# if (abs(offset[0]) > bins[0]) or (abs(offset[1]) > bins[1]):
# raise RuntimeError("Detector offset exceeds maximum detected angle.")
bins = np.roll(bins, offset, (0, 1))
if fftshift:
bins = np.fft.fftshift(bins)
if return_indices:
indices = []
for i in label_to_index(bins, nbins_radial * nbins_azimuthal - 1):
indices.append(i)
return indices
else:
return bins
@jit(nopython=True, nogil=True, fastmath=True)
def _sum_run_length_encoded(array, result, separators):
for x in range(result.shape[1]):
for i in range(result.shape[0]):
for j in range(separators[x], separators[x + 1]):
result[i, x] += array[i, j]
def _interpolate_stack(
array: np.ndarray, positions: np.ndarray, mode: str, order: int, **kwargs
):
map_coordinates = get_ndimage_module(array).map_coordinates
xp = get_array_module(array)
positions_shape = positions.shape
positions = positions.reshape((-1, 2))
old_shape = array.shape
array = array.reshape((-1,) + array.shape[-2:])
array = xp.pad(array, ((0, 0), (2 * order,) * 2, (2 * order,) * 2), mode=mode)
positions = positions + 2 * order
output = xp.zeros((array.shape[0], positions.shape[0]), dtype=array.dtype)
for i in range(array.shape[0]):
map_coordinates(array[i], positions.T, output=output[i], order=order, **kwargs)
output = output.reshape(old_shape[:-2] + positions_shape[:-1])
return output
[docs]
class BaseMeasurements(ArrayObject, EqualityMixin, CopyMixin, metaclass=ABCMeta):
"""
Base class for all measurement types.
Parameters
----------
array : ndarray
Array containing data of type `float` or `complex`.
ensemble_axes_metadata : list of AxisMetadata, optional
Metadata associated with an ensemble axis.
metadata : dict, optional
A dictionary defining simulation metadata.
"""
[docs]
def __init__(
self,
array: np.ndarray | da.core.Array,
ensemble_axes_metadata: list[AxisMetadata],
metadata: dict,
):
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
@abstractmethod
def base_axes_metadata(self) -> list:
"""List of AxisMetadata of the base axes."""
@property
def metadata(self) -> dict:
"""Metadata describing the measurement."""
return self._metadata
def _get_from_metadata(self, key: Hashable):
if key not in self.metadata.keys():
raise RuntimeError(f"{key} not in measurement metadata.")
return self.metadata[key]
def _check_is_complex(self):
if not np.iscomplexobj(self.array):
raise RuntimeError("Function not implemented for non-complex measurements.")
[docs]
def real(self) -> Self:
"""Returns the real part of a complex-valued measurement."""
self._check_is_complex()
self.metadata["label"] = "real"
self.metadata["units"] = "arb. unit"
return self._apply_element_wise_func(get_array_module(self.array).real)
[docs]
def imag(self) -> Self:
"""Returns the imaginary part of a complex-valued measurement."""
self._check_is_complex()
self.metadata["label"] = "imaginary"
self.metadata["units"] = "arb. unit"
return self._apply_element_wise_func(get_array_module(self.array).imag)
[docs]
def phase(self) -> Self:
"""Calculates the phase of a complex-valued measurement."""
self._check_is_complex()
self.metadata["label"] = "phase"
self.metadata["units"] = "rad."
return self._apply_element_wise_func(get_array_module(self.array).angle)
[docs]
def abs(self) -> Self:
"""Calculates the absolute value of a complex-valued measurement."""
# self._check_is_complex()
self.metadata["label"] = "amplitude"
self.metadata["units"] = "arb. unit"
return self._apply_element_wise_func(get_array_module(self.array).abs)
[docs]
def intensity(self) -> Self:
"""Calculates the squared norm of a complex-valued measurement."""
self._check_is_complex()
self.metadata["label"] = "intensity"
self.metadata["units"] = "arb. unit"
return self._apply_element_wise_func(abs2)
[docs]
def relative_difference(
self, other: BaseMeasurements, min_relative_tol: float = 0.0
) -> Self:
"""
Calculates the relative difference with respect to another compatible
measurement.
Parameters
----------
other : BaseMeasurements
Measurement to which the difference is calculated.
min_relative_tol : float
Avoids division by zero errors by defining a minimum value of the divisor in
the relative difference.
Returns
-------
difference : BaseMeasurements
The relative difference as a measurement of the same type.
"""
if not isinstance(other, self.__class__):
raise RuntimeError("Measurements are not of the same type.")
difference = self - other
xp = get_array_module(self.array)
valid = xp.abs(self.array) >= min_relative_tol * self.array.max()
difference._array[valid] /= self.array[valid]
difference._array[valid == 0] = np.nan
difference._array *= 100.0
difference.metadata["label"] = "Relative difference"
difference.metadata["units"] = "%"
difference.metadata["tex_units"] = r"$\%$"
return difference
[docs]
def normalize_ensemble(self, scale: str = "max", shift: str = "mean"):
"""
Normalize the ensemble by shifting ad scaling each member.
Parameters
----------
scale : {'max', 'min', 'sum', 'mean', 'ptp'}
shift : {'max', 'min', 'sum', 'mean', 'ptp'}
Returns
-------
normalized_measurements : BaseMeasurements or subclass of _BaseMeasurement
"""
if shift != "none":
array = self.array - getattr(np, shift)(self.array, axis=-1, keepdims=True)
else:
array = self.array
array = array / getattr(np, scale)(self.array, axis=-1, keepdims=True)
kwargs = self._copy_kwargs(exclude=("array",))
return self.__class__(array, **kwargs)
[docs]
def reduce_ensemble(self) -> Self:
"""
Calculates the mean of an ensemble measurement (e.g. of frozen phonon
configurations).
"""
axis = tuple(
i
for i, axis in enumerate(self.axes_metadata)
if hasattr(axis, "_ensemble_mean") and axis._ensemble_mean
)
if len(axis) == 0:
return self
if np.iscomplexobj(self.array):
warnings.warn("a complex reducing a complex measurement")
return self.mean(axis=axis)
def _apply_element_wise_func(self, func: Callable) -> Self:
d = self._copy_kwargs(exclude=("array",))
d["array"] = func(self.array)
return self.__class__(**d)
@property
@abstractmethod
def _area_per_pixel(self):
if len(_scan_sampling(self)) == 2:
return np.prod(_scan_sampling(self))
else:
raise RuntimeError("Cannot infer pixel area from axes metadata.")
[docs]
def poisson_noise(
self,
dose_per_area: SupportsFloat | Sequence[SupportsFloat] | None = None,
total_dose: SupportsFloat | Sequence[SupportsFloat] | None = None,
samples: int = 1,
seed: Optional[int] = None,
) -> Self:
"""
Add Poisson noise (i.e. shot noise) to a measurement corresponding to the
provided 'total_dose' (per measurement if applied to an ensemble) or
'dose_per_area' (not applicable for single measurements).
Parameters
----------
dose_per_area : float, sequence of float, optional
The irradiation dose [electrons per Å:sup:`2`]. May be given as a single
value or as a sequence of values for each ensemble member.
total_dose : float, optional
The irradiation dose per diffraction pattern.
samples : int, optional
The number of samples to draw from a Poisson distribution. If this is
greater than 1, an additional ensemble axis will be added to the
measurement.
seed : int, optional
Seed the random number generator.
Returns
-------
noisy_measurement : BaseMeasurements or subclass of _BaseMeasurement
The noisy measurement.
"""
dtype = get_dtype(complex=False)
wrong_dose_error = RuntimeError(
"Provide one of 'dose_per_area' or 'total_dose'."
)
if (dose_per_area is not None) and (total_dose is not None):
raise wrong_dose_error
if dose_per_area is not None and total_dose is None:
total_dose = self._area_per_pixel * dose_per_area
elif total_dose is None:
pass
elif total_dose is not None:
if dose_per_area is not None:
raise wrong_dose_error
else:
raise wrong_dose_error
total_dose = np.array(total_dose, dtype=dtype)
# print(seed, samples)
transform = NoiseTransform(dose=total_dose, samples=samples, seeds=seed)
measurement = transform.apply(self)
return measurement
def _scale_axis_from_metadata(self):
return ScaleAxis(
label=self.metadata.get("label", ""),
units=self.metadata.get("units", ""),
tex_label=None,
)
def to_measurement_ensemble(self):
return MeasurementsEnsemble(
array=self.array,
ensemble_axes_metadata=self.axes_metadata,
metadata=self.metadata,
)
[docs]
@abstractmethod
def show(self, *args, **kwargs):
"""Documented in subclasses"""
[docs]
def periodic_crop(
array: np.ndarray, corner: tuple[float, float], new_shape: tuple[int, int]
) -> np.ndarray:
"""
Crop an array with periodic boundary conditions. The cropping region is wrapped
around the array.
Parameters
----------
array : ndarray
The array to crop.
corner : two floats
The corner of the cropping region.
new_shape : two ints
The shape of the cropping region.
Returns
-------
cropped_array : ndarray
The cropped array.
"""
xp = get_array_module(array)
if (
(corner[0] > 0)
& (corner[1] > 0)
& (corner[0] + new_shape[0] < array.shape[-2])
& (corner[1] + new_shape[1] < array.shape[-1])
):
array = array[
...,
corner[0] : corner[0] + new_shape[0],
corner[1] : corner[1] + new_shape[1],
]
return array
x = xp.arange(corner[0], corner[0] + new_shape[0], dtype=xp.int64) % array.shape[-2]
y = xp.arange(corner[1], corner[1] + new_shape[1], dtype=xp.int64) % array.shape[-1]
x, y = xp.meshgrid(x, y, indexing="ij")
array = array[..., x.ravel(), y.ravel()].reshape(array.shape[:-2] + new_shape)
return array
[docs]
def integrate_disc(
measurement: Images | DiffractionPatterns,
position: np.ndarray,
radius: float,
return_mean: bool = False,
border: str = "wrap",
interpolate: Optional[tuple[float, bool]] = None,
) -> float:
"""
Integrate the values of a 2d measurement on a disc-shaped region.
Parameters
----------
measurement : 2d measurement
The measurement to integrate
position : two floats
Center of disc-shaped integration region
radius : float
Radius of disc-shaped integration region
return_mean : bool
If true return the mean, otherwise return the sum.
border : str
Specify how to treat integration regions that cross the image border. The valid
values and their behaviours are:
'wrap'
The measurement is extended by wrapping around to the opposite edge.
'raise'
Raise an error if the integration region crosses the measurement border.
interpolate : float or False
The image will be interpolated to this sampling. Units of Angstrom.
Returns
-------
float
Integral value
"""
if interpolate is not None:
measurement = measurement.interpolate(interpolate)
x_axis = measurement.base_axes_metadata[-2]
y_axis = measurement.base_axes_metadata[-1]
assert isinstance(x_axis, RealSpaceAxis) and isinstance(y_axis, RealSpaceAxis)
offset = (x_axis.offset, y_axis.offset)
position = np.array(position) - offset
integration_shape = (
int(np.ceil(2 * radius / x_axis.sampling)),
int(np.ceil(2 * radius / y_axis.sampling)),
)
corner = (
int(np.floor(position[-2] / x_axis.sampling)) - integration_shape[0] // 2,
int(np.floor(position[-1] / y_axis.sampling)) - integration_shape[1] // 2,
)
if border == "wrap":
cropped = periodic_crop(measurement.array, corner, integration_shape)
elif border == "raise":
if (
(np.any(np.array(corner) < 0))
| (corner[0] + integration_shape[0] > measurement.array.shape[0])
| (corner[1] + integration_shape[1] > measurement.array.shape[1])
):
raise RuntimeError("The integration region is outside the image.")
cropped = periodic_crop(measurement.array, corner, integration_shape)
else:
raise RuntimeError('border must be one of "wrap" or "raise"')
x = np.linspace(
0.0,
cropped.shape[-2] * x_axis.sampling,
cropped.shape[-2],
endpoint=x_axis.endpoint,
)
y = np.linspace(
0.0,
cropped.shape[-1] * measurement.sampling[-1],
cropped.shape[-1],
endpoint=y_axis.endpoint,
)
x, y = np.meshgrid(x, y, indexing="ij")
cropped_position = np.array(position)[:2] - (
corner[-2] * x_axis.sampling,
corner[-1] * y_axis.sampling,
)
r = np.sqrt((x - cropped_position[-2]) ** 2 + (y - cropped_position[-1]) ** 2)
mean_sampling = (x_axis.sampling + y_axis.sampling) / 2
mask = 1 - np.clip((r - radius) / mean_sampling, 0, 1)
if return_mean:
return (cropped * mask).sum((-2, -1)) / mask.sum((-2, -1))
else:
return (cropped * mask).sum((-2, -1))
[docs]
class MeasurementsEnsemble(BaseMeasurements):
_base_dims = 0
[docs]
def __init__(
self,
array: np.ndarray,
ensemble_axes_metadata: list[AxisMetadata],
metadata: dict | None = None,
):
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def _area_per_pixel(self):
raise RuntimeError("Cannot infer pixel area from metadata.")
@property
def base_axes_metadata(self):
return []
[docs]
def show(
self,
type: str = "lines",
ax: Optional[Axes] = None,
power: float = 1.0,
common_scale: bool = False,
explode: bool | Sequence[int] = (),
overlay: bool | Sequence[int] = (),
figsize: Optional[tuple[int, int]] = None,
title: bool | str = True,
units: Optional[str] = None,
interact: bool = False,
display: bool = True,
**kwargs,
) -> Visualization:
"""
Show the image(s) using matplotlib.
Parameters
----------
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axis. This is not available for exploded
plots.
power : float
Show image on a power scale.
explode : bool, optional
If True, a grid of images is created for all the items of the last two
ensemble axes. If False, the first ensemble item is shown. May be given as a
sequence of axis indices to create a grid of images from the specified axes.
The default is determined by the axis metadata.
overlay : bool or sequence of int, optional
If True, all line profiles in the ensemble are shown in a single plot.
If False, only the first ensemble item is shown. May be given as a sequence
of axis indices to specify which line profiles in the ensemble to show
together. The default is determined by the axis metadata.
figsize : two int, optional
The figure size given as width and height in inches, passed to
`matplotlib.pyplot.figure`.
title : bool or str, optional
Set the column title of the images. If True is given instead of a string the
title will be given by the value corresponding to the "name" key of the axes
metadata dictionary, if this item exists.
units : str
The units used for the x and y axes. The given units must be compatible with
the axes of the images.
interact : bool
If True, create an interactive visualization. This requires enabling the
ipympl Matplotlib backend.
display : bool, optional
If True (default) the figure is displayed immediately.
Returns
-------
measurement_visualization_2d : VisualizationImshow
"""
if not interact:
self.compute()
# scale_axis = self._scale_axis_from_metadata()
# base_axes_metadata = self._plot_base_axes_metadata(units)
# array = self.array
# raise RuntimeError("Cannot infer pixel area from metadata.")
# if display_axes != (-2, -1):
# array = np.moveaxis(self.array, source=display_axes, destination=(-2, -1))
# display_axes = normalize_axes(display_axes, self.shape)
# base_axes_metadata = [self.axes_metadata[i] for i in display_axes]
# ensemble_axes_metadata = [
# self.axes_metadata[i]
# for i in range(len(self.shape))
# if i not in display_axes
# ]
artist_type = LinesArtist
visualization = Visualization(
measurement=self,
ax=ax,
artist_type=artist_type,
# power=power,
aspect=False,
share_x=True,
share_y=common_scale,
common_scale=common_scale,
explode=explode,
overlay=overlay,
figsize=figsize,
# interact=interact,
title=title,
**kwargs,
)
if common_scale is False and visualization._explode:
visualization.axes.set_sizes(padding=0.8)
return visualization
# def show(
# self,
# ax: Axes = None,
# common_scale: bool = True,
# explode: bool | Sequence[int] = None,
# overlay: bool | Sequence[int] = None,
# figsize: tuple[int, int] = None,
# title: str = None,
# units: str = None,
# legend: bool = False,
# interact: bool = False,
# display: bool = True,
# **kwargs,
# ):
# # if not interact:
# # self.compute()
#
# visualization = VisualizationLines(
# array=self.array,
# coordinate_axes=self.ensemble_axes_metadata[-1:],
# scale_axis=self._scale_axis_from_metadata(),
# ensemble_axes=self.ensemble_axes_metadata[:-1],
# ax=ax,
# common_scale=common_scale,
# explode=explode,
# overlay=overlay,
# figsize=figsize,
# interact=interact,
# title=title,
# **kwargs,
# )
#
# if not display and not interact:
# plt.close()
#
# if interact and display:
# from IPython.display import display as ipython_display
#
# ipython_display(visualization.layout_widgets())
#
# return visualization
class _BaseMeasurement2D(BaseMeasurements):
_base_dims = 2
@property
def base_shape(self) -> tuple[int, int]:
return super().base_shape[-2], super().base_shape[-1]
@abstractmethod
def _get_1d_equivalent(self):
pass
@property
@abstractmethod
def sampling(self) -> tuple[float, float]:
"""
Sampling of the measurements in `x` and `y` [Å] or [1/Å].
"""
@property
@abstractmethod
def extent(self) -> tuple[float, float]:
"""
Extent of measurements in `x` and `y` [Å] or [1/Å].
"""
@property
@abstractmethod
def offset(self) -> tuple[float, float]:
"""
The offset of the origin of the measurement coordinates [Å] or [1/Å].
"""
def interpolate_line(
self,
start: tuple[float, float] | Atom | None = None,
end: tuple[float, float] | Atom | None = None,
sampling: float | None = None,
gpts: int | None = None,
width: float = 0.0,
margin: float = 0.0,
order: int = 3,
endpoint: bool = False,
fractional: bool = False,
):
"""
Interpolate image(s) along a given line. Either 'sampling' or 'gpts' must be
provided.
Parameters
----------
start : two float, Atom, optional
Starting position of the line [Å] (alternatively taken from a selected
atom).
end : two float, Atom, optional
Ending position of the line [Å] (alternatively taken from a selected atom).
sampling : float
Sampling of grid points along the line [1 / Å].
gpts : int
Number of grid points along the line.
width : float, optional
The interpolation will be averaged across a perpendicular distance equal to
this width.
margin : float or tuple of float, optional
Add margin [Å] to the start and end interpolated line.
order : int, optional
The spline interpolation order.
endpoint : bool
Sets whether the ending position is included or not.
fractional : bool
If True, use fractional coordinates with respect to the extent of the
measurement.
Returns
-------
line_profiles : RealSpaceLineProfiles
The interpolated line(s).
"""
from abtem.scan import LineScan
# if self.is_complex:
# raise NotImplementedError
if (sampling is None) and (gpts is None):
sampling = min(self.sampling)
xp = get_array_module(self.array)
if start is None:
start = (0.0, 0.0)
if end is None and fractional:
end = (0.0, 1.0)
elif end is None:
end = (0.0, self.extent[0])
if fractional:
extent = self.extent
else:
extent = None
scan = LineScan(
start=start,
end=end,
gpts=gpts,
sampling=sampling,
endpoint=endpoint,
potential=extent,
fractional=fractional,
)
if margin != 0.0:
scan.add_margin(margin)
positions = xp.asarray(
(scan.get_positions(lazy=False) - self.offset) / self.sampling
)
if width:
direction = xp.array(scan.end) - xp.array(scan.start)
direction = direction / xp.linalg.norm(direction)
perpendicular_direction = xp.array([-direction[1], direction[0]])
n = xp.floor(width / min(self.sampling) / 2) * 2 + 1
perpendicular_positions = (
xp.linspace(-n / 2, n / 2, int(n))[:, None]
* perpendicular_direction[None]
)
positions = perpendicular_positions[None, :] + positions[:, None]
if self.is_lazy:
# raise NotImplementedError("Lazy interpolation not implemented.")
# TDOO: Implement lazy interpolation
base_axes = tuple(range(len(self.base_shape)))
chunks = self.array.chunks[:-2] + (positions.shape[0],)
new_axis = (base_axes[0],)
if width:
new_axis = new_axis + (new_axis[0] + 1,)
chunks = chunks + (1,)
array = da.map_blocks(
_interpolate_stack,
self.array,
positions=positions,
mode="wrap",
order=order,
drop_axis=base_axes,
new_axis=new_axis,
chunks=chunks,
meta=xp.array((), dtype=get_dtype(complex=False)),
)
else:
array = _interpolate_stack(self.array, positions, mode="wrap", order=order)
metadata = copy.copy(self.metadata)
metadata.update(scan.metadata)
metadata["label"] = "intensity"
metadata["units"] = "arb. unit"
if width:
array = array.mean(-1)
metadata["width"] = width
ensemble_axes_metadata = self.ensemble_axes_metadata
return self._get_1d_equivalent()(
array=array,
sampling=scan.sampling,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
def gaussian_filter(
self,
sigma: float | tuple[float, float],
boundary: str = "periodic",
cval: float = 0.0,
):
"""
Apply 2D gaussian filter to measurements.
Parameters
----------
sigma : float or two float
Standard deviation for the Gaussian kernel in the `x` and `y`-direction. If
given as a single number, the standard deviation is equal for both axes.
boundary : {'periodic', 'reflect', 'constant'}
The boundary parameter determines how the images are extended beyond their
boundaries when the filter overlaps with a border.
``periodic`` :
The images are extended by wrapping around to the opposite edge. Use
this mode for periodic (default).
``reflect`` :
The images are extended by reflecting about the edge of the last
pixel.
``constant`` :
The images are extended by filling all values beyond the edge with
the same constant value, defined by the 'cval' parameter.
cval : scalar, optional
Value to fill past edges in spline interpolation input if boundary is
'constant' (default is 0.0).
Returns
-------
filtered_images : Images
The filtered image(s).
"""
xp = get_array_module(self.array)
gaussian_filter = get_ndimage_module(self.array).gaussian_filter
if boundary == "periodic":
mode = "wrap"
elif boundary in ("reflect", "constant"):
mode = boundary
else:
raise ValueError()
if np.isscalar(sigma):
sigma = (sigma,) * 2
sigma = (0,) * (len(self.shape) - 2) + tuple(
s / d for s, d in zip(sigma, self.sampling)
)
if self.is_lazy:
depth = tuple(
min(int(np.ceil(4.0 * s)), n) for s, n in zip(sigma, self.shape)
)
array = da.map_overlap(
gaussian_filter,
self.array,
sigma=sigma,
boundary=boundary,
mode=mode,
cval=cval,
depth=depth,
meta=xp.array((), dtype=xp.float32),
)
else:
array = gaussian_filter(self.array, sigma=sigma, mode=mode, cval=cval)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
return self.__class__(**kwargs)
def interpolate_line_at_position(
self,
center: tuple[float, float] | Atom,
angle: float,
extent: float,
gpts: Optional[int] = None,
sampling: Optional[float] = None,
width: float = 0.0,
order: int = 3,
endpoint: bool = True,
):
"""
Interpolate image(s) along a line centered at a specified position.
Parameters
----------
center : two float
Center position of the line [Å]. May be given as an Atom.
angle : float
Angle of the line [deg.].
extent : float
Extent of the line [Å].
gpts : int
Number of grid points along the line.
sampling : float
Sampling of grid points along the line [Å].
width : float, optional
The interpolation will be averaged across a perpendicular distance equal to
this width.
order : int, optional
The spline interpolation order.
endpoint : bool
Sets whether the ending position is included or not.
Returns
-------
line_profiles : RealSpaceLineProfiles or ReciprocalSpaceProfiles
The interpolated line(s).
"""
from abtem.scan import LineScan
scan = LineScan.at_position(center=center, extent=extent, angle=angle)
return self.interpolate_line(
scan.start,
scan.end,
gpts=gpts,
sampling=sampling,
width=width,
order=order,
endpoint=endpoint,
)
def show(
self,
ax: Optional[Axes] = None,
cbar: bool = False,
cmap: Optional[str] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
power: float = 1.0,
common_color_scale: bool = False,
explode: bool | Sequence[int] = (),
overlay: bool | Sequence[int] = (),
figsize: Optional[tuple[int, int]] = None,
title: bool | str = True,
units: Optional[str] = None,
interact: bool = False,
display: bool = True,
**kwargs,
) -> Visualization:
"""
Show the image(s) using matplotlib.
Parameters
----------
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axis. This is not available for exploded
plots.
cbar : bool, optional
Add colorbar(s) to the image(s). The size and padding of the colorbars may
be adjusted using the `set_cbar_size` and `set_cbar_padding` methods.
cmap : str, optional
Matplotlib colormap name used to map scalar data to colors. If the
measurement is complex the colormap must be one of 'hsv' or 'hsluv'.
vmin : float, optional
Minimum of the intensity color scale. Default is the minimum of the array
values.
vmax : float, optional
Maximum of the intensity color scale. Default is the maximum of the array
values.
power : float
Show image on a power scale.
common_color_scale : bool, optional
If True all images in an image grid are shown on the same colorscale, and a
single colorbar is created (if it is requested). Default is False.
explode : bool, optional
If True, a grid of images is created for all the items of the last two
ensemble axes. If False, the first ensemble item is shown. May be given as a
sequence of axis indices to create a grid of images from the specified axes.
The default is determined by the axis metadata.
overlay : bool or sequence of int, optional
If True, all line profiles in the ensemble are shown in a single plot.
If False, only the first ensemble item is shown. May be given as a sequence
of axis indices to specify which line profiles in the ensemble to show
together. The default is determined by the axis metadata.
figsize : two int, optional
The figure size given as width and height in inches, passed to
`matplotlib.pyplot.figure`.
title : bool or str, optional
Set the column title of the images. If True is given instead of a string the
title will be given by the
value corresponding to the "name" key of the axes metadata dictionary, if
this item exists.
units : str
The units used for the x and y axes. The given units must be compatible with
the axes of the images.
interact : bool
If True, create an interactive visualization. This requires enabling the
ipympl Matplotlib backend.
display : bool, optional
If True (default) the figure is displayed immediately.
Returns
-------
measurement_visualization_2d : VisualizationImshow
"""
visualization = Visualization(
measurement=self,
ax=ax,
common_scale=common_color_scale,
figsize=figsize,
title=title,
aspect=True,
share_x=True,
share_y=True,
explode=explode,
overlay=overlay,
interactive=not interact and display,
value_limits=(vmin, vmax),
power=power,
cmap=cmap,
cbar=cbar,
units=units,
**kwargs,
)
if interact:
visualization.interact(ImageGUI, display=display)
return visualization
[docs]
class Images(_BaseMeasurement2D):
"""
A collection of 2D measurements such as HRTEM or STEM-ADF images. May be used to
represent a reconstructed phase.
Parameters
----------
array : np.ndarray
2D or greater array containing data of type `float` or `complex`. The
second-to-last and last
dimensions are the image `y`- and `x`-axis, respectively.
sampling : two float
Lateral sampling of images in `x` and `y` [Å].
ensemble_axes_metadata : list of AxisMetadata, optional
List of metadata associated with the ensemble axes. The length and item order
must match the ensemble axes.
metadata : dict, optional
A dictionary defining measurement metadata.
"""
[docs]
def __init__(
self,
array: da.core.Array | np.array,
sampling: float | tuple[float, float],
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[Dict] = None,
):
if np.isscalar(sampling):
sampling = (float(sampling),) * 2
else:
sampling = float(sampling[0]), float(sampling[1])
self._sampling = sampling
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
def _get_1d_equivalent(self):
return RealSpaceLineProfiles
@property
def _area_per_pixel(self):
return np.prod(self.sampling)
@property
def sampling(self) -> tuple[float, float]:
return self._sampling
@property
def offset(self) -> tuple[float, float]:
return 0.0, 0.0
@property
def extent(self) -> tuple[float, float]:
return (
self.sampling[0] * self.base_shape[0],
self.sampling[1] * self.base_shape[1],
)
@property
def coordinates(self) -> tuple[np.ndarray, np.ndarray]:
"""Coordinates of pixels in `x` and `y` [Å]."""
x = np.linspace(0.0, self.shape[-2] * self.sampling[0], self.shape[-2])
y = np.linspace(0.0, self.shape[-1] * self.sampling[1], self.shape[-1])
return x, y
@property
def base_axes_metadata(self) -> list[AxisMetadata]:
return [
RealSpaceAxis(
label="x", sampling=self.sampling[0], units="Å", tex_label="$x$"
),
RealSpaceAxis(
label="y", sampling=self.sampling[1], units="Å", tex_label="$y$"
),
]
[docs]
def integrate_disc(self, position: np.ndarray, radius: float) -> float:
"""
Integrate the values of the images on a disc-shaped region.
Parameters
----------
position : two floats
Center of disc-shaped integration region [Å].
radius : float
Radius of disc-shaped integration region [Å].
Returns
-------
float
Integral value.
"""
return integrate_disc(self, position=position, radius=radius)
[docs]
def integrate_gradient(self):
"""
Calculate integrated gradients. Requires complex images whose real and imaginary
parts represent the `x` and `y` components of a gradient.
Returns
-------
integrated_gradient : Images
The integrated gradient.
"""
self._check_is_complex()
if self.is_lazy:
xp = get_array_module(self.array)
array = self.array.rechunk(
self.array.chunks[:-2] + ((self.shape[-2],), (self.shape[-1],))
)
array = array.map_blocks(
_integrate_gradient_2d,
sampling=self.sampling,
meta=xp.array((), dtype=np.float32),
)
else:
array = _integrate_gradient_2d(self.array, sampling=self.sampling)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
kwargs["metadata"].update({"label": "iCOM", "units": "arb. unit"})
return self.__class__(**kwargs)
[docs]
def crop(
self,
extent: tuple[float, float],
offset: tuple[float, float] = (0.0, 0.0),
centered: bool = False,
):
"""
Crop images to a smaller extent.
Parameters
----------
extent : tuple of float
Extent of rectangular cropping region in `x` and `y` [Å].
offset : tuple of float
Lower corner of cropping region in `x` and `y` [Å] (default is (0,0)).
Returns
-------
cropped_images : Images
The cropped images.
"""
if centered and offset == (0.0, 0.0):
offset = (
self.extent[0] / 2 - extent[0] / 2,
self.extent[1] / 2 - extent[1] / 2,
)
elif centered:
raise ValueError("Offset is not used when centered is True.")
if extent[0] > self.extent[0] or extent[1] > self.extent[1]:
raise ValueError("Extent must be smaller than the original extent.")
offset = (
int(np.round(self.base_shape[0] * offset[0] / self.extent[0])),
int(np.round(self.base_shape[1] * offset[1] / self.extent[1])),
)
new_shape = (
int(np.round(self.base_shape[0] * extent[0] / self.extent[0])),
int(np.round(self.base_shape[1] * extent[1] / self.extent[1])),
)
array = self.array[
...,
offset[0] : offset[0] + new_shape[0],
offset[1] : offset[1] + new_shape[1],
]
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
return self.__class__(**kwargs)
@staticmethod
def _interpolate_spline(array, old_gpts, new_gpts, pad_mode, order, cval):
xp = get_array_module(array)
x = xp.linspace(0.0, old_gpts[0], new_gpts[0], endpoint=False)
y = xp.linspace(0.0, old_gpts[1], new_gpts[1], endpoint=False)
positions = xp.meshgrid(x, y, indexing="ij")
positions = xp.stack(positions, axis=-1)
return _interpolate_stack(array, positions, pad_mode, order=order, cval=cval)
[docs]
def interpolate(
self,
sampling: Optional[float | tuple[float, float]] = None,
gpts: Optional[int | tuple[int, int]] = None,
method: str = "fft",
boundary: str = "periodic",
order: int = 3,
normalization: str = "values",
cval: float = 0.0,
) -> Images:
"""
Interpolate images producing equivalent images with a different sampling.
Either 'sampling' or 'gpts' must be provided (but not both).
Parameters
----------
sampling : float or two float
Sampling of images after interpolation in `x` and `y` [Å].
gpts : int or two int
Number of grid points of images after interpolation in `x` and `y`.
Do not use if 'sampling' is used.
method : {'fft', 'spline'}
The interpolation method.
``fft`` :
Interpolate by cropping or zero-padding in reciprocal space.
This method should be preferred for periodic images.
``spline`` :
Interpolate using spline interpolation. This method should be
preferred for non-periodic images.
boundary : {'periodic', 'reflect', 'constant'}
The boundary parameter determines how the input array is extended beyond its
boundaries for spline interpolation.
``periodic`` :
The images are extended by wrapping around to the opposite edge.
Use this mode for periodic images (default).
``reflect`` :
The images are extended by reflecting about the edge of the last
pixel.
``constant`` :
The images are extended by filling all values beyond the edge with
the same constant value, defined by the 'cval' parameter.
order : int
The order of the spline interpolation (default is 3). The order has to be in
the range 0-5.
normalization : {'values', 'amplitude'}
The normalization parameter determines which quantity is preserved after
normalization.
``values`` :
The pixel-wise values of the images are preserved.
``intensity`` :
The total intensity of the images is preserved.
cval : scalar, optional
Value to fill past edges in spline interpolation input if boundary is
'constant' (default is 0.0).
Returns
-------
interpolated_images : Images
The interpolated images.
"""
if method == "fft" and boundary != "periodic":
raise ValueError(
"Only periodic boundaries available for FFT interpolation."
)
if sampling is None and gpts is None:
raise ValueError()
if gpts is None and sampling is not None:
if np.isscalar(sampling):
sampling = (sampling,) * 2
gpts = tuple(int(np.ceil(l / d)) for d, l in zip(sampling, self.extent))
elif gpts is not None:
if np.isscalar(gpts):
gpts = (gpts,) * 2
else:
raise ValueError()
xp = get_array_module(self.array)
sampling = (self.extent[0] / gpts[0], self.extent[1] / gpts[1])
if boundary == "periodic":
boundary = "wrap"
array = None
if self.is_lazy:
array = self.array.rechunk(
chunks=self.array.chunks[:-2] + ((self.shape[-2],), (self.shape[-1],))
)
if method == "fft":
array = array.map_blocks(
fft_interpolate,
new_shape=gpts,
normalization=normalization,
chunks=self.array.chunks[:-2] + ((gpts[0],), (gpts[1],)),
meta=xp.array((), dtype=self.array.dtype),
)
elif method == "spline":
array = array.map_blocks(
self._interpolate_spline,
old_gpts=self.shape[-2:],
new_gpts=gpts,
order=order,
cval=cval,
pad_mode=boundary,
chunks=self.array.chunks[:-2] + ((gpts[0],), (gpts[1],)),
meta=xp.array((), dtype=self.array.dtype),
)
else:
if method == "fft":
array = fft_interpolate(self.array, gpts, normalization=normalization)
elif method == "spline":
array = self._interpolate_spline(
self.array,
old_gpts=self.shape[-2:],
new_gpts=gpts,
pad_mode=boundary,
order=order,
cval=cval,
)
if array is None:
raise RuntimeError()
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["sampling"] = sampling
kwargs["array"] = array
return self.__class__(**kwargs)
[docs]
def tile(self, repetitions: tuple[int, int]) -> Images:
"""
Tile image(s).
Parameters
----------
repetitions : tuple of int
The number of repetitions of the images along the `x`- and `y`-axis,
respectively.
Returns
-------
tiled_images : Images
The tiled image(s).
"""
if len(repetitions) != 2:
raise RuntimeError()
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = np.tile(
self.array, (1,) * (len(self.array.shape) - 2) + repetitions
)
return self.__class__(**kwargs)
[docs]
def scan_noise(
self,
dwell_time: float,
flyback_time: float,
rms_power: float,
max_frequency: float = 500.0,
num_components: int = 200,
seed: Optional[int] = None,
):
"""
Apply scan noise to images.
Parameters
----------
dwell_time : float
Dwell time of the beam [s].
flyback_time : float
Flyback time of the beam [s].
rms_power : float
RMS power of the scan noise [V].
max_frequency : float
Maximum frequency of the scan noise [1/Å].
"""
transform = ScanNoiseTransform(
dwell_time=dwell_time,
flyback_time=flyback_time,
rms_power=rms_power,
max_frequency=max_frequency,
num_components=num_components,
seeds=seed,
)
return self.apply_transform(transform)
@staticmethod
def _diffractograms(array):
xp = get_array_module(array)
array = xp.fft.fft2(array)
return xp.fft.fftshift(xp.abs(array), axes=(-2, -1))
[docs]
def diffractograms(self) -> DiffractionPatterns:
"""
Calculate diffractograms (i.e. power spectra) from image(s).
Returns
-------
diffractograms : DiffractionPatterns
Diffractograms of image(s).
"""
xp = get_array_module(self.array)
if self.is_lazy:
array = self.array.rechunk(
chunks=self.array.chunks[:-2] + ((self.shape[-2],), (self.shape[-1],))
)
array = array.map_blocks(
self._diffractograms, meta=xp.array((), dtype=xp.float32)
)
else:
array = self._diffractograms(self.array)
sampling = 1 / self.extent[0], 1 / self.extent[1]
return DiffractionPatterns(
array=array,
sampling=sampling,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self.metadata,
)
# def _plot_base_axes_metadata(self, units: str = None):
# return self.base_axes_metadata
class _BaseMeasurement1D(BaseMeasurements):
_base_dims = 1
def __init__(
self,
array: np.ndarray,
sampling: Optional[float] = None,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[dict] = None,
):
self._sampling = sampling
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def _area_per_pixel(self):
raise RuntimeError("Cannot infer pixel area from metadata.")
@classmethod
def from_array_and_metadata(
cls,
array: np.ndarray,
axes_metadata: list[AxisMetadata],
metadata: Optional[dict] = None,
) -> Self:
"""
Creates line profile(s) from a given array and metadata.
Parameters
----------
array : array
Complex array defining one or more 1D line profiles.
axes_metadata : list of AxesMetadata
Axis metadata for each axis. The axis metadata must be compatible with the
shape of the array. The last two axes must be RealSpaceAxis.
metadata : dict, optional
A dictionary defining the measurement metadata.
Returns
-------
line_profiles : RealSpaceLineProfiles
Line profiles from the array and metadata.
"""
x_axis = axes_metadata[-1]
if isinstance(x_axis, LinearAxis):
sampling = x_axis.sampling
else:
raise RuntimeError()
axes_metadata = axes_metadata[:-1]
return cls(
array,
sampling=sampling,
ensemble_axes_metadata=axes_metadata,
metadata=metadata,
)
@property
def extent(self) -> float:
"""
Extent of measurements [Å] or [1/Å].
"""
return self.sampling * self.shape[-1]
@property
def sampling(self) -> float:
"""
Extent of measurements [Å] or [1/Å].
"""
return self._sampling
@property
@abstractmethod
def base_axes_metadata(self) -> list[RealSpaceAxis | ReciprocalSpaceAxis]:
pass
def _line_scan(self, sampling=None):
start, end = self.metadata["start"], self.metadata["end"]
from abtem.scan import LineScan
return LineScan(start=start, end=end, sampling=sampling)
def _add_to_visualization(self, *args, **kwargs):
if not all(key in self.metadata for key in ("start", "end")):
raise RuntimeError(
"The metadata does not contain the keys 'start' and 'end'"
)
if "width" in self.metadata:
kwargs["width"] = self.metadata["width"]
self._line_scan().add_to_axes(*args, **kwargs)
@staticmethod
def _calculate_widths(array, sampling, height):
xp = get_array_module(array)
array = array - xp.max(array, axis=-1, keepdims=True) * height
widths = xp.zeros(array.shape[:-1], dtype=np.float32)
for i in np.ndindex(array.shape[:-1]):
zero_crossings = xp.where(xp.diff(xp.sign(array[i]), axis=-1))[0]
left, right = zero_crossings[0], zero_crossings[-1]
widths[i] = (right - left) * sampling
return widths
def width(self, height: float = 0.5):
"""
Calculate the width of line(s) at a given height, e.g. full width at half
maximum (the default).
Parameters
----------
height : float
Fractional height at which the width is calculated.
Returns
-------
width : float
The calculated width.
"""
if self.is_lazy:
return self.array.map_blocks(
self._calculate_widths,
drop_axis=(len(self.array.shape) - 1,),
dtype=np.float32,
sampling=self.sampling,
height=height,
)
else:
return self._calculate_widths(self.array, self.sampling, height)
@staticmethod
def _interpolate(array, gpts, endpoint, order):
xp = get_array_module(array)
map_coordinates = get_ndimage_module(array).map_coordinates
old_shape = array.shape
array = array.reshape((-1, array.shape[-1]))
array = xp.pad(array, ((0,) * 2, (3,) * 2), mode="wrap")
new_points = xp.linspace(3.0, array.shape[-1] - 3.0, gpts, endpoint=endpoint)[
None
]
new_array = xp.zeros(array.shape[:-1] + (gpts,), dtype=xp.float32)
for i in range(len(array)):
map_coordinates(array[i], new_points, new_array[i], order=order)
return new_array.reshape(old_shape[:-1] + (gpts,))
def interpolate(
self,
sampling: Optional[float] = None,
gpts: Optional[int] = None,
order: int = 3,
endpoint: bool = False,
) -> BaseMeasurementsSubclass:
"""
Interpolate line profile(s) producing equivalent line profile(s) with a
different sampling. Either 'sampling' or 'gpts' must be provided (but not both).
Parameters
----------
sampling : float, optional
Sampling of line profiles after interpolation [Å].
gpts : int, optional
Number of grid points of line profiles after interpolation. Do not use if
'sampling' is used.
order : int, optional
The order of the spline interpolation (default is 3). The order has to be in
the range 0-5.
endpoint : bool, optional
If True, end is the last position. Otherwise, it is not included.
Default is False.
Returns
-------
interpolated_profiles : RealSpaceLineProfiles
The interpolated line profile(s).
"""
xp = get_array_module(self.array)
if (gpts is not None) and (sampling is not None):
raise RuntimeError()
if sampling is None and gpts is None:
sampling = self.sampling
if gpts is None:
gpts = int(np.ceil(self.extent / sampling))
if sampling is None:
sampling = self.extent / gpts
if self.is_lazy:
array = self.array.rechunk(self.array.chunks[:-1] + ((self.shape[-1],),))
array = array.map_blocks(
self._interpolate,
gpts=gpts,
endpoint=endpoint,
order=order,
chunks=self.array.chunks[:-1] + (gpts,),
meta=xp.array((), dtype=xp.float32),
)
else:
array = self._interpolate(self.array, gpts, endpoint, order)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
kwargs["sampling"] = sampling
return self.__class__(**kwargs)
def show(
self,
ax: Optional[Axes] = None,
common_scale: bool = True,
explode: bool | Sequence[int] = False,
overlay: Optional[bool | Sequence[int]] = None,
figsize: Optional[tuple[int, int]] = None,
title: str | bool = True,
units: Optional[str] = None,
legend: bool = False,
interact: bool = False,
display: bool = True,
**kwargs,
) -> Visualization:
"""
Show the reciprocal-space line profile(s) using matplotlib.
Parameters
----------
ax : matplotlib Axes, optional
If given the plots are added to the Axes. This is not available for image
grids.
common_scale : bool
If True all plots are shown with a common y-axis. Default is False.
explode : bool or sequence of bool, optional
If True, a grid of plots is created for all the items of the last two
ensemble axes. If False, only the one plot is created. May be given as a
sequence of axis indices to create a grid of plots from the specified
axes. The default is determined by the axis metadata.
overlay : bool or sequence of int, optional
If True, all line profiles in the ensemble are shown in a single plot.
If False, only the first ensemble item is shown. May be given as a sequence
of axis indices to specify which line profiles in the ensemble to show
together. The default is determined by the axis metadata.
figsize : two int, optional
The figure size given as width and height in inches, passed to
`matplotlib.pyplot.figure`.
title : bool or str, optional
Set the column title of the plots. If True is given instead of a string the
title will be given by the value corresponding to the "name" key of the axes
metadata dictionary, if this item exists.
legend : bool
Add a legend to the plot. The labels will be derived from
units : str, optional
The units used for the x-axis. The given units must be compatible.
interact : bool
If True, create an interactive visualization. This requires enabling the
`ipympl` Matplotlib backend.
display : bool, optional
If True (default) the figure is displayed immediately.
Returns
-------
visualization : Visualization
"""
if overlay is None and explode is False:
overlay = True
elif overlay is False or overlay is None:
overlay = ()
visualization = Visualization(
measurement=self,
ax=ax,
figsize=figsize,
title=title,
aspect=False,
share_x=True,
share_y=common_scale,
explode=explode,
overlay=overlay,
interactive=not interact and display,
legend=legend,
common_scale=common_scale,
**kwargs,
)
if interact:
visualization.interact(LinesGUI, display=display)
if common_scale is False and visualization._explode:
visualization.axes.set_sizes(padding=0.8)
return visualization
[docs]
class RealSpaceLineProfiles(_BaseMeasurement1D):
"""
A collection of real-space line profile(s).
Parameters
----------
array : np.ndarray
1D or greater array containing data of type `float` or `complex`.
sampling : float
Sampling of line profiles [Å].
ensemble_axes_metadata : list of AxisMetadata, optional
List of metadata associated with the ensemble axes. The length and item order
must match the ensemble axes.
metadata : dict, optional
A dictionary defining measurement metadata.
"""
[docs]
def __init__(
self,
array: np.ndarray,
sampling: Optional[float] = None,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[dict] = None,
):
super().__init__(
array=array,
sampling=sampling,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def base_axes_metadata(self) -> list[RealSpaceAxis]:
return [
RealSpaceAxis(label="r", sampling=self.sampling, units="Å", tex_label="$r$")
]
[docs]
def tile(self, repetitions: int) -> "RealSpaceLineProfiles":
"""
Tile line profiles(s).
Parameters
----------
repetitions : int
The number of repetitions of the line profiles.
Returns
-------
tiled_line_profiles : RealSpaceLineProfiles
The tiled line profiles(s).
"""
kwargs = self._copy_kwargs(exclude=("array",))
xp = get_array_module(self.array)
reps = (1,) * (len(self.array.shape) - 1) + (repetitions,)
if self.is_lazy:
kwargs["array"] = da.tile(self.array, reps)
else:
kwargs["array"] = xp.tile(self.array, reps)
return self.__class__(**kwargs)
# def _plot_extent(self, units=None):
# scale = get_conversion_factor(units, "Å")
# return [0, self.extent * scale]
#
# # def _plot_extent(self, units=None):
# # scale = {"Å": 1, "nm": 0.1}[_validate_real_space_units(units)]
# # return [0, self.extent * scale]
#
# def _plot_x_label(self, units=None):
# return f"x [{_validate_units(units, 'Å')}]"
#
# def _plot_y_label(self, units=None):
# return f"y [{_validate_units(units, 'Å')}]"
[docs]
class ReciprocalSpaceLineProfiles(_BaseMeasurement1D):
"""
A collection of reciprocal-space line profile(s).
Parameters
----------
array : np.ndarray
1D or greater array containing data of type `float` or `complex`.
sampling : float
Sampling of line profiles [1 / Å].
ensemble_axes_metadata : list of AxisMetadata, optional
List of metadata associated with the ensemble axes. The length and item order
must match the ensemble axes.
metadata : dict, optional
A dictionary defining measurement metadata.
"""
[docs]
def __init__(
self,
array: np.ndarray,
sampling: Optional[float] = None,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[dict] = None,
):
super().__init__(
array=array,
sampling=sampling,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def base_axes_metadata(self) -> list[AxisMetadata]:
return [
ReciprocalSpaceAxis(
label="k", sampling=self.sampling, units="1/Å", tex_label="$k$"
)
]
@property
def angular_extent(self):
"""Extent of line profiles given as scattering angels [mrad]."""
wavelength = energy2wavelength(self._get_from_metadata("energy"))
return self.extent * wavelength * 1e3
# def _plot_x_label(self, units=None):
# return f"x [{_validate_units(units, '1/Å')}]"
#
# def _plot_y_label(self, units=None):
# return f"y [{_validate_units(units, '1/Å')}]"
#
# def _plot_extent(self, units=None):
# if units is None:
# units = "1/Å"
#
# if units == "mrad":
# return [0, self.angular_extent]
# elif units == "1/Å":
# return [0, self.extent]
def _integrate_gradient_2d(gradient, sampling):
xp = get_array_module(gradient)
gx, gy = gradient.real, gradient.imag
(nx, ny) = gx.shape[-2:]
ikx = xp.fft.fftfreq(nx, d=sampling[0])
iky = xp.fft.fftfreq(ny, d=sampling[1])
grid_ikx, grid_iky = xp.meshgrid(ikx, iky, indexing="ij")
k = grid_ikx**2 + grid_iky**2
k[k == 0] = 1e-12
That = (xp.fft.fft2(gx) * grid_ikx + xp.fft.fft2(gy) * grid_iky) / (2j * np.pi * k)
T = xp.real(xp.fft.ifft2(That))
T -= xp.min(T)
return T
def _fourier_space_bilinear_nodes_and_weight(
old_shape: tuple[int, int],
new_shape: tuple[int, int],
old_angular_sampling: tuple[float, float],
new_angular_sampling: tuple[float, float],
xp,
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
nodes = []
weights = []
old_sampling = (
1 / old_angular_sampling[0] / old_shape[0],
1 / old_angular_sampling[1] / old_shape[1],
)
new_sampling = (
1 / new_angular_sampling[0] / new_shape[0],
1 / new_angular_sampling[1] / new_shape[1],
)
for n, m, r, d in zip(old_shape, new_shape, old_sampling, new_sampling):
k = xp.fft.fftshift(xp.fft.fftfreq(n, r).astype(xp.float32))
k_new = xp.fft.fftshift(xp.fft.fftfreq(m, d).astype(xp.float32))
distances = k_new[None] - k[:, None]
distances[distances < 0.0] = np.inf
w = distances.min(0) / (k[1] - k[0])
w[w == np.inf] = 0.0
nodes.append(distances.argmin(0))
weights.append(w)
v, u = nodes
vw, uw = weights
v, u, vw, uw = xp.broadcast_arrays(v[:, None], u[None, :], vw[:, None], uw[None, :])
return v, u, vw, uw
def _gaussian_source_size(measurements, sigma: float | tuple[float, float]):
if len(_scan_axes(measurements)) < 2:
raise RuntimeError(
"Gaussian source size not implemented for diffraction patterns with less"
"than two scan axes."
)
if np.isscalar(sigma):
sigma = (sigma,) * 2
xp = get_array_module(measurements.array)
gaussian_filter = get_ndimage_module(measurements._array).gaussian_filter
ensemble_axes = tuple(range(len(measurements.ensemble_shape)))
padded_sigma = ()
depth = ()
i = 0
for axis, n in zip(ensemble_axes, measurements.ensemble_shape):
if axis in _scan_axes(measurements):
scan_sampling = _scan_sampling(measurements)[i]
padded_sigma += (sigma[i] / scan_sampling,)
depth += (min(int(np.ceil(4.0 * sigma[i] / scan_sampling)), n),)
i += 1
else:
padded_sigma += (0.0,)
depth += (0,)
padded_sigma += (0.0,) * 2
depth += (0,) * 2
if measurements.is_lazy:
array = measurements.array.map_overlap(
gaussian_filter,
sigma=padded_sigma,
mode="wrap",
depth=depth,
meta=xp.array((), dtype=xp.float32),
)
else:
array = gaussian_filter(measurements.array, sigma=padded_sigma, mode="wrap")
kwargs = measurements._copy_kwargs(exclude=("array",))
return measurements.__class__(array, **kwargs)
def _infer_lines(b, H, W, out_H, out_W, kH, kW):
target_size = 2**17
line_size = b * (H * W // out_H + kH * kW * out_W)
target_lines = target_size // line_size
if target_lines < out_H:
lines = 1
while True:
next_lines = lines * 2
if next_lines > target_lines:
break
lines = next_lines
else:
lines = out_H
return lines
def _interpolate_bilinear(x, v, u, vw, uw):
B, H, W = x.shape
out_H, out_W = v.shape
# Interpolation is done by each output panel (i.e. multi lines)
# in order to better utilize CPU cache memory.
lines = _infer_lines(B, H, W, out_H, out_W, 2, 2)
vcol = np.empty((2, lines, out_W), dtype=v.dtype)
ucol = np.empty((2, lines, out_W), dtype=u.dtype)
wcol = np.empty((2, 2, lines, out_W), dtype=x.dtype)
y = np.empty((B, out_H * out_W), dtype=x.dtype)
for i in range(0, out_H, lines):
n = min(lines, out_H - i)
vcol = vcol[:, :n]
ucol = ucol[:, :n]
wcol = wcol[:, :, :n]
i_end = i + n
# indices
vcol[0] = v[i:i_end]
ucol[0] = u[i:i_end]
np.add(vcol[0], 1, out=vcol[1])
np.add(ucol[0], 1, out=ucol[1])
np.minimum(vcol[1], H - 1, out=vcol[1])
np.minimum(ucol[1], W - 1, out=ucol[1])
wcol[0, 1] = uw[i:i_end]
np.subtract(1, wcol[0, 1], out=wcol[0, 0])
np.multiply(wcol[0], vw[i:i_end], out=wcol[1])
wcol[0] -= wcol[1]
# packing to the panel whose shape is (B, C, 2, 2, l, out_W)
panel = x[:, vcol[:, None], ucol[None, :]]
# interpolation
panel = panel.reshape((B, 4, n * out_W))
weights = wcol.reshape((4, n * out_W))
iout = i * out_W
iout_end = i_end * out_W
np.einsum("ijk,jk->ik", panel, weights, out=y[:, iout:iout_end])
del panel, weights
return y.reshape((B, out_H, out_W))
def _diffraction_pattern_resampling_gpts(
old_sampling: tuple[float, float],
old_gpts: tuple[int, int],
sampling: str | float | tuple[float, float],
gpts: Optional[tuple[int, int]] = None,
adjust_sampling: bool = True,
):
if gpts is None:
if sampling == "uniform":
validated_sampling = (max(old_sampling),) * 2
elif isinstance(sampling, float):
validated_sampling = (sampling,) * 2
else:
raise ValueError("Invalid sampling.")
assert isinstance(validated_sampling, tuple)
adjusted_sampling, gpts = adjusted_gpts(
validated_sampling, old_sampling, old_gpts
)
if adjust_sampling:
validated_sampling = adjusted_sampling
else:
if np.isscalar(gpts):
gpts = (gpts,) * 2
validated_sampling = tuple(
d * old_n / new_n for d, old_n, new_n in zip(sampling, old_gpts, gpts)
)
return gpts, validated_sampling
[docs]
class DiffractionPatterns(_BaseMeasurement2D):
"""
One or more diffraction patterns.
Parameters
----------
array : np.ndarray
2D or greater array containing data with `float` type. The second-to-last and
last dimensions are the reciprocal space `y`- and `x`-axis of the diffraction
pattern.
sampling : float or two float
The reciprocal-space sampling of the diffraction patterns [1 / Å].
fftshift : bool, optional
If True, the diffraction patterns are assumed to have the zero-frequency
component to the center of the spectrum, otherwise the center(s) are assumed to
be at `(0, 0)`.
ensemble_axes_metadata : list of AxisMetadata, optional
List of metadata associated with the ensemble axes. The length and item order
must match the ensemble axes.
metadata : dict, optional
A dictionary defining measurement metadata.
"""
[docs]
def __init__(
self,
array: np.ndarray | da.core.Array,
sampling: float | tuple[float, float],
fftshift: bool = False,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[dict] = None,
):
if np.isscalar(sampling):
sampling = (float(sampling),) * 2
else:
sampling = float(sampling[0]), float(sampling[1])
self._fftshift = fftshift
self._sampling = sampling
self._base_axes = (-2, -1)
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def _area_per_pixel(self):
return _scan_area_per_pixel(self)
def _get_1d_equivalent(self):
return ReciprocalSpaceLineProfiles
@property
def base_axes_metadata(self):
limits = self.limits
return [
ReciprocalSpaceAxis(
sampling=self.sampling[0],
offset=limits[0][0],
label="kx",
units="1/Å",
fftshift=self.fftshift,
tex_label="$k_x$",
),
ReciprocalSpaceAxis(
sampling=self.sampling[1],
offset=limits[1][0],
label="ky",
units="1/Å",
fftshift=self.fftshift,
tex_label="$k_y$",
),
]
[docs]
def tile_scan(self, repetitions: tuple[int, int]) -> DiffractionPatterns:
"""
Tile the scan axes of the diffraction patterns. The diffraction patterns must
have one or more scan axes.
Parameters
----------
repetitions : two int
The number of repetitions of the scan positions along the `x`- and `y`-axis.
Returns
-------
tiled_diffraction_patterns : DiffractionPatterns
The tiled diffraction patterns.
"""
scan_axes = _scan_axes(self)
if len(scan_axes) != 2:
raise NotImplementedError
xp = get_array_module(self.array)
tiling = ()
j = 0
for i in range(len(self.shape)):
if i in scan_axes:
tiling += (repetitions[j],)
j += 1
else:
tiling += (1,)
array = xp.tile(self.array, tiling)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
return self.__class__(**kwargs)
@staticmethod
def _index_diffraction_spots(
array,
orientation_matrices,
hkl,
mask_all,
sampling,
cell,
sg_max,
g_max,
centering,
energy,
radius,
):
from abtem.bloch.indexing import index_diffraction_spots
from abtem.bloch.utils import filter_reciprocal_space_vectors
mask = filter_reciprocal_space_vectors(
hkl,
cell,
energy=energy,
sg_max=sg_max,
g_max=g_max,
centering=centering,
orientation_matrices=orientation_matrices,
)
array = index_diffraction_spots(
array=array,
hkl=hkl[mask],
sampling=sampling,
cell=cell,
energy=energy,
radius=radius,
orientation_matrices=orientation_matrices,
)
array_all = np.zeros(array.shape[:-1] + (mask_all.sum(),), dtype=array.dtype)
array_all[..., mask[mask_all]] = array
return array_all
[docs]
def index_diffraction_spots(
self,
cell: Cell | float | tuple[float, float, float],
sg_max: Optional[float] = None,
g_max: Optional[float] = None,
orientation_matrices: Optional[np.ndarray] = None,
radius: Optional[float] = None,
centering: str = "P",
energy: Optional[float] = None,
) -> IndexedDiffractionPatterns:
"""
Indexes the Bragg reflections (diffraction spots) by their Miller indices.
Parameters
----------
cell : ase.cell.Cell or float or tuple of float
The assumed unit cell with respect to the diffraction pattern should be
indexed. Must be one of ASE `Cell` object, float (for a cubic unit cell) or
three floats (for orthorhombic unit cells).
orientation_matrices : np.ndarray, optional
Orientation matrices used for indexing the diffraction spots. The shape of
the orientation matrices must be broadcastable with the ensemble shape of
the diffraction patterns.
sg_max : float, optional
Maximum excitation error [1/Å] of the indexed diffraction spots.
The default is estimated from the energy and `g_max`.
g_max : float, optional
Maximum scattering vector [1/Å] of the indexed diffraction spots.
The default is the maximum frequency of the diffraction patterns.
radius : float, optional
Integration Radius of the diffraction spots [1/Å]. The default is the
reciprocal-space sampling of the diffraction patterns.
centering : {'P', 'F', 'I', 'A', 'B', 'C'}
Assumed lattice centering used for determining the reflection conditions.
energy : float, optional
The energy of the electrons [keV]. The default is the energy stored in the
metadata.
Return
-------
indexed_patterns : IndexedDiffractionPatterns
The indexed diffraction pattern(s).
"""
from abtem.bloch.indexing import (
estimate_necessary_excitation_error,
validate_cell,
)
from abtem.bloch.utils import filter_reciprocal_space_vectors, make_hkl_grid
if orientation_matrices is not None and not is_broadcastable(
self.ensemble_shape, orientation_matrices.shape[:-2]
):
raise ValueError(
"The ensemble shape and the shape of the orientation matrices must be"
" broadcastable."
)
if energy is None:
energy = self._get_from_metadata("energy")
if g_max is None:
g_max = max(self.max_frequency)
if sg_max is None:
sg_max = estimate_necessary_excitation_error(energy, g_max)
if orientation_matrices is None:
orientation_matrices = np.eye(3)[(None,) * len(self.array.shape[:-2])]
cell = validate_cell(cell)
hkl = make_hkl_grid(cell, g_max)
mask = filter_reciprocal_space_vectors(
hkl,
cell,
energy=energy,
sg_max=sg_max,
g_max=g_max,
centering=centering,
orientation_matrices=orientation_matrices,
)
if self.is_lazy:
orientation_matrices = orientation_matrices[
(None,)
* (len(self.array.shape[:-2]) - len(orientation_matrices.shape[:-2]))
]
chunks = tuple(
c if n == sum(c) else 1
for n, c in zip(orientation_matrices.shape, self.array.chunks[:-2])
)
lazy_orientation_matrices = da.from_array(
orientation_matrices, chunks=chunks + (3, 3)
)
intensities = da.map_blocks(
self._index_diffraction_spots,
self.array,
lazy_orientation_matrices,
hkl=hkl,
mask_all=mask,
sampling=self.sampling,
cell=cell,
sg_max=sg_max,
g_max=g_max,
centering=centering,
energy=energy,
radius=radius,
drop_axis=len(self.array.shape) - 1,
chunks=self.array.chunks[:-2] + (mask.sum(),),
meta=np.array((), dtype=self.dtype),
)
else:
intensities = self._index_diffraction_spots(
array=self.array,
orientation_matrices=orientation_matrices,
hkl=hkl,
mask_all=mask,
sampling=self.sampling,
cell=cell,
sg_max=sg_max,
g_max=g_max,
centering=centering,
energy=energy,
radius=radius,
)
reciprocal_lattice_vectors = np.matmul(
cell.reciprocal(),
np.swapaxes(orientation_matrices, -2, -1),
)
return IndexedDiffractionPatterns(
intensities,
hkl[mask],
reciprocal_lattice_vectors=reciprocal_lattice_vectors,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self.metadata,
)
@property
def fftshift(self) -> bool:
"""
True if the zero-frequency is shifted to the center of the array.
"""
return self._fftshift
@property
def sampling(self) -> tuple[float, float]:
"""Sampling of diffraction patterns in `x` and `y` [1 / Å]."""
return self._sampling
@property
def angular_sampling(self) -> tuple[float, float]:
"""
Angular sampling of diffraction patterns in `x` and `y` [mrad].
"""
wavelength = energy2wavelength(self._get_from_metadata("energy"))
return (
self.sampling[0] * wavelength * 1e3,
self.sampling[1] * wavelength * 1e3,
)
@property
def max_angles(self) -> tuple[float, float]:
"""Maximum scattering angle in `x` and `y` [mrad]."""
return (
self.shape[-2] // 2 * self.angular_sampling[0],
self.shape[-1] // 2 * self.angular_sampling[1],
)
@property
def max_frequency(self):
"""Maximum spatial frequency in `x` and `y` [1 / Å]."""
return abs(self.limits[0][1]), abs(self.limits[1][1])
@property
def limits(self) -> list[tuple[float, float]]:
"""Lowest and highest spatial frequency in `x` and `y` [1 / Å]."""
limits = []
for i in (-2, -1):
if self.shape[i] % 2:
limits += [
(
-(self.shape[i] - 1) // 2 * self.sampling[i],
(self.shape[i] - 1) // 2 * self.sampling[i],
)
]
else:
limits += [
(
-self.shape[i] // 2 * self.sampling[i],
(self.shape[i] // 2 - 1) * self.sampling[i],
)
]
return limits
@property
def angular_limits(self) -> list[tuple[float, float]]:
"""Lowest and highest scattering angle in `x` and `y` [mrad]."""
limits = self.limits
wavelength = energy2wavelength(self._get_from_metadata("energy"))
limits[0] = (
limits[0][0] * wavelength * 1e3,
limits[0][1] * wavelength * 1e3,
)
limits[1] = (
limits[1][0] * wavelength * 1e3,
limits[1][1] * wavelength * 1e3,
)
return limits
@property
def offset(self) -> tuple[float, float]:
limits = self.limits
return limits[0][0], limits[1][0]
@property
def extent(self) -> tuple[float, float]:
limits = self.limits
return limits[0][0] - limits[0][1], limits[1][0] - limits[1][1]
@property
def coordinates(self) -> tuple[np.ndarray, np.ndarray]:
"""Reciprocal-space frequency coordinates [1 / Å]."""
return (
self.axes_metadata[-2].coordinates(self.base_shape[-2]),
self.axes_metadata[-1].coordinates(self.base_shape[-1]),
)
@property
def angular_coordinates(self) -> tuple[np.ndarray, np.ndarray]:
"""Scattering angle coordinates [mrad]."""
xp = get_array_module(self.array)
limits = self.angular_limits
alpha_x = xp.linspace(
limits[0][0], limits[0][1], self.shape[-2], dtype=xp.float32
)
alpha_y = xp.linspace(
limits[1][0], limits[1][1], self.shape[-1], dtype=xp.float32
)
if self.fftshift:
return alpha_x, alpha_y
else:
return np.fft.fftshift(alpha_x), np.fft.fftshift(alpha_y)
@staticmethod
def _batch_interpolate_bilinear(array, new_sampling, sampling, new_gpts):
xp = get_array_module(array)
v, u, vw, uw = _fourier_space_bilinear_nodes_and_weight(
array.shape[-2:], new_gpts, sampling, new_sampling, xp
)
old_shape = array.shape
array = array.reshape((-1,) + array.shape[-2:])
old_sums = array.sum((-2, -1), keepdims=True)
if xp is cp:
array = interpolate_bilinear_cuda(array, v, u, vw, uw)
else:
array = _interpolate_bilinear(array, v, u, vw, uw)
array = array / array.sum((-2, -1), keepdims=True) * old_sums
return array.reshape(old_shape[:-2] + array.shape[-2:])
[docs]
def interpolate(
self,
sampling: Optional[str | float | tuple[float, float]] = None,
gpts: Optional[tuple[int, int]] = None,
):
"""
Interpolate diffraction pattern(s) producing equivalent pattern(s) with a
different sampling.
Parameters
----------
sampling : 'uniform' or float or two floats
Sampling of diffraction patterns after interpolation in `x` and `y` [1 / Å].
If a single value, the same sampling is used for both axes. If 'uniform',
the diffraction patterns are down-sampled along the axis with the smallest
pixel size such that the sampling is uniform.
gpts : tuple of int
Number of grid points of the diffraction patterns after interpolation in
`x` and `y`.
Do not use if 'sampling' is used.
Returns
-------
interpolated_diffraction_patterns : DiffractionPatterns
The interpolated diffraction pattern(s).
"""
gpts, sampling = _diffraction_pattern_resampling_gpts(
self.sampling, self.base_shape, sampling, gpts, adjust_sampling=False
)
if self.is_lazy:
array = self.array.map_blocks(
self._batch_interpolate_bilinear,
sampling=self.sampling,
new_sampling=sampling,
new_gpts=gpts,
chunks=self.array.chunks[:-2] + ((gpts[0],), (gpts[1],)),
dtype=np.float32,
)
else:
array = self._batch_interpolate_bilinear(
self.array, sampling=self.sampling, new_sampling=sampling, new_gpts=gpts
)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["sampling"] = sampling
kwargs["array"] = array
return self.__class__(**kwargs)
def _check_integration_limits(self, inner: float, outer: float):
if inner > outer:
raise RuntimeError(
f"Inner detection ({inner} mrad) angle cannot exceed the outer"
f" detection angle ({outer} mrad)."
)
if (outer > self.max_angles[0]) or (outer > self.max_angles[1]):
if not np.isclose(min(self.max_angles), outer, atol=1e-5):
raise RuntimeError(
f"Outer integration limit cannot exceed the maximum simulated angle"
f" ({outer} mrad > {min(self.max_angles)} mrad), increase the"
" number of grid points."
)
[docs]
def gaussian_source_size(
self, sigma: float | tuple[float, float]
) -> DiffractionPatterns:
"""
Simulate the effect of a finite source size on diffraction pattern(s) using a
Gaussian filter.
The filter is not applied to diffraction pattern individually, but the intensity
of diffraction patterns are mixed across scan axes. Applying this filter
requires two linear scan axes.
Applying this filter before integrating the diffraction patterns will produce
the same image as integrating the diffraction patterns first then applying a
Gaussian filter.
Parameters
----------
sigma : float or two float
Standard deviation of Gaussian kernel in the `x` and `y`-direction. If given
as a single number, the standard deviation is equal for both axes.
Returns
-------
filtered_diffraction_patterns : DiffractionPatterns
The filtered diffraction pattern(s).
"""
return _gaussian_source_size(self, sigma)
[docs]
def poisson_noise(
self,
dose_per_area: Optional[float] = None,
total_dose: Optional[float] = None,
samples: int = 1,
seed: Optional[int] = None,
):
"""
Add Poisson noise (i.e. shot noise) to a measurement corresponding to the
provided `total_dose` (per measurement if applied to an ensemble) or
`dose_per_area` (not applicable for single measurements).
Parameters
----------
dose_per_area : float, optional
The irradiation dose per unit of scan area [electrons per Å:sup:`2`].
This is only valid if the diffraction patterns has two scan axes.
total_dose : float, optional
The irradiation dose per diffraction pattern.
samples : int, optional
The number of samples to draw from a Poisson distribution. If this is
greater than 1, an additional ensemble axis will be added to the
measurement.
seed : int, optional
Seed the random number generator.
Returns
-------
noisy_measurement : BaseMeasurements
The noisy measurement.
"""
if len(_scan_shape(self)) < 2 and dose_per_area is not None:
raise ValueError(
"diffraction patterns has less than two scan axes, provide `total_dose`"
" not `dose_per_area`"
)
# TODO: normalization
return super().poisson_noise(
dose_per_area=dose_per_area,
total_dose=total_dose,
samples=samples,
seed=seed,
)
@staticmethod
def _radial_binning(
array,
nbins_radial,
nbins_azimuthal,
sampling,
inner,
outer,
fftshift,
rotation,
offset,
):
xp = get_array_module(array)
indices = _polar_detector_bins(
gpts=array.shape[-2:],
sampling=sampling,
inner=inner,
outer=outer,
nbins_radial=nbins_radial,
nbins_azimuthal=nbins_azimuthal,
fftshift=fftshift,
rotation=rotation,
offset=offset,
return_indices=True,
)
separators = xp.concatenate(
(xp.array([0]), xp.cumsum(xp.array([len(i) for i in indices])))
)
new_shape = array.shape[:-2] + (nbins_radial, nbins_azimuthal)
array = array.reshape(
(
-1,
array.shape[-2] * array.shape[-1],
)
)[..., np.concatenate(indices)]
result = xp.zeros(
(
array.shape[0],
len(indices),
),
dtype=xp.float32,
)
if xp is cp:
sum_run_length_encoded_cuda(array, result, separators)
else:
_sum_run_length_encoded(array, result, separators)
return result.reshape(new_shape)
[docs]
def polar_binning(
self,
nbins_radial: int,
nbins_azimuthal: int,
inner: float = 0.0,
outer: Optional[float] = None,
rotation: float = 0.0,
offset: tuple[float, float] = (0.0, 0.0),
):
"""
Create polar measurements from the diffraction patterns by binning the
measurements on a polar grid. This method may be used to simulate a segmented
detector with a specified number of radial and azimuthal bins.
Each bin is a segment of an annulus and the bins are spaced equally in the
radial and azimuthal directions.
The bins fit between a given inner and outer integration limit, they may be
rotated around the origin, and their center may be shifted from the origin.
Parameters
----------
nbins_radial : int
Number of radial bins.
nbins_azimuthal : int
Number of angular bins.
inner : float
Inner integration limit of the bins [mrad] (default is 0.0).
outer : float
Outer integration limit of the bins [mrad]. If not specified, this is set to
be the maximum detected angle of the diffraction pattern.
rotation : float
Rotation of the bins around the origin [mrad] (default is 0.0).
offset : two float
Offset of the bins from the origin in `x` and `y` [mrad].
Default is (0.0, 0.0).
Returns
-------
polar_measurements : PolarMeasurements
The polar measurements.
"""
if nbins_radial <= 0 or nbins_azimuthal <= 0:
raise RuntimeError("number of bins must be greater than zero")
if outer is None:
outer = min(self.max_angles)
self._check_integration_limits(inner, outer)
xp = get_array_module(self.array)
if self.is_lazy:
array = self.array.map_blocks(
self._radial_binning,
nbins_radial=nbins_radial,
nbins_azimuthal=nbins_azimuthal,
sampling=self.angular_sampling,
inner=inner,
outer=outer,
fftshift=self.fftshift,
rotation=rotation,
offset=offset,
drop_axis=(len(self.shape) - 2, len(self.shape) - 1),
chunks=self.array.chunks[:-2]
+ (
(nbins_radial,),
(nbins_azimuthal,),
),
new_axis=(
len(self.shape) - 2,
len(self.shape) - 1,
),
meta=xp.array((), dtype=xp.float32),
)
else:
array = self._radial_binning(
self.array,
nbins_radial=nbins_radial,
nbins_azimuthal=nbins_azimuthal,
sampling=self.angular_sampling,
inner=inner,
outer=outer,
fftshift=self.fftshift,
rotation=rotation,
offset=offset,
)
radial_sampling = (outer - inner) / nbins_radial
azimuthal_sampling = 2 * np.pi / nbins_azimuthal
return PolarMeasurements(
array,
radial_sampling=radial_sampling,
azimuthal_sampling=azimuthal_sampling,
radial_offset=inner,
azimuthal_offset=rotation,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self.metadata,
)
[docs]
def radial_binning(
self, step_size: float = 1.0, inner: float = 0.0, outer: Optional[float] = None
) -> PolarMeasurements:
"""
Create polar measurement(s) from the diffraction pattern(s) by binning the
measurements in annular regions. This method may be used to simulate a segmented
detector with a specified number of radial bins.
This is equivalent to detecting a wave function using the
`FlexibleAnnularDetector`.
Parameters
----------
step_size : float, optional
Radial extent of the bins [mrad]. Default is 1.0.
inner : float, optional
Inner integration limit of the bins [mrad]. Default is 0.0.
outer : float, optional
Outer integration limit of the bins [mrad]. If not specified, this is set to
be the maximum detected angle of the diffraction pattern.
Returns
-------
radially_binned_measurement : PolarMeasurements
Radially binned polar measurement(s).
"""
if outer is None:
outer = min(self.max_angles)
nbins_radial = int((outer - inner) / step_size)
return self.polar_binning(nbins_radial, 1, inner, outer)
@staticmethod
def _integrate_fourier_space(array, sampling, inner, outer, fftshift, offset):
xp = get_array_module(array)
bins = _annular_detector_mask(
gpts=array.shape[-2:],
sampling=sampling,
inner=inner,
outer=outer,
fftshift=fftshift,
offset=offset,
xp=xp,
)
return xp.sum(array * bins, axis=(-2, -1))
[docs]
def integrate_radial(
self,
inner: float,
outer: float = None,
offset: tuple[float, float] = (0.0, 0.0),
) -> Images:
"""
Create images by integrating the diffraction patterns over an annulus defined by
an inner and outer integration angle.
Parameters
----------
inner : float
Inner integration limit [mrad].
outer : float
Outer integration limit [mrad].
offset : tuple of float
Offset of center of annular integration region [mrad].
Returns
-------
integrated_images : Images
The integrated images.
"""
if isinstance(inner, Sequence) or isinstance(outer, Sequence):
if isinstance(inner, Number):
inners = (inner,) * len(outer)
outers = outer
else:
outers = (outer,) * len(inner)
inners = inner
measurements = [
self.integrate_radial(inner=inner, outer=outer)
for inner, outer in zip(inners, outers)
]
measurements = stack(
measurements,
axis_metadata=NonLinearAxis(
label="Limits", values=tuple(zip(inners, outers)), units="mrad"
),
)
return measurements
if outer is None:
outer = min(self.max_angles)
self._check_integration_limits(inner, outer)
xp = get_array_module(self.array)
if self.is_lazy:
integrated_intensity = self.array.map_blocks(
self._integrate_fourier_space,
sampling=self.angular_sampling,
inner=inner,
outer=outer,
fftshift=self.fftshift,
offset=offset,
drop_axis=(len(self.shape) - 2, len(self.shape) - 1),
meta=xp.array((), dtype=xp.float32),
)
else:
integrated_intensity = self._integrate_fourier_space(
self.array,
sampling=self.angular_sampling,
inner=inner,
outer=outer,
fftshift=self.fftshift,
offset=offset,
)
return _reduced_scanned_images_or_line_profiles(integrated_intensity, self)
[docs]
def integrated_center_of_mass(self) -> Images:
"""
Calculate integrated center-of-mass (iCOM) images from diffraction patterns.
This method is only implemented for diffraction patterns with exactly two scan
axes.
Returns
-------
icom_images : Images
The iCOM images.
"""
com = self.center_of_mass()
if isinstance(com, Images):
return com.integrate_gradient()
else:
raise RuntimeError(
f"Integrated center-of-mass not implemented for DiffractionPatterns"
f" with {len(_scan_shape(self))} scan axes."
)
@staticmethod
def _com(array: np.ndarray, x: np.ndarray, y: np.ndarray):
com_x = (array * x[:, None]).sum(axis=(-2, -1))
com_y = (array * y[None]).sum(axis=(-2, -1))
com = com_x + 1.0j * com_y
return com
[docs]
def center_of_mass(self, units: str = "1/Å") -> Images | RealSpaceLineProfiles:
"""
Calculate center-of-mass images or line profiles from diffraction patterns.
The results are of type `complex` where the real and imaginary part represents
the `x` and `y` component.
Parameters
----------
units : {'1/Å', 'mrad'}
Units of the center-of-mass values of the output images or line profiles.
Default is '1/Å'.
Returns
-------
com_images : Images
Center-of-mass images.
com_line_profiles : RealSpaceLineProfiles
Center-of-mass line profiles (returned if there is only one scan axis).
"""
if units == "mrad":
x, y = self.angular_coordinates
elif units == "1/Å":
x, y = self.coordinates
else:
raise ValueError("units must be '1/Å' or 'mrad'")
xp = get_array_module(self.array)
x, y = xp.asarray(x), xp.asarray(y)
if self.is_lazy:
base_axes = tuple(
range(
len(self.ensemble_shape),
len(self.base_shape) + len(self.ensemble_shape),
)
)
array = self.array.map_blocks(
self._com, x=x, y=y, drop_axis=base_axes, dtype=np.complex64
)
else:
array = self._com(self.array, x=x, y=y)
return _reduced_scanned_images_or_line_profiles(array, self)
@staticmethod
def _bandlimit(
array: np.ndarray,
inner: float,
outer: float,
angular_coordinates: tuple[np.ndarray, np.ndarray],
):
alpha_x, alpha_y = angular_coordinates
alpha = np.sqrt(alpha_x[:, None] ** 2 + alpha_y[None] ** 2)
block = alpha > inner
if outer != np.inf:
block *= alpha < outer
return array * block
[docs]
def bandlimit(
self, inner: float = 0.0, outer: float = np.inf
) -> DiffractionPatterns:
"""
Bandlimit diffraction pattern(s) by setting everything outside an annulus
defined by two radial angles to zero.
Parameters
----------
inner : float
Inner limit of zero region [mrad].
outer : float
Outer limit of zero region [mrad].
Returns
-------
band-limited_diffraction_patterns : DiffractionPatterns
The band-limited diffraction pattern(s).
"""
xp = get_array_module(self.array)
if self.is_lazy:
array = self.array.map_blocks(
self._bandlimit,
inner=inner,
outer=outer,
angular_coordinates=self.angular_coordinates,
meta=xp.array((), dtype=xp.float32),
)
else:
array = self._bandlimit(self.array, inner, outer, self.angular_coordinates)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
return self.__class__(**kwargs)
@staticmethod
def _crop(array: np.ndarray, gpts: tuple[int, int]):
xp = get_array_module(array)
array = xp.fft.fftshift(
fft_crop(xp.fft.ifftshift(array, axes=(-2, -1)), new_shape=gpts),
axes=(-2, -1),
)
return array
[docs]
def crop(
self,
max_angle: Optional[float] = None,
max_frequency: Optional[float] = None,
gpts: Optional[tuple[int, int]] = None,
) -> DiffractionPatterns:
"""
Crop the diffraction patterns such that they only include spatial frequencies
(scattering angles) up to a given limit.
Parameters
----------
max_angle : float, optional
The maximum included scattering angle in the cropped diffraction patterns.
max_frequency : float, optional
The maximum included spatial frequency in the cropped diffraction patterns.
gpts : tuple of int
The number of gpts in the cropped diffraction patterns.
Returns
-------
cropped_diffraction_patterns : DiffractionPatterns
The cropped diffraction pattern(s).
"""
none_args = (max_angle is None) + (max_frequency is None) + (gpts is None)
if none_args != 2:
raise ValueError(
"provide exactly one of 'max_angle', 'max_frequency' or 'gpts'"
)
if gpts is None and max_angle is not None:
gpts = (
int(2 * np.round(max_angle / self.angular_sampling[0])) + 1,
int(2 * np.round(max_angle / self.angular_sampling[1])) + 1,
)
elif gpts is None and max_frequency is not None:
gpts = (
int(2 * np.round(max_frequency / self.sampling[0])) + 1,
int(2 * np.round(max_frequency / self.sampling[1])) + 1,
)
if gpts is None:
raise ValueError()
if self.is_lazy:
xp = get_array_module(self.array)
array = self.array.map_blocks(
self._crop,
gpts=gpts,
chunks=self.array.chunks[:-2] + gpts,
meta=xp.array((), dtype=self.dtype),
)
else:
array = self._crop(self.array, gpts=gpts)
kwargs = self._copy_kwargs(exclude=("array",))
kwargs["array"] = array
return self.__class__(**kwargs)
@staticmethod
def _azimuthal_average(
array: np.ndarray,
angular_coordinates: tuple[np.ndarray, np.ndarray],
max_angle: float,
radial_sampling: float,
weighting_function: str,
width: float,
):
x, y = np.meshgrid(*angular_coordinates, indexing="ij")
r = np.sqrt(x**2 + y**2)
centers = np.arange(0, max_angle, radial_sampling)
values = np.zeros(array.shape[:-2] + centers.shape)
for i, center in enumerate(centers):
if weighting_function == "step":
mask = np.abs(r - center) < width
elif weighting_function == "gaussian":
mask = np.exp(-((r - center) ** 2) / (width**2 / 2))
else:
raise ValueError("weighting function must be 'step' or 'gaussian'")
weight = np.sum(mask)
if weight > 0:
values[..., i] = np.sum(array * mask, axis=(-2, -1)) / weight
else:
values[..., i] = 0.0
return values
[docs]
def azimuthal_average(
self,
max_angle: Optional[float] = None,
radial_sampling: float = 1.0,
weighting_function: str = "step",
width: float = 1.0,
) -> ReciprocalSpaceLineProfiles:
"""
Calculate the azimuthal averages of the diffraction patterns.
Parameters
----------
max_angle : float, optional
The maximum included scattering angle in the azimuthal averages [mrad].
radial_sampling : float, optional
The radial sampling of the azimuthal averages [mrad]. Default is equal to
the smallest value of the x and y component of the angular sampling.
weighting_function : str
The weighting function to determining how to average the diffraction
patterns. The weighting function determines the shape of the mask that is
applied to the diffraction patterns before averaging. The options are 'step'
and 'gaussian'. Default is 'step'.
width : float, optional
The width of the weighting function [mrad]. Default is 1.0.
For the 'step' weighting function, this is the width of the step function.
For the 'gaussian' weighting function, this is the standard deviation of the
Gaussian function.
Returns
-------
azimuthal_averages : ReciprocalSpaceLineProfiles
The azimuthal averages of the diffraction patterns.
"""
if max_angle is None:
max_angle = -min(min(self.angular_limits))
radial_sampling = radial_sampling * min(self.angular_sampling)
width = width * min(self.angular_sampling)
if self.is_lazy:
xp = get_array_module(self.array)
n = int(max_angle / radial_sampling)
base_axes = tuple(
range(
len(self.ensemble_shape),
len(self.ensemble_shape) + len(self.base_shape),
)
)
array = self.array.map_blocks(
self._azimuthal_average,
angular_coordinates=self.angular_coordinates,
max_angle=max_angle,
radial_sampling=radial_sampling,
weighting_function=weighting_function,
width=width,
drop_axis=base_axes,
new_axis=base_axes[0],
chunks=self.array.chunks[:-2] + (n,),
meta=xp.array((), dtype=np.float32),
)
else:
array = self._azimuthal_average(
self.array,
angular_coordinates=self.angular_coordinates,
max_angle=max_angle,
radial_sampling=radial_sampling,
weighting_function=weighting_function,
width=width,
)
wavelength = energy2wavelength(self._get_from_metadata("energy"))
return ReciprocalSpaceLineProfiles(
array,
sampling=radial_sampling / (wavelength * 1e3),
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self.metadata,
)
# def fourier_shell_correlation(
# self,
# other: DiffractionPatterns,
# radial_sampling: float = 1.0,
# width: float = 1.0,
# weighting_function: str = "step",
# ):
# fsc = (self**0.5 * other**0.5).azimuthal_average(
# radial_sampling=radial_sampling,
# width=width,
# weighting_function=weighting_function,
# ) / (
# self.azimuthal_average(
# radial_sampling=radial_sampling,
# width=width,
# weighting_function=weighting_function,
# )
# * other.azimuthal_average(
# radial_sampling=radial_sampling,
# width=width,
# weighting_function=weighting_function,
# )
# )
# return fsc
[docs]
def block_direct(
self, radius: Optional[float] = None, margin: Optional[bool] = None
) -> DiffractionPatterns:
"""
Block the direct beam by setting the pixels of the zeroth-order Bragg reflection
(non-scattered beam) to zero.
Parameters
----------
radius : float, optional
The radius of the zeroth-order reflection to block [mrad]. If not given this
will be inferred from the metadata, if available.
margin : bool, optional
If True adds a margin to the blocking radius to fully block soft apertures.
Margin is true by default for diffraction patterns with `semiangle_cutoff`
in metadata.
Returns
-------
diffraction_patterns : DiffractionPatterns
The diffraction pattern(s) with the direct beam removed.
"""
if radius is None:
if "semiangle_cutoff" in self.metadata.keys():
radius = self.metadata["semiangle_cutoff"]
else:
radius = max(self.angular_sampling) * 1.0001
if "semiangle_cutoff" in self.metadata.keys() and margin is None:
margin = True
if margin:
radius += max(self.angular_sampling)
return self.bandlimit(radius, outer=np.inf)
[docs]
class PolarMeasurements(BaseMeasurements):
"""
Class describing polar measurements with a specified number of radial and azimuthal
bins.
Each bin is a segment of an annulus and the bins are spaced equally in the radial
and azimuthal directions. The bins may be rotated around the origin, and their
center may be shifted from the origin.
Parameters
----------
array : np.ndarray
Array containing the measurement.
radial_sampling : float
Sampling of the radial bins [mrad].
azimuthal_sampling : int
Sampling of the azimuthal bins [rad].
radial_offset : float, optional
Offset of the bins from the origin [mrad] (default is 0.0).
azimuthal_offset : float, optional
Rotation of the bins around the origin [rad] (default is 0.0).
ensemble_axes_metadata : list of AxisMetadata, optional
List of metadata associated with the ensemble axes. The length and item order
must match the ensemble axes.
metadata : dict, optional
A dictionary defining measurement metadata.
Returns
-------
polar_measurements : PolarMeasurements
The polar measurements.
"""
_base_dims = 2
[docs]
def __init__(
self,
array: np.ndarray | da.core.Array,
radial_sampling: float,
azimuthal_sampling: float,
radial_offset: float = 0.0,
azimuthal_offset: float = 0.0,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[dict] = None,
):
self._radial_sampling = radial_sampling
self._azimuthal_sampling = azimuthal_sampling
self._radial_offset = radial_offset
self._azimuthal_offset = azimuthal_offset
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
@property
def _area_per_pixel(self):
return _scan_area_per_pixel(self)
@property
def sampling(self):
raise RuntimeError("Sampling not defined for polar measurement.")
@property
def offset(self):
raise RuntimeError("Offset not defined for polar measurement.")
@property
def extent(self):
raise RuntimeError("Extent not defined for polar measurement.")
def _get_1d_equivalent(self):
raise RuntimeError("Not defined for polar measurement.")
@property
def base_axes_metadata(self) -> list[AxisMetadata]:
return [
LinearAxis(
label="Radial scattering angle",
offset=self.radial_offset,
sampling=self.radial_sampling,
_concatenate=False,
units="mrad",
),
LinearAxis(
label="Azimuthal scattering angle",
offset=self.azimuthal_offset,
sampling=self.azimuthal_sampling,
_concatenate=False,
units="rad",
),
]
@property
def radial_offset(self) -> float:
"""Offset of the bins from the origin [mrad]."""
return self._radial_offset
@property
def outer_angle(self) -> float:
"""The outer angle of the outermost radial bin [mrad]."""
return self._radial_offset + self.radial_sampling * self.shape[-2]
@property
def radial_sampling(self) -> float:
"""Sampling of the radial bins [mrad]."""
return self._radial_sampling
@property
def azimuthal_sampling(self) -> float:
"""Sampling of the azimuthal bins [rad]."""
return self._azimuthal_sampling
@property
def azimuthal_offset(self) -> float:
"""Rotation of the bins around the origin [rad]."""
return self._azimuthal_offset
[docs]
def integrate_radial(
self, inner: float, outer: float
) -> Images | RealSpaceLineProfiles:
"""
Create images by integrating the polar measurements over an annulus defined by
an inner and outer integration angle.
Parameters
----------
inner : float
Inner integration limit [mrad].
outer : float
Outer integration limit [mrad].
Returns
-------
integrated_images : Images
The integrated images.
real_space_line_profiles : RealSpaceLineProfiles
Integrated line profiles (returned if there is only one scan axis).
"""
return self.integrate(radial_limits=(inner, outer))
[docs]
def integrate(
self,
radial_limits: Optional[tuple[float, float]] = None,
azimuthal_limits: Optional[tuple[float, float]] = None,
detector_regions: Optional[int | Sequence[int]] = None,
) -> Images | RealSpaceLineProfiles:
"""
Integrate polar regions to produce an image or line profiles.
Parameters
----------
radial_limits : tuple of float
Inner and outer radial angles of the integration limits [mrad].
azimuthal_limits : tuple of float
Lower and upper azimuthal angles of the integration limits [rad].
detector_regions : int or sequence of int
The explicit detector regions to integrate over.
Returns
-------
integrated_images : Images or RealSpaceLineProfiles
"""
if detector_regions is not None:
if (radial_limits is not None) or (azimuthal_limits is not None):
raise ValueError()
if np.isscalar(detector_regions):
detector_regions = [detector_regions]
array = self.array.reshape(self.shape[:-2] + (-1,))[
..., list(detector_regions)
].sum(axis=-1)
else:
if radial_limits is None:
radial_slice = slice(None)
else:
inner_index = int(
(radial_limits[0] - self.radial_offset) / self.radial_sampling
)
outer_index = int(
(radial_limits[1] - self.radial_offset) / self.radial_sampling
)
radial_slice = slice(inner_index, outer_index)
if outer_index > self.shape[-2]:
raise RuntimeError("Integration limit exceeded.")
if azimuthal_limits is None:
azimuthal_slice = slice(None)
else:
left_index = int(azimuthal_limits[0] / self.radial_sampling)
right_index = int(azimuthal_limits[1] / self.radial_sampling)
azimuthal_slice = slice(left_index, right_index)
array = self.array[..., radial_slice, azimuthal_slice].sum(axis=(-2, -1))
return _reduced_scanned_images_or_line_profiles(array, self)
[docs]
def gaussian_source_size(
self, sigma: float | tuple[float, float]
) -> PolarMeasurements:
"""
Simulate the effect of a finite source size on diffraction pattern(s) using a
Gaussian filter.
The filter is not applied to diffraction pattern individually, but the intensity
of diffraction patterns are mixed across scan axes. Applying this filter
requires two linear scan axes.
Applying this filter before integrating the diffraction patterns will produce
the same image as integrating the diffraction patterns first then applying a
Gaussian filter.
Parameters
----------
sigma : float or two float
Standard deviation of Gaussian kernel in the `x` and `y`-direction. If given
as a single number, the standard deviation is equal for both axes.
Returns
-------
filtered_diffraction_patterns : DiffractionPatterns
The filtered diffraction pattern(s).
"""
return _gaussian_source_size(self, sigma)
[docs]
def poisson_noise(
self,
dose_per_area: Optional[float] = None,
total_dose: Optional[float] = None,
samples: int = 1,
seed: Optional[int] = None,
):
"""
Add Poisson noise (i.e. shot noise) to a measurement corresponding to the provided 'total_dose' (per measurement
if applied to an ensemble) or 'dose_per_area' (not applicable for single measurements).
Parameters
----------
dose_per_area : float, optional
The irradiation dose per unit of scan area [electrons per Å:sup:`2`]. This is only valid if the diffraction
patterns has two scan axes.
total_dose : float, optional
The irradiation dose per diffraction pattern.
samples : int, optional
The number of samples to draw from a Poisson distribution. If this is greater than 1, an additional
ensemble axis will be added to the measurement.
seed : int, optional
Seed the random number generator.
Returns
-------
noisy_measurement : BaseMeasurements
The noisy measurement.
"""
if len(_scan_shape(self)) < 2 and dose_per_area is not None:
raise ValueError(
"Polar measurement has less than two scan axes, provide 'total_dose' not 'dose_per_area'."
)
return super().poisson_noise(
dose_per_area=dose_per_area,
total_dose=total_dose,
samples=samples,
seed=seed,
)
[docs]
def to_diffraction_patterns(
self, gpts: int | tuple[int, int], margin: float | tuple[float, float] = 0.1
):
"""
Convert the polar measurements to diffraction patterns by discretizing the polar
bins on a regular grid.
Parameters
----------
gpts : int or two int
Number of grid points describing the diffraction patterns.
margin : float or two float, optional
The margin as a fraction of the outer angle of the polar measurements to add
to the maximum angle of the diffraction patterns.
Returns
-------
diffraction_patterns : DiffractionPatterns
The diffraction patterns discretized from the polar measurements.
"""
if np.isscalar(gpts):
gpts = (gpts,) * 2
if np.isscalar(margin):
margin = (margin,) * 2
angular_sampling = (
(1 + margin[0]) * self.outer_angle / gpts[0] * 2,
(1 + margin[1]) * self.outer_angle / gpts[1] * 2,
)
nbins_radial, nbins_azimuthal = self.base_shape
regions = _polar_detector_bins(
gpts=gpts,
sampling=angular_sampling,
inner=self.radial_offset,
outer=self.outer_angle,
nbins_radial=nbins_radial,
nbins_azimuthal=nbins_azimuthal,
fftshift=True,
rotation=self.azimuthal_offset,
offset=(0.0, 0.0),
return_indices=False,
)
new_array = np.zeros(self.ensemble_shape + regions.shape, dtype=np.float32)
for i, indices in enumerate(label_to_index(regions)):
x, y = np.unravel_index(indices, regions.shape)
radial, azimuthal = np.unravel_index(i, (nbins_radial, nbins_azimuthal))
new_array[..., x, y] = self.array[..., radial, azimuthal][..., None]
new_array[..., regions < 0] = np.nan
wavelength = energy2wavelength(self._get_from_metadata("energy"))
sampling = (
angular_sampling[0] / (wavelength * 1e3),
angular_sampling[1] / (wavelength * 1e3),
)
return DiffractionPatterns(
new_array,
sampling=sampling,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self.metadata,
)
[docs]
def differentials(
self,
direction_1: tuple[int | tuple[int, ...], int | tuple[int, ...]],
direction_2: tuple[int | tuple[int, ...], int | tuple[int, ...]],
return_complex: bool = True,
) -> Images:
"""
Calculate the differential signal by subtracting the intensity of specified
detector regions.
Parameters
----------
direction_1 : tuple of int or tuple of tuple of int
The detector regions used for calculating the differential signal for the
first direction. The first item is the detector region(s) contributing to
the positive term and the second item is the detector region(s) contributing
to the negative terms.
direction_2 : tuple of int or tuple of tuple of int
The detector regions used for calculating the differential signal for the
second direction. The first item is the detector region(s) contributing to
the positive term and the second item is the detector region(s) contributing
to the negative terms.
return_complex : bool, optional
If True, return a complex image where the real and imaginary part represents
`direction_1` and `direction_2`. If False, return images with an ensemble
dimension for the directions.
Returns
-------
differential_image : Images
The (complex) differential image(s).
"""
differential_1 = self.integrate(
detector_regions=direction_1[1]
) - self.integrate(detector_regions=direction_1[0])
differential_2 = self.integrate(
detector_regions=direction_2[1]
) - self.integrate(detector_regions=direction_2[0])
if not return_complex:
stacked = stack(
(differential_1, differential_2), ("direction_1", "direction_2")
)
return stacked
xp = get_array_module(self.array)
array = xp.zeros_like(xp.array(differential_1.array), dtype=xp.complex64)
array.real = differential_1.array
array.imag = differential_2.array
return Images(array, **differential_1._copy_kwargs(exclude=("array",)))
[docs]
def to_image_ensemble(self):
"""
Convert the polar measurements to an ensemble of images, where the radial and
azimuthal angles becomes ensemble axes.
Returns
-------
image_ensemble : Images
"""
image_axes = _scan_axes(self)
xp = get_array_module(self.array)
array = xp.moveaxis(self.array, image_axes, (-2, -1))[..., 0, :, :]
ensemble_axes_metadata = [
axis.copy()
for i, axis in enumerate(self.axes_metadata[:-1])
if i not in image_axes
]
ensemble_axes_metadata[-1]._default_type = "range"
sampling = _scan_sampling(self)
return Images(
array,
sampling=(sampling[0], sampling[1]),
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=self.metadata,
)
[docs]
def show(
self,
ax: Optional[Axes] = None,
gpts: int | tuple[int, int] = (512, 512),
cbar: bool = False,
cmap: Optional[str] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
power: float = 1.0,
common_color_scale: bool = False,
explode: bool | Sequence[bool] = (),
overlay: bool | Sequence[int] = (),
figsize: Optional[tuple[int, int]] = None,
title: bool | str = True,
units: Optional[str] = None,
interact: bool = False,
display: bool = True,
) -> Visualization:
"""
Show the polar measurements as an image or a grid of images. The images are
shown by discretizing the polar bins on a regular grid.
Parameters
----------
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axis. This is not available for exploded
plots.
gpts : int or tuple of int, optional
Number of grid points in the image(s).
cbar : bool, optional
Add colorbar(s) to the image(s). The size and padding of the colorbars may
be adjusted using the `set_cbar_size` and `set_cbar_padding` methods.
cmap : str, optional
Matplotlib colormap name used to map scalar data to colors. If the
measurement is complex the colormap must be one of 'hsv' or 'hsluv'.
vmin : float, optional
Minimum of the intensity color scale. Default is the minimum of the array
values.
vmax : float, optional
Maximum of the intensity color scale. Default is the maximum of the array
values.
power : float
Show image on a power scale.
common_color_scale : bool, optional
If True all images in an image grid are shown on the same colorscale, and a
single colorbar is created (if it is requested). Default is False.
explode : bool, optional
If True, a grid of images is created for all the items of the last two
ensemble axes. If False, the first ensemble item is shown. May be given as a
sequence of axis indices to create a grid of images from the specified axes.
The default is determined by the axis metadata.
overlay : bool or sequence of int, optional
If True, all line profiles in the ensemble are shown in a single plot.
If False, only the first ensemble item is shown. May be given as a sequence
of axis indices to specify which line profiles in the ensemble to show
together. The default is determined by the axis metadata.
figsize : two int, optional
The figure size given as width and height in inches, passed to
`matplotlib.pyplot.figure`.
title : bool or str, optional
Set the column title of the images. If True is given instead of a string the
title will be given by the value corresponding to the "name" key of the axes
metadata dictionary, if this item exists.
units : str
The units used for the x and y axes. The given units must be compatible with
the axes of the images.
interact : bool
If True, create an interactive visualization. This requires enabling the
`ipympl` Matplotlib backend.
display : bool, optional
If True (default) the figure is displayed immediately.
Returns
-------
measurement_visualization_2d : MeasurementVisualizationImshow
"""
diffraction_patterns = self.to_diffraction_patterns(gpts=gpts)
if not interact:
diffraction_patterns.compute()
return diffraction_patterns.show(
ax=ax,
cbar=cbar,
cmap=cmap,
vmin=vmin,
vmax=vmax,
power=power,
common_color_scale=common_color_scale,
explode=explode,
overlay=overlay,
figsize=figsize,
title=title,
units=units,
interact=interact,
display=display,
)
[docs]
@jit(nopython=True, nogil=True, fastmath=True)
def calculate_max_reciprocal_space_vector(hkl, reciprocal_lattice_vectors):
k_max = 0.0
for i in range(len(hkl)):
lengths = (
(
hkl[i, 0] * reciprocal_lattice_vectors[..., 0, :]
+ hkl[i, 1] * reciprocal_lattice_vectors[..., 1, :]
+ hkl[i, 2] * reciprocal_lattice_vectors[..., 2, :]
)
** 2
).sum(-1)
if hasattr(lengths, "max"):
lengths = lengths.max()
k_max = max(k_max, lengths)
return np.sqrt(k_max)
# @jit(nopython=True, nogil=True)
def _reciprocal_lattice_vector_lengths(max_lengths, hkl, reciprocal_lattice_vectors):
for i in range(len(hkl)): # pylint: disable=not-an-iterable
lengths = np.sqrt(
(
(
hkl[i, 0] * reciprocal_lattice_vectors[..., 0, :]
+ hkl[i, 1] * reciprocal_lattice_vectors[..., 1, :]
+ hkl[i, 2] * reciprocal_lattice_vectors[..., 2, :]
)
** 2
).sum(-1)
)
max_lengths[i] = lengths.max()
return max_lengths
[docs]
def reciprocal_lattice_vector_lengths(hkl, reciprocal_lattice_vectors):
max_lengths = np.zeros(len(hkl))
return _reciprocal_lattice_vector_lengths(
max_lengths, hkl, reciprocal_lattice_vectors
)
@jit(nopython=True, nogil=True)
def _reciprocal_lattice_vector_mask(mask, hkl, reciprocal_lattice_vectors, k_max):
for i in range(len(hkl)): # pylint: disable=not-an-iterable
lengths = (
(
hkl[i, 0] * reciprocal_lattice_vectors[..., 0, :]
+ hkl[i, 1] * reciprocal_lattice_vectors[..., 1, :]
+ hkl[i, 2] * reciprocal_lattice_vectors[..., 2, :]
)
** 2
).sum(-1)
new_mask = lengths < k_max**2
if isinstance(new_mask, bool):
mask[i] = new_mask
else:
mask[i] = new_mask.any()
[docs]
class IndexedDiffractionPatterns(BaseMeasurements):
"""
Diffraction patterns indexed by their Miller indices.
Parameters
----------
array : np.ndarray
1D or greater array of type `float` or `complex`. The last axis represents the
diffraction spots and should have the same length as the number of miller
indices, any preceding axis represents an ensemble axis.
miller_indices : np.ndarray
The miller indices of the diffraction spots as an N x 3 array where N is the
number of miller indices. The order of the miller indices must correspond to the
array of intensities. The second axis represents each hkl miller index.
reciprocal_lattice_vectors : np.ndarray
The reciprocal lattice vectors of the crystal as a 3 x 3 array. The first axis
represents miller indices and the order of the items must correspond to the
array of intensities. The second axis represents the reciprocal space positions
in x, y and z [1/Å].
ensemble_axes_metadata : list of AxisMetadata, optional
List of metadata associated with the ensemble axes. The length and item order
must match the ensemble axes.
metadata : dict, optional
A dictionary defining measurement metadata.
"""
_base_dims = 1
[docs]
def __init__(
self,
array: da.core.Array | np.ndarray,
miller_indices: np.ndarray,
reciprocal_lattice_vectors: np.ndarray,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
metadata: Optional[dict] = None,
):
if isinstance(reciprocal_lattice_vectors, Cell):
reciprocal_lattice_vectors = np.array(reciprocal_lattice_vectors)
if len(reciprocal_lattice_vectors.shape) <= len(array.shape):
reciprocal_lattice_vectors = reciprocal_lattice_vectors[
(None,) * (len(reciprocal_lattice_vectors.shape) - len(array.shape) + 1)
]
if not is_broadcastable(
array.shape[:-1], reciprocal_lattice_vectors.shape[:-2]
):
raise ValueError()
if not len(miller_indices) == array.shape[-1]:
raise ValueError(
"The number of miller indices must be equal to the number of"
" diffraction spots."
)
super().__init__(
array=array,
ensemble_axes_metadata=ensemble_axes_metadata,
metadata=metadata,
)
self._miller_indices = miller_indices
self._intensities = array
self._reciprocal_lattice_vectors = reciprocal_lattice_vectors
def _area_per_pixel(self):
raise NotImplementedError
@property
def base_axes_metadata(self) -> list[AxisMetadata]:
return [AxisMetadata(label="hkl")]
@property
def intensities(self) -> np.ndarray:
"""
Intensities of the diffraction spots.
"""
return self._array
@property
def reciprocal_lattice_vectors(self) -> np.ndarray:
"""
Reciprocal lattice vectors of the diffraction spots.
"""
return self._reciprocal_lattice_vectors
@property
def positions(self) -> np.ndarray:
"""
Reciprocal space positions of the diffraction spots.
"""
positions = self.miller_indices @ self.reciprocal_lattice_vectors
return positions
@property
def all_positions(self) -> np.ndarray:
"""
Reciprocal space positions of the diffraction spots.
"""
repeats = ()
for n, m in zip(self.shape[:-1], self.reciprocal_lattice_vectors.shape[:-2]):
if n == m:
repeats += (1,)
elif m == 1:
repeats += (n,)
else:
raise RuntimeError("Incompatible shapes.")
positions = np.tile(self.positions, repeats + (1, 1))
return positions
@property
def miller_indices(self) -> np.ndarray:
"""
Miller indices of the diffraction spots.
"""
return self._miller_indices
@property
def angular_positions(self):
"""
Scattering angles of the diffraction spots.
"""
wavelength = energy2wavelength(self._get_from_metadata("energy"))
return self.positions * wavelength * 1e3
@property
def ensemble_shape(self) -> tuple:
return self.intensities.shape[:-1]
@classmethod
def _pack_kwargs(cls, kwargs):
kwargs["miller_indices"] = kwargs["miller_indices"].tolist()
# kwargs["reciprocal_lattice_vectors"] = [
# (float(position[0]), float(position[1]), float(position[2]))
# for position in kwargs["reciprocal_lattice_vectors"]
# ]
kwargs["reciprocal_lattice_vectors"] = kwargs[
"reciprocal_lattice_vectors"
].tolist()
packed_kwargs = super()._pack_kwargs(kwargs)
return packed_kwargs
@classmethod
def _unpack_kwargs(cls, attrs):
kwargs = super()._unpack_kwargs(attrs)
kwargs["miller_indices"] = np.array(kwargs["miller_indices"], dtype=int)
kwargs["reciprocal_lattice_vectors"] = np.array(
kwargs["reciprocal_lattice_vectors"], dtype=float
)
return kwargs
def __getitem__(self, items):
items = _validate_array_items(items, self.shape)
kwargs = self.get_items(items)
new_items = ()
for i, n in zip(items, self.reciprocal_lattice_vectors.shape):
if n == 1:
if isinstance(i, int):
new_items += (0,)
else:
new_items += (slice(None),)
else:
new_items += (i,)
# items = tuple(
# 0 if (isinstance(i, int) and (n == 1)) else slice(None)
# for i, n in zip(items, self.reciprocal_lattice_vectors.shape)
# )
kwargs["reciprocal_lattice_vectors"] = self.reciprocal_lattice_vectors[
new_items
]
return self.__class__(**kwargs)
def _reduction(
self,
reduction_func,
axes,
keepdims: bool = False,
split_every: int = 2,
**kwargs,
):
reciprocal_lattice_vectors = getattr(np, reduction_func)(
self.reciprocal_lattice_vectors, axis=axes, keepdims=keepdims
)
kwargs["reciprocal_lattice_vectors"] = reciprocal_lattice_vectors
return super()._reduction(reduction_func, axes, keepdims, split_every, **kwargs)
[docs]
def remove_low_intensity(self, threshold: float = 1e-3):
"""
Remove diffraction spots with intensity below a threshold for all ensemble
dimensions.
Parameters
----------
threshold : float
Intensity threshold for removing diffraction spots.
Returns
-------
thresholded_spots : IndexedDiffractionPatterns
The indexed diffraction spots with an intensity above the given threshold.
"""
if self.is_lazy:
raise RuntimeError("Cannot threshold lazy IndexedDiffractionPatterns.")
ensemble_axes = tuple(range(self.ensemble_dims))
xp = get_array_module(self.intensities)
mask = xp.max(self.intensities, axis=ensemble_axes) > threshold
mask = asnumpy(mask)
miller_indices = self.miller_indices[mask]
intensities = self.intensities[..., mask]
return self.__class__(
intensities,
miller_indices,
reciprocal_lattice_vectors=self.reciprocal_lattice_vectors,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self._metadata,
)
[docs]
def sort(self, criterion: str = "distance"):
"""
Sort the diffraction spots according to a given criterion.
Parameters
----------
criterion : {'distance', 'intensity'}
The boundary parameter determines how the images are extended beyond their
boundaries when the filter overlaps with a border.
``distance`` :
Sort according to the distance in reciprocal space from the zero
frequency.
``intensity`` :
Sort according to the intensity of the diffraction spots.
Returns
-------
sorted_spots : IndexedDiffractionPatterns
The indexed diffraction spots sorted according to the given criterion.
"""
if self.lazy:
raise RuntimeError("Cannot sort lazy IndexedDiffractionPatterns.")
if criterion == "distance":
criterion = -np.linalg.norm(self.positions, axis=1)
elif criterion == "intensity":
ensemble_axes = tuple(range(len(self.ensemble_shape)))
criterion = -np.max(self.intensities, axis=ensemble_axes)
else:
raise ValueError()
order = np.argsort(criterion)
array = self.array[..., order]
miller_indices = self.miller_indices[order]
reciprocal_lattice_vectors = self.reciprocal_lattice_vectors[..., order, :, :]
return self.__class__(
array,
miller_indices,
reciprocal_lattice_vectors=reciprocal_lattice_vectors,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self._metadata,
)
[docs]
def crop(self, max_angle: Optional[float] = None, k_max: Optional[float] = None):
"""
Crop the indexed diffraction patterns such that they only include spots with
spatial frequencies (scattering angles) up to a given limit.
Parameters
----------
max_angle : float, optional
The maximum included scattering angle in the cropped diffraction patterns.
k_max : float, optional
The maximum included reciprocal lattice vector in the cropped diffraction
spots.
Returns
-------
cropped : IndexedDiffractionPatterns
The cropped indexed diffraction spots.
"""
if max_angle is not None and k_max is None:
wavelength = energy2wavelength(self._get_from_metadata("energy"))
k_max = max_angle / wavelength / 1e3
elif not k_max or max_angle:
raise ValueError("Either 'max_angle' or 'k_max' must be given.")
mask = np.zeros(len(self.miller_indices), dtype=bool)
_reciprocal_lattice_vector_mask(
mask,
self.miller_indices.astype(self.reciprocal_lattice_vectors.dtype),
self.reciprocal_lattice_vectors,
k_max,
)
miller_indices = self.miller_indices[mask]
array = self.array[..., mask]
return self.__class__(
array,
miller_indices,
reciprocal_lattice_vectors=self.reciprocal_lattice_vectors,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self._metadata,
)
[docs]
def normalize_to_spot(self, spot: Optional[tuple[int, int, int]] = None):
"""
Normalize the intensity of the diffraction spots.
Parameters
----------
spot : tuple of three int
The intensities will be normalized with respect to the intensity of this
spot. Defaults to the most intense spot.
Returns
-------
normalized_indexed_diffraction_patterns : IndexedDiffractionPatterns
"""
intensities_dict = self.intensities_dict
if spot is None:
c = np.max(self.intensities)
else:
c = intensities_dict[spot]
intensities = self.intensities / c
return self.__class__(
intensities,
self.miller_indices.copy(),
self.positions.copy(),
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self._metadata,
)
[docs]
def to_data_array(self):
"""
Convert the indexed diffraction patterns to xarray DataArray.
Returns
-------
data_array_of_indexed_spots : xarray.DataArray
"""
if xr is None:
raise RuntimeError("xarray is required to convert to DataArray.")
coords = []
for axes_metadata, n in zip(self.ensemble_axes_metadata, self.ensemble_shape):
new_coords = list(axes_metadata.coordinates(n))
new_coords = xr.DataArray(
new_coords,
dims=[axes_metadata.label],
attrs={"units": axes_metadata.units},
)
coords.append(new_coords)
hkl = list([f"{h} {k} {l}" for h, k, l in self.miller_indices])
hkl = xr.DataArray(hkl, name="hkl", dims=("hkl",))
coords.append(hkl)
dims = [axes_metadata.label for axes_metadata in self.ensemble_axes_metadata]
dims = dims + ["hkl"]
data = xr.DataArray(
self.array,
dims=dims,
coords=coords,
attrs={
"long_name": "intensity",
"units": self.metadata.get("units", "arb. unit"),
},
)
k = reciprocal_lattice_vector_lengths(
self.miller_indices, self.reciprocal_lattice_vectors
)
data = data.assign_coords(k=("hkl", k))
return data
def _miller_indices_to_string(self):
return [f"{h} {k} {l}" for h, k, l in self.miller_indices]
[docs]
def to_dataframe(self):
"""
Convert the indexed diffraction patterns to pandas DataFrame.
Returns
-------
data_frame : pd.DataFrame
"""
if pd is None:
raise RuntimeError("pandas is required to convert to DataFrame.")
if self.ensemble_shape:
if len(self.ensemble_shape) > 1:
raise RuntimeError(
"cannot convert indexed diffraction patterns with more than one"
" ensemble axis to dataframe"
)
intensities = {
hkl: self.intensities[..., i]
for i, hkl in enumerate(self._miller_indices_to_string())
}
axes_metadata = self.ensemble_axes_metadata[0]
if hasattr(axes_metadata, "values"):
index = axes_metadata.values
else:
index = list(range(len(self.intensities)))
df = pd.DataFrame(intensities, index=index)
with config.set({"visualize.use_tex": False}):
df.index.name = self.axes_metadata[0].format_label()
df.columns.name = self.axes_metadata[1].format_label()
return df
else:
intensities = {
hkl: intensity
for hkl, intensity in zip(
self._miller_indices_to_string(), self.intensities
)
}
return pd.DataFrame(intensities, index=[0])
[docs]
def block_direct(self):
"""
Remove the zero-order spot.
Returns
-------
blocked : IndexedDiffractionPatterns
The indexed diffraction spots without the zero-order spot.
"""
to_delete = np.where(np.all(self.miller_indices == 0, axis=1))[0]
miller_indices = np.delete(self.miller_indices, to_delete, axis=0)
intensities = np.delete(self.intensities, to_delete, axis=-1)
return self.__class__(
intensities,
miller_indices,
reciprocal_lattice_vectors=self.reciprocal_lattice_vectors,
ensemble_axes_metadata=self.ensemble_axes_metadata,
metadata=self._metadata,
)
def max_reciprocal_space_vector_length(self):
return calculate_max_reciprocal_space_vector(
self.miller_indices, self.reciprocal_lattice_vectors
)
[docs]
def show(
self,
ax: Optional[Axes] = None,
cbar: bool = False,
cmap: Optional[str] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
power: float = 1.0,
common_color_scale: bool = False,
scale: float = 0.5,
explode: bool | Sequence[bool] = (),
overlay: bool | Sequence[bool] = (),
figsize: Optional[tuple[int, int]] = None,
title: bool | str = True,
units: Optional[str] = None,
interact: bool = False,
display: bool = True,
**kwargs,
):
"""
Show the diffraction spots as an EllipseCollection using matplotlib.
Parameters
----------
ax : matplotlib.axes.Axes, optional
If given the plots are added to the axis. This is not available for exploded
plots.
cbar : bool, optional
Add colorbar(s) to the image(s). The size and padding of the colorbars may
be adjusted using the `set_cbar_size` and `set_cbar_padding` methods.
cmap : str, optional
Matplotlib colormap name used to map scalar data to colors. If the
measurement is complex the colormap must be one of 'hsv' or 'hsluv'.
vmin : float, optional
Minimum of the intensity color scale. Default is the minimum of the array
values.
vmax : float, optional
Maximum of the intensity color scale. Default is the maximum of the array
values.
power : float
Show diffraction spots intensities on a power scale.
common_color_scale : bool, optional
If True all images in an image grid are shown on the same colorscale, and a
single colorbar is created (if it is requested). Default is False.
scale : float, optional
Scale the radii of the circles representing the diffraction spots.
explode : bool or sequence of bool, optional
If True, a grid of plots is created for all the items of the last two
ensemble axes. If False, the first ensemble item is shown. May be given as a
sequence of axis indices to create a grid of plots from the specified axes.
The default is determined by the axis metadata.
overlay : bool or sequence of int, optional
If True, all line profiles in the ensemble are shown in a single plot.
If False, only the first ensemble item is shown. May be given as a sequence
of axis indices to specify which line profiles in the ensemble to show
together. The default is determined by the axis metadata.
figsize : two int, optional
The figure size given as width and height in inches, passed to
`matplotlib.pyplot.figure`.
title : bool or str, optional
Set the column title of the plots. If True is given instead of a string the
title will be given by the value corresponding to the "name" key of the axes
metadata dictionary, if this item exists.
units : str
The units used for the x and y axes. The given units must be compatible with
the axes of the plots.
interact : bool
If True, create an interactive visualization. This requires enabling the
`ipympl` Matplotlib backend.
display : bool, optional
If True (default) the figure is displayed immediately.
Returns
-------
visualization : Visualization
"""
k_max = self.max_reciprocal_space_vector_length() * get_conversion_factor(
units, "1/Å", self.metadata.get("energy", None)
)
xlim = (-k_max, k_max)
ylim = (-k_max, k_max)
visualization = Visualization(
measurement=self,
ax=ax,
common_scale=common_color_scale,
figsize=figsize,
title=title,
aspect=True,
share_x=True,
share_y=True,
explode=explode,
overlay=overlay,
interactive=not interact and display,
value_limits=(vmin, vmax),
power=power,
cmap=cmap,
cbar=cbar,
scale=scale,
units=units,
xlim=xlim,
ylim=ylim,
**kwargs,
)
if interact:
visualization.interact(ScatterGUI, display=display)
return visualization
@property
def intensities_dict(self) -> dict[tuple, np.ndarray]:
"""
A dictionary mapping miller indices to intensities.
"""
intensities = {
tuple(hkl): intensity
for hkl, intensity in zip(
self.miller_indices,
np.moveaxis(self.intensities, -1, 0),
)
}
values = np.zeros(self.shape[:-1], dtype=np.float32)
intensities = defaultdict(lambda: values, intensities)
return intensities
@property
def positions_dict(self) -> dict[tuple, np.ndarray]:
"""
A dictionary mapping miller indices to reciprocal space positions [1/Å].
"""
positions = {
tuple(hkl): position
for hkl, position in zip(
self.miller_indices,
np.moveaxis(self.reciprocal_lattice_vectors, -2, 0),
)
}
return positions
@classmethod
def _stack(
cls,
diffraction_spots: IndexedDiffractionPatterns,
axis_metadata: list[AxisMetadata],
axis: int,
):
intensities = [spots.intensities_dict for spots in diffraction_spots]
# positions = [spots.positions_dict for spots in diffraction_spots]
# def merge_dicts_no_overwrite(dict1, dict2):
# return {**dict1, **{k: v for k, v in dict2.items() if k not in dict1}}
# merged = {}
# for positions1 in positions:
# merged = merge_dicts_no_overwrite(merged, positions1)
# positions = [
# merge_dicts_no_overwrite(positions1, merged) for positions1 in positions
# ]
miller_indices = list(
set(itertools.chain(*[intensities1.keys() for intensities1 in intensities]))
)
new_intensities = {}
for hkl in miller_indices:
new_intensities[hkl] = []
for intensities1 in intensities:
new_intensities[hkl].append(intensities1[hkl])
new_intensities[hkl] = np.stack(new_intensities[hkl], axis=axis)
miller_indices = np.stack(list(new_intensities.keys()), axis=0)
intensities = np.stack(list(new_intensities.values()), axis=-1)
positions = np.stack(
[spots.reciprocal_lattice_vectors for spots in diffraction_spots], axis=0
)
ensemble_axes_metadata = [
axis_metadata.copy()
for axis_metadata in diffraction_spots[0].ensemble_axes_metadata
]
ensemble_axes_metadata.insert(axis, axis_metadata)
metadata = diffraction_spots[0].metadata
return IndexedDiffractionPatterns(
intensities, miller_indices, positions, ensemble_axes_metadata, metadata
)