"""Module for describing different types of scans."""
from __future__ import annotations
import itertools
from abc import abstractmethod
from typing import TYPE_CHECKING
import dask.array as da
import numpy as np
from ase import Atom, Atoms
from matplotlib.axes import Axes
from matplotlib.patches import Rectangle
from abtem.array import ArrayObject, T
from abtem.core.axes import ScanAxis, PositionsAxis, AxisMetadata
from abtem.core.backend import get_array_module, validate_device
from abtem.core.chunks import validate_chunks
from abtem.core.ensemble import _wrap_with_array, unpack_blockwise_args
from abtem.core.fft import fft_shift_kernel
from abtem.core.grid import Grid, HasGridMixin
from abtem.measurements import _scan_axes
from abtem.potentials.iam import BasePotential, _validate_potential
from abtem.transfer import nyquist_sampling
from abtem.transform import ReciprocalSpaceMultiplication
if TYPE_CHECKING:
from abtem.waves import Waves, Probe
from abtem.prism.s_matrix import BaseSMatrix
def _validate_scan(scan: np.ndarray | BaseScan, probe: Probe = None) -> BaseScan:
if scan is None and probe is None:
scan = CustomScan(np.zeros((1, 2)), squeeze=True)
elif scan is None:
scan = CustomScan(np.zeros((0, 2)), squeeze=True)
if not isinstance(scan, BaseScan):
scan = CustomScan(scan, squeeze=True)
if probe is not None:
scan = scan.copy()
scan.match_probe(probe)
return scan
def _validate_scan_sampling(scan: BaseScan, probe: Probe):
if scan.sampling is None:
if not hasattr(probe, "semiangle_cutoff"):
raise ValueError()
if hasattr(probe, "dummy_probes"):
probe = probe.dummy_probes()
semiangle_cutoff = probe.aperture._max_semiangle_cutoff
scan.sampling = 0.99 * nyquist_sampling(semiangle_cutoff, probe.energy)
[docs]
class BaseScan(ReciprocalSpaceMultiplication):
"""Abstract class to describe scans."""
def __len__(self) -> int:
return self.num_positions
@property
def num_positions(self) -> int:
"""Number of probe positions in the scan."""
return len(self.get_positions())
@property
@abstractmethod
def shape(self) -> tuple[int, ...]:
"""The shape the scan."""
pass
@property
def ensemble_shape(self) -> tuple[int, ...]:
return self.shape
@property
def _default_ensemble_chunks(self):
return ("auto",) * len(self.ensemble_shape)
[docs]
@abstractmethod
def get_positions(self, *args, **kwargs) -> np.ndarray:
"""Get the scan positions as numpy array."""
pass
def _get_weights(self):
raise NotImplementedError
@property
@abstractmethod
def limits(self):
"""Lower left and upper right corner of the bounding box containing all positions in the scan."""
pass
@abstractmethod
def _sort_into_extents(
self,
extents: tuple[
tuple[tuple[float, float], ...], tuple[tuple[float, float], ...]
],
):
pass
def _evaluate_kernel(self, waves: Waves) -> np.ndarray:
"""
Evaluate the array to be multiplied with the waves in reciprocal space.
Parameters
----------
waves : Waves, optional
If given, the array will be evaluated to match the provided waves.
Returns
-------
kernel : np.ndarray or dask.array.Array
"""
device = validate_device(waves.device)
xp = get_array_module(device)
waves.grid.check_is_defined()
positions = self.get_positions()
if len(positions) == 0:
return xp.ones(waves.gpts, dtype=xp.complex64)
positions = xp.asarray(positions) / xp.asarray(waves.sampling).astype(
np.float32
)
kernel = fft_shift_kernel(positions, shape=waves.gpts)
try:
kernel *= self._get_weights()[..., None, None]
except NotImplementedError:
pass
return kernel
#
# class SourceDistribution(BaseScan):
# """
# Distribution of electron source offsets.
#
# Parameters
# ----------
# distribution : 2D :class:`.BaseDistribution`
# Distribution describing the positions and weights of the source offsets.
# """
#
# def __init__(self, distribution: BaseDistribution):
# self._distribution = distribution
#
# @property
# def shape(self):
# return self._distribution.shape
#
# def get_positions(self):
# xi = [factor.values for factor in self._distribution.factors]
# return np.stack(np.meshgrid(*xi, indexing="ij"), axis=-1)
#
# def _get_weights(self):
# return self._distribution.weights
#
# @property
# def ensemble_axes_metadata(self):
# return [PositionsAxis()] * len(self.shape)
#
# def ensemble_blocks(self, chunks=None):
# if chunks is None:
# chunks = self._default_ensemble_chunks
#
# chunks = validate_chunks(self.ensemble_shape, chunks, limit=None)
#
# blocks = ()
# for parameter, n in zip(self._distribution.factors, chunks):
# blocks += (parameter.divide(n, lazy=True),)
#
# return blocks
#
# def ensemble_partial(self):
# def distribution(*args):
# factors = [arg.item() for arg in args]
# dist = SourceDistribution(AxisAlignedDistributionND(factors))
# arr = np.empty((1,) * len(args), dtype=object)
# arr.itemset(dist)
# return arr
#
# return distribution
#
# @property
# def limits(self):
# pass
[docs]
class CustomScan(BaseScan):
"""
Custom scan based on explicit 2D probe positions.
Parameters
----------
positions : np.ndarray, optional
Scan positions [Å]. Anything that can be converted to an ndarray of shape (n, 3) is accepted. Default is
(0., 0.).
"""
[docs]
def __init__(self, positions: np.ndarray = (0.0, 0.0), squeeze: bool = False):
positions = np.array(positions, dtype=np.float32)
if len(positions.shape) == 1:
positions = positions[None]
self._positions = positions
self._squeeze = squeeze
super().__init__()
[docs]
def match_probe(self, probe: Probe | BaseSMatrix):
"""
Sets the positions to a single position in the center of the probe extent.
Parameters
----------
probe : Probe or BaseSMatrix
The matched probe or s-matrix.
"""
if len(self.positions) == 0:
probe.grid.check_is_defined()
self._positions = np.array(probe.extent, dtype=np.float32)[None] / 2.0
@property
def ensemble_axes_metadata(self):
if len(self.positions) == 0:
return []
return [
PositionsAxis(
values=tuple(
(float(position[0]), float(position[1]))
for position in self.positions
),
_squeeze=self._squeeze,
)
]
@staticmethod
def _from_partitioned_args_func(*args, **kwargs):
scan = unpack_blockwise_args(args)
positions = scan[0]["positions"]
new_scan = CustomScan(positions, **kwargs)
new_scan = _wrap_with_array(new_scan, 1)
return new_scan
def _from_partitioned_args(self):
if len(self.positions) == 0:
return lambda *args, **kwargs: _wrap_with_array(self)
return self._from_partitioned_args_func
def _partition_args(self, chunks=None, lazy: bool = True):
if len(self.positions) == 0:
return ()
chunks = self._validate_ensemble_chunks(chunks)
cumchunks = tuple(np.cumsum(chunks[0]))
positions = np.empty(len(chunks[0]), dtype=object)
for i, (start_chunk, chunk) in enumerate(zip((0,) + cumchunks, chunks[0])):
positions.itemset(
i, {"positions": self._positions[start_chunk : start_chunk + chunk]}
)
if lazy:
positions = da.from_array(positions, chunks=1)
return (positions,)
def _sort_into_extents(self, extents):
new_positions = np.zeros_like(self.positions)
chunks = ()
start = 0
for x_extents, y_extents in itertools.product(*extents):
mask = (
(self.positions[:, 0] >= x_extents[0])
* (self.positions[:, 0] < x_extents[1])
* (self.positions[:, 1] >= y_extents[0])
* (self.positions[:, 1] < y_extents[1])
)
n = np.sum(mask)
chunks += (n,)
stop = start + n
new_positions[start:stop] = self.positions[mask]
start = stop
assert sum(chunks) == len(self)
return CustomScan(new_positions), (chunks,)
@property
def shape(self):
if len(self.positions) == 0:
return ()
return self.positions.shape[:-1]
@property
def positions(self):
"""Scan positions [Å]."""
return self._positions
@property
def limits(self):
return [
(np.min(self.positions[:, 0]), np.min(self.positions[:, 1])),
(np.max(self.positions[:, 0]), np.max(self.positions[:, 1])),
]
[docs]
def get_positions(self) -> np.ndarray:
return self._positions
def _validate_coordinate(
coordinate: tuple[float, float] | Atom,
potential: BasePotential | Atoms = None,
fractional: bool = False,
) -> tuple[float, float]:
if isinstance(coordinate, Atom):
if fractional:
raise ValueError()
coordinate = coordinate.x, coordinate.y
if fractional:
potential = _validate_potential(potential)
if isinstance(potential, BasePotential):
if potential is None:
raise ValueError("provide potential for fractional coordinates")
potential = _validate_potential(potential)
extent = potential.extent
else:
extent = potential
coordinate = (
extent[0] * coordinate[0],
extent[1] * coordinate[1],
)
coordinate = coordinate if coordinate is None else tuple(coordinate)
return coordinate
def _validate_coordinates(
start: tuple[float, float] | Atom,
end: tuple[float, float] | Atom,
potential: BasePotential | Atoms,
fractional: bool,
) -> tuple[tuple[float, float], tuple[float, float]]:
if fractional:
potential = _validate_potential(potential)
start = _validate_coordinate(start, potential, fractional)
end = _validate_coordinate(end, potential, fractional)
if start is not None and end is not None:
if np.allclose(start, end):
raise RuntimeError("scan start and end is identical")
return start, end
[docs]
class LineScan(BaseScan):
"""
A scan along a straight line.
Parameters
----------
start : two float or Atom, optional
Start point of the scan [Å]. May be given as fractional coordinate if `fractional=True`. Default is (0., 0.).
end : two float or Atom, optional
End point of the scan [Å]. May be given as fractional coordinate if `fractional=True`.
Default is None, the scan end point will match the extent of the potential.
gpts : int, optional
Number of scan positions. Default is None. Provide one of gpts or sampling.
sampling : float, optional
Sampling rate of scan positions [1 / Å]. Provide one of gpts or sampling. If not provided the sampling will
match the Nyquist sampling of the Probe in a multislice simulation.
endpoint : bool, optional
If True, end is the last position. Otherwise, it is not included. Default is True.
fractional : bool, optional
If True, use fractional coordinates with respect to the given potential for `start` and `end`.
potential : BasePotential or Atoms, optional
Potential defining the grid with respect to which the fractional coordinates should be given.
"""
[docs]
def __init__(
self,
start: tuple[float, float] | Atom = (0.0, 0.0),
end: tuple[float, float] | Atom = None,
gpts: int = None,
sampling: float = None,
endpoint: bool = True,
fractional: bool = False,
potential: BasePotential | Atoms = None,
):
self._gpts = gpts
self._sampling = sampling
self._start, self._end = _validate_coordinates(
start, end, potential, fractional
)
self._endpoint = endpoint
self._adjust_gpts()
self._adjust_sampling()
super().__init__()
@property
def direction(self):
"""Normal vector pointing from `start` to `end`."""
direction = np.array(self.end) - np.array(self.start)
return direction / np.linalg.norm(direction)
@property
def angle(self):
"""Angle of the line from `start` to `end` and the `x`-axis [deg.]."""
direction = self.direction
return np.arctan2(direction[1], direction[0])
[docs]
def add_margin(self, margin: float | tuple[float, float]):
"""
Extend the line scan by adding a margin to the start and end of the line scan.
Parameters
----------
margin : float or tuple of float
The margin added to the start and end of the linescan [Å]. If float the same margin is added.
"""
if not np.isscalar(margin):
margin = (margin,) * 2
direction = self.direction
self.start = tuple(np.array(self.start) - direction * margin[0])
self.end = tuple(np.array(self.end) + direction * margin[1])
return self
[docs]
@classmethod
def at_position(
cls,
center: tuple[float, float] | Atom,
extent: float = 1.0,
angle: float = 0.0,
gpts: int = None,
sampling: float = None,
endpoint: bool = True,
):
"""
Make a line scan centered at a given position.
Parameters
----------
center : two float
Center position of the line [Å]. May be given as an Atom.
angle : float
Angle of the line [deg.].
extent : float
Extent of the line [Å].
gpts : int
Number of grid points along the line.
sampling : float
Sampling of grid points along the line [Å].
endpoint : bool
Sets whether the ending position is included or not.
Returns
-------
line_scan : LineScan
"""
if isinstance(center, Atom):
center = (center.x, center.y)
direction = np.array((np.cos(np.deg2rad(angle)), np.sin(np.deg2rad(angle))))
start = tuple(np.array(center) - extent / 2 * direction)
end = tuple(np.array(center) + extent / 2 * direction)
return cls(
start=start, end=end, gpts=gpts, sampling=sampling, endpoint=endpoint
)
[docs]
def match_probe(self, probe: Probe | BaseSMatrix):
"""
Sets sampling to the Nyquist frequency. If the start and end point of the scan is not given, set them to the
lower and upper left corners of the probe extent.
Parameters
----------
probe : Probe or BaseSMatrix
The matched probe or s-matrix.
"""
if self.start is None:
self.start = (0.0, 0.0)
if self.end is None and probe.extent is not None:
self.end = (0.0, probe.extent[1])
_validate_scan_sampling(self, probe)
@property
def extent(self) -> float | None:
"""Grid extent [Å]."""
if self._start is None or self._end is None:
return None
return np.linalg.norm(np.array(self._end) - np.array(self._start))
def _adjust_gpts(self):
if self.extent is None or self.sampling is None:
return
self._gpts = int(np.ceil(self.extent / self.sampling))
self._adjust_sampling()
def _adjust_sampling(self):
if self.extent is None or self.gpts is None:
return
if self.endpoint and self.gpts > 1:
self._sampling = self.extent / (self.gpts - 1)
else:
self._sampling = self.extent / self.gpts
@property
def endpoint(self) -> bool:
"""True if the scan endpoint is the last position. Otherwise, the endpoint is not included."""
return self._endpoint
@property
def limits(self) -> tuple[tuple[float, float], tuple[float, float]]:
return self.start, self.end
@property
def gpts(self) -> int:
"""Number of grid points."""
return self._gpts
@gpts.setter
def gpts(self, gpts: int):
self._gpts = gpts
self._adjust_sampling()
@property
def sampling(self) -> float:
"""Grid sampling [Å]."""
return self._sampling
@sampling.setter
def sampling(self, sampling: float):
self._sampling = sampling
self._adjust_gpts()
@property
def shape(self) -> tuple[int]:
return (self._gpts,)
@property
def metadata(self):
return {"start": self.start, "end": self.end}
@property
def start(self) -> tuple[float, float] | None:
"""
Start point of the scan [Å].
"""
return self._start
@start.setter
def start(self, start: tuple[float, float]):
if start is not None:
start = (float(start[0]), float(start[1]))
self._start = start
self._adjust_gpts()
@property
def end(self) -> tuple[float, float] | None:
"""
End point of the scan [Å].
"""
return self._end
@end.setter
def end(self, end: tuple[float, float]):
if end is not None:
end = (float(end[0]), float(end[1]))
self._end = end
self._adjust_gpts()
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return [
ScanAxis(
label="r",
sampling=self.sampling,
offset=0.0,
units="Å",
endpoint=self.endpoint,
)
]
@property
def ensemble_shape(self):
return self.shape
def _out_ensemble_shape(
self, array_object: ArrayObject, index: int = 0
) -> tuple[int, ...]:
return self.ensemble_shape + array_object.ensemble_shape
def _out_ensemble_axes_metadata(
self, array_object: ArrayObject | T, index: int = 0
) -> list[AxisMetadata] | tuple[list[AxisMetadata], ...]:
ensemble_axes_metadata = self.ensemble_axes_metadata
if len(_scan_axes(array_object)) > 0:
for axis in ensemble_axes_metadata:
axis._main = False
return [*ensemble_axes_metadata, *array_object.ensemble_axes_metadata]
@property
def _default_ensemble_chunks(self):
return ("auto",)
def _sort_into_extents(self, extents):
raise NotImplementedError
@staticmethod
def _from_partitioned_args_func(*args, **kwargs):
args = unpack_blockwise_args(args)
line_scan = args[0]
return _wrap_with_array(line_scan)
def _from_partitioned_args(self):
return self._from_partitioned_args_func
def _partition_args(self, chunks=None, lazy: bool = True):
if chunks is None:
chunks = self._default_ensemble_chunks
chunks = validate_chunks(self.ensemble_shape, chunks)
direction = np.array(self.end) - np.array(self.start)
direction = direction / np.linalg.norm(direction, axis=0)
cumchunks = tuple(np.cumsum(chunks[0]))
blocks = []
for i, (start_chunk, chunk) in enumerate(zip((0,) + cumchunks, chunks[0])):
start = np.array(self.start) + start_chunk * self.sampling * direction
end = start + self.sampling * chunk * direction
block = _wrap_with_array(
LineScan(start=start, end=end, gpts=chunk, endpoint=False)
)
if lazy:
block = da.from_array(block, chunks=1)
blocks.append(block)
if lazy:
blocks = da.concatenate(blocks)
else:
blocks = np.concatenate(blocks)
return (blocks,)
[docs]
def get_positions(self, chunks: int = None, lazy: bool = False) -> np.ndarray:
x = np.linspace(
self.start[0],
self.end[0],
self.gpts,
endpoint=self.endpoint,
dtype=np.float32,
)
y = np.linspace(
self.start[1],
self.end[1],
self.gpts,
endpoint=self.endpoint,
dtype=np.float32,
)
return np.stack((np.reshape(x, (-1,)), np.reshape(y, (-1,))), axis=1)
[docs]
def add_to_axes(self, ax: Axes, width: float = 0.0, **kwargs):
"""
Add a visualization of a scan line to a matplotlib plot.
Parameters
----------
ax : matplotlib Axes
The axes of the matplotlib plot the visualization should be added to.
width : float, optional
Width of line [Å].
kwargs :
Additional options for matplotlib.pyplot.plot as keyword arguments.
"""
if width:
rect = Rectangle(
tuple(self.start), self.extent, width, angle=self.angle, **kwargs
)
ax.add_patch(rect)
else:
ax.plot(
[self.start[0], self.end[0]], [self.start[1], self.end[1]], **kwargs
)
[docs]
class GridScan(HasGridMixin, BaseScan):
"""
A scan over a regular grid for calculating scanning transmission electron microscopy.
Parameters
----------
start : two float or Atom, optional
Start corner of the scan [Å]. May be given as fractional coordinate if `fractional=True`. Default is (0., 0.).
end : two float or Atom, optional
End corner of the scan [Å]. May be given as fractional coordinate if `fractional=True`.
Default is None, the scan end point will match the extent of the potential.
gpts : two int, optional
Number of scan positions in the `x`- and `y`-direction of the scan. Provide one of gpts or sampling.
sampling : two float, optional
Sampling rate of scan positions [1 / Å]. Provide one of gpts or sampling. If not provided the sampling will
match the Nyquist sampling of the Probe in a multislice simulation.
endpoint : bool, optional
If True, end is the last position. Otherwise, it is not included. Default is False.
fractional : bool, optional
If True, use fractional coordinates with respect to the given potential for `start` and `end`.
potential : BasePotential or Atoms, optional
Potential defining the grid with respect to which the fractional coordinates should be given.
"""
[docs]
def __init__(
self,
start: tuple[float, float] | Atom = (0.0, 0.0),
end: tuple[float, float] | Atom = None,
gpts: int | tuple[int, int] = None,
sampling: float | tuple[float, float] = None,
endpoint: bool | tuple[bool, bool] = False,
fractional: bool = False,
potential: BasePotential | Atoms = None,
):
super().__init__()
start, end = _validate_coordinates(start, end, potential, fractional)
if start is not None:
if np.isscalar(start):
start = (start,) * 2
start = tuple(map(float, start))
assert len(start) == 2
if end is not None:
if np.isscalar(end):
end = (end,) * 2
end = tuple(map(float, end))
assert len(end) == 2
if start is not None and end is not None:
extent = np.array(end, dtype=float) - start
else:
extent = None
self._start = start
self._end = end
self._grid = Grid(
extent=extent, gpts=gpts, sampling=sampling, dimensions=2, endpoint=endpoint
)
def __len__(self):
return self.gpts[0] * self.gpts[1]
@property
def limits(self):
return [self.start, self.end]
@property
def endpoint(self) -> tuple[bool, bool]:
"""True if the scan endpoint is the last position. Otherwise, the endpoint is not included."""
return self.grid.endpoint
@property
def shape(self) -> tuple[int, int]:
return self.gpts
@property
def start(self) -> tuple[float, float] | None:
"""Start corner of the scan [Å]."""
return self._start
@start.setter
def start(self, start: tuple[float, float]):
self._start = start
self._adjust_extent()
@property
def end(self) -> tuple[float, float] | None:
"""End corner of the scan [Å]."""
return self._end
@end.setter
def end(self, end: tuple[float, float]):
self._end = end
self._adjust_extent()
def _adjust_extent(self):
if self.start is None or self.end is None:
return
self.extent = np.array(self.end) - self.start
[docs]
def match_probe(self, probe: Probe | BaseSMatrix):
"""
Sets sampling to the Nyquist frequency. If the start and end point of the scan is not given, set them to the
lower left and upper right corners of the probe extent.
Parameters
----------
probe : Probe or BaseSMatrix
The matched probe or s-matrix.
"""
if self.start is None:
self.start = (0.0, 0.0)
if self.end is None:
self.end = probe.extent
_validate_scan_sampling(self, probe)
def _x_coordinates(self):
return np.linspace(
self.start[0],
self.end[0],
self.gpts[0],
endpoint=self.endpoint[0],
dtype=np.float32,
)
def _y_coordinates(self):
return np.linspace(
self.start[1],
self.end[1],
self.gpts[1],
endpoint=self.endpoint[1],
dtype=np.float32,
)
[docs]
def get_positions(self) -> np.ndarray:
xi = []
for start, end, gpts, endpoint in zip(
self.start, self.end, self.gpts, self.endpoint
):
xi.append(
np.linspace(start, end, gpts, endpoint=endpoint, dtype=np.float32)
)
if len(xi) == 1:
return xi[0]
return np.stack(np.meshgrid(*xi, indexing="ij"), axis=-1)
def _sort_into_extents(self, extents):
x = np.linspace(
self.start[0], self.end[0], self.gpts[0], endpoint=self.endpoint[0]
)
separators = [l for _, l in extents[0]]
unique, x_chunks = np.unique(np.digitize(x, separators), return_counts=True)
unique = list(unique)
x_chunks_new = []
for i in range(len(separators)):
if i in unique:
x_chunks_new.append(x_chunks[unique.index(i)])
else:
x_chunks_new.append(0)
y = np.linspace(
self.start[1], self.end[1], self.gpts[1], endpoint=self.endpoint[1]
)
separators = [l for _, l in extents[1]]
unique, y_chunks = np.unique(
np.digitize(y, [l for _, l in extents[1]]), return_counts=True
)
unique = list(unique)
y_chunks_new = []
for i in range(len(separators)):
if i in unique:
y_chunks_new.append(y_chunks[unique.index(i)])
else:
y_chunks_new.append(0)
return self, (tuple(x_chunks_new), tuple(y_chunks_new))
@property
def ensemble_axes_metadata(self):
axes_metadata = []
for label, sampling, offset, endpoint in zip(
("x", "y"), self.sampling, self.start, self.endpoint
):
axes_metadata.append(
ScanAxis(
label=label,
sampling=sampling,
offset=offset,
units="Å",
endpoint=endpoint,
)
)
return axes_metadata
def _out_ensemble_axes_metadata(
self, array_object: ArrayObject | T, index: int = 0
) -> list[AxisMetadata] | tuple[list[AxisMetadata], ...]:
ensemble_axes_metadata = self.ensemble_axes_metadata
if len(_scan_axes(array_object)) > 0:
for axis in ensemble_axes_metadata:
axis._main = False
return [*ensemble_axes_metadata, *array_object.ensemble_axes_metadata]
@classmethod
def _from_partitioned_args_func(cls, *args, **kwargs):
x_scan, y_scan = unpack_blockwise_args(args)
start = (x_scan["start"], y_scan["start"])
end = (x_scan["end"], y_scan["end"])
gpts = (x_scan["gpts"], y_scan["gpts"])
endpoint = (x_scan["endpoint"], y_scan["endpoint"])
new_scan = cls(start=start, end=end, gpts=gpts, endpoint=endpoint, **kwargs)
new_scan = _wrap_with_array(new_scan, 2)
return new_scan
def _from_partitioned_args(self):
return self._from_partitioned_args_func
@property
def ensemble_shape(self):
return self.shape
@property
def _default_ensemble_chunks(self):
return "auto", "auto"
def _partition_args(self, chunks=None, lazy=True):
self.grid.check_is_defined()
chunks = self._validate_ensemble_chunks(chunks)
blocks = ()
for i in range(2):
cumchunks = tuple(np.cumsum(chunks[i]))
block = np.empty(len(chunks[i]), dtype=object)
for j, (start_chunk, chunk) in enumerate(zip((0,) + cumchunks, chunks[i])):
start = self.start[i] + start_chunk * self.sampling[i]
end = start + self.sampling[i] * chunk
block[j] = {
"start": start,
"end": end,
"gpts": chunk,
"endpoint": False,
}
if lazy:
blocks += (da.from_array(block, chunks=1),)
else:
blocks += (block,)
return blocks
[docs]
def add_to_plot(
self,
ax,
alpha: float = 0.33,
facecolor: str = "r",
edgecolor: str = "r",
**kwargs,
):
"""
Add a visualization of the scan area to a matplotlib plot.
Parameters
----------
ax : matplotlib Axes
The axes of the matplotlib plot the visualization should be added to.
alpha : float, optional
Transparency of the scan area visualization. Default is 0.33.
facecolor : str, optional
Color of the scan area visualization.
edgecolor : str, optional
Color of the edge of the scan area visualization.
kwargs :
Additional options for matplotlib.patches.Rectangle used for scan area visualization as keyword arguments.
"""
rect = Rectangle(
tuple(self.start),
*self.extent,
alpha=alpha,
facecolor=facecolor,
edgecolor=edgecolor,
**kwargs,
)
ax.add_patch(rect)