from typing import Tuple
import numpy as np
from abtem.core.backend import get_array_module, cp
from abtem.core.complex import complex_exponential
from abtem.core.energy import energy2wavelength
from abtem.core.utils import expand_dims_to_broadcast
[docs]
def batch_crop_2d(array: np.ndarray, corners: np.ndarray, new_shape: Tuple[int, int]):
xp = get_array_module(array)
if len(array.shape) > 3:
old_shape = array.shape
batch_shape = array.shape[: -len(corners.shape) - 1]
array = array.reshape((-1,) + array.shape[-2:])
corners = corners.reshape((-1, 2))
if batch_shape:
assert array.shape[0] == corners.shape[0] * np.prod(batch_shape)
corners = np.tile(corners, (np.prod(batch_shape), 1))
else:
old_shape = None
# if xp is cp:
i = xp.arange(array.shape[0])[:, None, None]
ix = xp.arange(new_shape[0]) + xp.asarray(corners[:, 0, None])
iy = xp.arange(new_shape[1]) + xp.asarray(corners[:, 1, None])
ix = ix[:, :, None]
iy = iy[:, None]
array = array[i, ix, iy]
# else:
# array = np.lib.stride_tricks.sliding_window_view(array, (1,) + new_shape)
# array = array[xp.arange(array.shape[0]), corners[:, 0], corners[:, 1], 0]
if old_shape is not None:
array = array.reshape(old_shape[:-2] + array.shape[-2:])
return array
[docs]
def minimum_crop(positions: np.ndarray, shape):
xp = get_array_module(positions)
offset = (shape[0] // 2, shape[1] // 2)
corners = xp.rint(positions - xp.asarray(offset)).astype(int)
upper_corners = corners + xp.asarray(shape)
crop_corner = (xp.min(corners[..., 0]).item(), xp.min(corners[..., 1]).item())
size = (
xp.max(upper_corners[..., 0]).item() - crop_corner[0],
xp.max(upper_corners[..., 1]).item() - crop_corner[1],
)
corners -= xp.asarray(crop_corner)
return crop_corner, size, corners
[docs]
def wrapped_slices(start: int, stop: int, n: int) -> Tuple[slice, slice]:
if start < 0:
if stop > n:
raise RuntimeError(f"start = {start} stop = {stop}, n = {n}")
a = slice(start % n, None)
b = slice(0, stop)
elif stop > n:
if start < 0:
raise RuntimeError(f"start = {start} stop = {stop}, n = {n}")
a = slice(start, None)
b = slice(0, stop - n)
else:
a = slice(start, stop)
b = slice(0, 0)
return a, b
[docs]
def wrapped_crop_2d(
array: np.ndarray, corner: Tuple[int, int], size: Tuple[int, int]
) -> np.ndarray:
upper_corner = (corner[0] + size[0], corner[1] + size[1])
xp = get_array_module(array)
try:
a, c = wrapped_slices(corner[0], upper_corner[0], array.shape[-2])
b, d = wrapped_slices(corner[1], upper_corner[1], array.shape[-1])
except RuntimeError:
padding = tuple(
(abs(min(c, 0)), max(c + k - l, 0))
for c, l, k in zip(corner, array.shape[-2:], size)
)
slices = tuple(
slice(c + p[0], c + p[0] + l) for c, l, p in zip(corner, size, padding)
)
padding = ((0, 0),) * (len(array.shape) - 2) + padding
slices = (slice(None),) * (len(array.shape) - 2) + slices
array = xp.pad(array, padding, mode="wrap")[slices]
return array
A = array[..., a, b]
B = array[..., c, b]
D = array[..., c, d]
C = array[..., a, d]
if A.size == 0:
AB = B
elif B.size == 0:
AB = A
else:
AB = xp.concatenate([A, B], axis=-2)
if C.size == 0:
CD = D
elif D.size == 0:
CD = C
else:
CD = xp.concatenate([C, D], axis=-2)
if CD.size == 0:
return AB
if AB.size == 0:
return CD
return xp.concatenate([AB, CD], axis=-1)
[docs]
def prism_wave_vectors(
cutoff: float,
extent: Tuple[float, float],
energy: float,
interpolation: Tuple[int, int],
xp=np,
) -> np.ndarray:
wavelength = energy2wavelength(energy)
n_max = int(np.ceil(cutoff / 1.0e3 / (wavelength / extent[0] * interpolation[0])))
m_max = int(np.ceil(cutoff / 1.0e3 / (wavelength / extent[1] * interpolation[1])))
n = np.arange(-n_max, n_max + 1, dtype=np.float32)
w = np.asarray(extent[0], dtype=np.float32)
m = np.arange(-m_max, m_max + 1, dtype=np.float32)
h = np.asarray(extent[1], dtype=np.float32)
kx = n / w * np.float32(interpolation[0])
ky = m / h * np.float32(interpolation[1])
mask = kx[:, None] ** 2 + ky[None, :] ** 2 < (cutoff / 1.0e3 / wavelength) ** 2
kx, ky = np.meshgrid(kx, ky, indexing="ij")
kx = kx[mask]
ky = ky[mask]
return xp.asarray([kx, ky]).T
[docs]
def plane_waves(
wave_vectors: np.ndarray,
extent: Tuple[float, float],
gpts: Tuple[int, int],
reverse: bool = False,
) -> np.ndarray:
xp = get_array_module(wave_vectors)
x = xp.linspace(0, extent[0], gpts[0], endpoint=False, dtype=np.float32)
y = xp.linspace(0, extent[1], gpts[1], endpoint=False, dtype=np.float32)
sign = -1.0 if reverse else 1.0
array = complex_exponential(
sign * 2 * np.pi * wave_vectors[:, 0, None, None] * x[:, None]
) * complex_exponential(
sign * 2 * np.pi * wave_vectors[:, 1, None, None] * y[None, :]
)
return array
def _planewave_shift_coefficients(positions, wave_vectors):
xp = get_array_module(positions)
# wave_vectors = xp.asarray(wave_vectors)
coefficients = complex_exponential(
-2.0 * xp.pi * positions[..., 0, None] * wave_vectors[:, 0][None]
)
# print(coefficients.shape, coefficients.dtype)
coefficients *= complex_exponential(
-2.0 * xp.pi * positions[..., 1, None] * wave_vectors[:, 1][None]
)
return coefficients
[docs]
def prism_coefficients(positions, wave_vectors, xp, ctf=None):
wave_vectors = xp.asarray(wave_vectors)
coefficients = _planewave_shift_coefficients(positions, wave_vectors)
if ctf is not None:
alpha = (
xp.sqrt(wave_vectors[:, 0] ** 2 + wave_vectors[:, 1] ** 2) * ctf.wavelength
)
phi = xp.arctan2(wave_vectors[:, 0], wave_vectors[:, 1])
basis = ctf._evaluate_from_angular_grid(alpha, phi)
basis, coefficients = expand_dims_to_broadcast(
basis, coefficients, match_dims=[(-1,), (-1,)]
)
coefficients = coefficients * basis
return coefficients