"""Module to describe the effect of temperature on the atomic positions."""
from __future__ import annotations
from abc import ABCMeta, abstractmethod
from functools import partial
from numbers import Number
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Optional,
Sequence,
TypeGuard,
Union,
)
import dask
import dask.array as da
import numpy as np
from ase import Atoms, data
from ase.cell import Cell
from ase.data import chemical_symbols
from ase.io import read
from ase.io.trajectory import read_atoms
from dask.delayed import Delayed
from abtem.core.axes import AxisMetadata, FrozenPhononsAxis, UnknownAxis
from abtem.core.chunks import Chunks, chunk_ranges, validate_chunks
from abtem.core.ensemble import Ensemble, _wrap_with_array, unpack_blockwise_args
from abtem.core.utils import CopyMixin, EqualityMixin, get_dtype, itemset
if TYPE_CHECKING:
pass
Reader: Optional[Callable] = None
try:
from gpaw.io import Reader # noqa
except ImportError:
Reader = None
def _safe_read_atoms(calculator, clean: bool = True) -> Atoms:
if isinstance(calculator, str):
assert Reader is not None
with Reader(calculator) as reader:
atoms = read_atoms(reader.atoms)
else:
atoms = calculator.atoms
if clean:
atoms.constraints = None
atoms.calc = True
return atoms
[docs]
class BaseFrozenPhonons(Ensemble, EqualityMixin, CopyMixin, metaclass=ABCMeta):
"""Base class for frozen phonons."""
[docs]
def __init__(
self, atomic_numbers: np.ndarray, cell: Cell, ensemble_mean: bool = True
):
self._cell = cell
self._atomic_numbers = atomic_numbers
self._ensemble_mean = ensemble_mean
@property
def ensemble_mean(self):
"""The mean of the ensemble of results from a multislice simulation is
calculated."""
return self._ensemble_mean
@property
def atomic_numbers(self) -> np.ndarray:
"""The unique atomic number of the atoms."""
return self._atomic_numbers
@property
def cell(self) -> Cell:
"""The cell of the atoms."""
return self._cell
@staticmethod
def _validate_atomic_numbers_and_cell(
atoms: Atoms | np.ndarray,
atomic_numbers: Optional[np.ndarray] = None,
cell: Optional[Cell] = None,
) -> tuple[np.ndarray, Cell]:
if isinstance(atoms, da.core.Array) and (
atomic_numbers is None or cell is None
):
atoms = atoms.compute(scheduler="single-threaded")
if isinstance(atoms, np.ndarray):
atoms = atoms.item()
assert isinstance(atoms, Atoms)
if cell is None:
cell = atoms.cell.copy()
else:
if not isinstance(cell, Cell):
cell = Cell(cell)
if not np.allclose(atoms.cell.array, cell.array):
raise RuntimeError("cell of provided Atoms did not match provided cell")
if atomic_numbers is None:
atomic_numbers = np.unique(atoms.numbers)
else:
atomic_numbers = np.array(atomic_numbers, dtype=int)
return atomic_numbers, cell
@property
@abstractmethod
def atoms(self) -> Atoms:
"""Base atomic configuration used for displacements."""
[docs]
@abstractmethod
def randomize(self, atoms: Atoms) -> Atoms:
"""
Randomize the atoms.
Parameters
----------
atoms : Atoms
"""
@abstractmethod
def __len__(self) -> int:
pass
@property
@abstractmethod
def num_configs(self):
"""Number of atomic configurations."""
def __iter__(self):
for _, _, fp in self.generate_blocks(1):
fp = fp.item()
yield fp.randomize(fp.atoms)
[docs]
class DummyFrozenPhonons(BaseFrozenPhonons):
"""Class to allow all potentials to be treated in the same way."""
[docs]
def __init__(
self,
atoms: Atoms,
num_configs: Optional[int] = None,
):
self._atoms = atoms
self._num_configs = num_configs
atomic_numbers, cell = self._validate_atomic_numbers_and_cell(atoms, None, None)
super().__init__(atomic_numbers=atomic_numbers, cell=cell, ensemble_mean=True)
@property
def num_configs(self):
return self._num_configs
@property
def ensemble_shape(self):
if self._num_configs is None:
return ()
else:
return (self._num_configs,)
@property
def _default_ensemble_chunks(self):
if self._num_configs is None:
return ()
else:
return (1,)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
if self._num_configs is None:
return []
else:
return [FrozenPhononsAxis(_ensemble_mean=self.ensemble_mean)]
[docs]
def randomize(self, atoms: Atoms) -> Atoms:
return atoms
@property
def numbers(self):
"""The atomic numbers of the atoms."""
return self.atoms.numbers
@property
def atoms(self):
return self._atoms
@classmethod
def _from_partitioned_args_func(cls, args, **kwargs):
if hasattr(args, "item"):
args = args.item()
atoms = args
new_dummy_frozen_phonons = cls(atoms=atoms, **kwargs)
return _wrap_with_array(new_dummy_frozen_phonons, 0)
def _from_partitioned_args(self):
kwargs = self._copy_kwargs(exclude=("atoms",))
return partial(self._from_partitioned_args_func, **kwargs)
def _partition_args(self, chunks: Optional[Chunks] = None, lazy: bool = True):
if chunks is None:
chunks = 1
if lazy:
lazy_args = dask.delayed(_wrap_with_array)(self.atoms, ndims=0)
array = da.from_delayed(lazy_args, shape=(), dtype=object)
else:
atoms = self.atoms
array = _wrap_with_array(atoms, ndims=0)
return (array,)
def __len__(self):
if self._num_configs is None:
return 1
else:
return self._num_configs
[docs]
def validate_seeds(
seeds: int | tuple[int, ...] | None,
num_seeds: Optional[int] = None,
) -> tuple[int, ...]:
if num_seeds is None and seeds is None:
num_seeds = 1
if isinstance(seeds, int) and num_seeds is None:
seeds = (seeds,)
elif seeds is None :
if num_seeds is None:
raise ValueError("Provide `num_seeds` or a seed for each configuration.")
rng = np.random.default_rng(seed=seeds)
seeds = ()
while len(seeds) < num_seeds:
seed = rng.integers(np.iinfo(np.int32).max)
if seed not in seeds:
seeds += (seed,)
elif isinstance(seeds, int) and num_seeds is not None:
seeds_to_make_new_seeds = seeds
rng = np.random.default_rng(seed=seeds_to_make_new_seeds)
seeds = ()
while len(seeds) < num_seeds:
seed = rng.integers(np.iinfo(np.int32).max)
if seed not in seeds:
seeds += (seed,)
else:
if not hasattr(seeds, "__len__"):
raise ValueError("Invalid type for `seeds`.")
seeds = tuple(seeds)
if num_seeds is not None:
assert num_seeds == len(seeds)
return seeds
[docs]
def ensure_all_values_are_tuples(
props: dict[Any, Any],
) -> TypeGuard[Dict[str, tuple[float | int, ...]]]:
return all(isinstance(value, tuple) for value in props.values())
[docs]
def all_keys_are_ints(props: dict[Any, Any]) -> TypeGuard[dict[int, Any]]:
return all(isinstance(key, int) for key in props.keys())
AtomProperties = Union[
float,
np.ndarray,
dict[str, float],
dict[str, tuple[float, ...]],
dict[str, np.ndarray],
Sequence[float],
]
[docs]
def validate_per_atom_property(
atoms: Atoms,
props: AtomProperties,
return_array: bool = False,
) -> np.ndarray | dict[str, np.ndarray]:
atomic_numbers = np.unique(atoms.numbers)
unique_symbols = [chemical_symbols[number] for number in atomic_numbers]
validated_props: np.ndarray | dict[str, np.ndarray]
dtype = get_dtype(complex=False)
if isinstance(props, Number):
validated_props = {
symbol: np.array(props, dtype=dtype) for symbol in unique_symbols
}
elif isinstance(props, dict):
if all_keys_are_ints(props):
validated_props = {
chemical_symbols[key]: np.array(value) for key, value in props.items()
}
elif not all(isinstance(key, str) for key in props.keys()):
raise RuntimeError(
"Keys in the properties dictionary must be either all "
"atomic numbers or all chemical symbols."
)
if not set(unique_symbols).issubset(set(props.keys())):
raise RuntimeError(
"Property must be provided for all atomic species."
f" symbols: {unique_symbols}, keys: {props.keys()}"
)
if ensure_all_values_are_tuples(props):
first_attr = next(iter(props.values()))
if not all(len(attr) == len(first_attr) for attr in props.values()):
raise RuntimeError("All values must have the same length.")
validated_props = {
symbol: np.array(value, dtype=dtype) for symbol, value in props.items()
}
elif isinstance(props, (list, tuple, np.ndarray)):
validated_props = np.array(props, dtype=dtype)
if len(props) != len(atoms):
raise RuntimeError("Property must be provided for all atoms.")
else:
raise ValueError("Invalid type for `props`.")
if return_array and isinstance(validated_props, dict):
return atom_property_dict_to_atom_property_array(atoms, validated_props)
return validated_props
[docs]
def validate_sigmas(
atoms: Atoms, sigmas: AtomProperties, return_array: bool = False
) -> tuple[np.ndarray | dict[str, np.ndarray], bool]:
"""
Validate the standard deviations of displacement for atoms in an atomic structure.
Parameters
----------
atoms : Atoms
The atomic structure which standard deviations of displacement are to be
validated.
sigmas : float, dict[str, float] or Sequence[float]
It can be either:
- a single float value specifying the standard deviation for all atoms,
- a dictionary mapping each atom's symbol or atomic number to a corresponding
standard deviation,
- a sequence of float values providing the standard deviation for each atom
individually.
For anisotropic displacements, either three values for each atom or for each
element must be provided.
Returns
-------
sigmas : dict[str, float] or np.ndarray
The validated standard deviations
anisotropic : bool
A boolean value indicating whether the displacements are anisotropic.
Raises
------
ValueError
If the type of `sigmas` is not float, dict, or list.
RuntimeError
If the length of `sigmas` does not match the length of `atoms`, or
three values for each atom or each element are not given for anisotropic
displacements.
"""
validated_sigmas = validate_per_atom_property(
atoms, sigmas, return_array=return_array
)
if isinstance(validated_sigmas, dict):
sigmas_array = next(iter(validated_sigmas.values()))
else:
sigmas_array = validated_sigmas
if (
sigmas_array.shape
and len(sigmas_array.shape) == 2
and sigmas_array.shape[-1] == (3,)
):
anisotropic = True
elif len(sigmas_array.shape) < 2:
anisotropic = False
else:
raise RuntimeError("Anisotropic displacements must be given as three values.")
return validated_sigmas, anisotropic
[docs]
def atom_property_dict_to_atom_property_array(
atoms: Atoms, props: dict[str, np.ndarray]
) -> np.ndarray:
dtype = get_dtype(complex=False)
n = next(iter(props.values())).shape
array = np.zeros((len(atoms.numbers),) + n, dtype=dtype)
for unique in np.unique(atoms.numbers):
array[atoms.numbers == unique] = np.array(
props[chemical_symbols[unique]], dtype=dtype
)
return array
[docs]
class FrozenPhonons(BaseFrozenPhonons):
"""
The frozen phonons randomly displace the atomic positions to emulate thermal
vibrations.
Parameters
----------
atoms : ASE.Atoms
Atomic configuration used for displacements.
num_configs : int
Number of frozen phonon configurations.
sigmas : float or dict or list
If float, the standard deviation of the displacements is assumed to be identical
for all atoms. If dict, a displacement standard deviation should be provided for
each species. The atomic species can be specified as atomic number or a symbol,
using the ASE standard. If list or array, a displacement standard deviation
should be provided for each atom.
Anistropic displacements may be given by providing a standard deviation for each
principal direction. This may be a tuple of three numbers for identical
displacements for all atoms. A dict of tuples of three numbers to specify
displacements for each species. A list or array with three numbers for each
atom.
directions : str, optional
The displacement directions of the atoms as a string; for example 'xy' (default)
for displacement in the `x`- and `y`-direction (i.e. perpendicular to the
propagation direction).
ensemble_mean : bool, optional
If True (default), the mean of the ensemble of results from a multislice
simulation is calculated, otherwise, the result of every frozen phonon
configuration is returned.
seed : int or sequence of int
Seed(s) for the random number generator used to generate the displacements, or
one seed for each configuration in the frozen phonon ensemble.
"""
[docs]
def __init__(
self,
atoms: Atoms,
num_configs: int,
sigmas: (
float | dict[str, float] | dict[str, tuple[float, ...]] | Sequence[float]
),
directions: str = "xyz",
ensemble_mean: bool = True,
seed: Optional[int | tuple[int, ...]] = None,
):
if isinstance(sigmas, dict):
atomic_numbers = np.array(
[data.atomic_numbers[symbol] for symbol in sigmas.keys()]
)
else:
atomic_numbers = None
atomic_numbers, cell = self._validate_atomic_numbers_and_cell(
atoms, atomic_numbers, cell=None
)
self._sigmas = validate_sigmas(atoms, sigmas)[0]
self._directions = directions
self._atoms = atoms
self._seed = validate_seeds(seed, num_seeds=num_configs)
super().__init__(
atomic_numbers=atomic_numbers, cell=cell, ensemble_mean=ensemble_mean
)
@property
def ensemble_shape(self):
return (len(self),)
@property
def _default_ensemble_chunks(self):
return (1,)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return [FrozenPhononsAxis(_ensemble_mean=self.ensemble_mean)]
@property
def num_configs(self) -> int:
return len(self._seed)
@property
def seed(self) -> tuple[int, ...]:
"""Random seed for each displacement configuration."""
return self._seed
@property
def sigmas(self) -> np.ndarray | dict[str, np.ndarray]:
"""Displacement standard deviation for each atom."""
return self._sigmas
@property
def atoms(self) -> Atoms:
return self._atoms
@property
def directions(self) -> str:
"""The directions of the random displacements."""
return self._directions
def __len__(self) -> int:
return self.num_configs
def _validate_sigmas(self, atoms: Atoms):
return validate_sigmas(atoms, self._sigmas)
@property
def _axes(self) -> list[int]:
axes = []
for direction in list(set(self._directions.lower())):
if direction == "x":
axes += [0]
elif direction == "y":
axes += [1]
elif direction == "z":
axes += [2]
else:
raise RuntimeError(f"Directions must be 'x', 'y' or 'z', not {axes}.")
return axes
[docs]
def randomize(self, atoms: Atoms) -> Atoms:
sigmas, anisotropic = self._validate_sigmas(atoms)
if isinstance(sigmas, dict):
sigmas = atom_property_dict_to_atom_property_array(atoms, sigmas)
assert isinstance(sigmas, np.ndarray)
atoms = atoms.copy()
rng = np.random.default_rng(self.seed[0])
if anisotropic:
r = rng.normal(size=(len(atoms), 3))
for axis in self._axes:
atoms.positions[:, axis] += sigmas[:, axis] * r[:, axis]
else:
r = rng.normal(size=(len(atoms), 3))
for axis in self._axes:
atoms.positions[:, axis] += sigmas * r[:, axis]
return atoms
@classmethod
def _from_partitioned_args_func(cls, *args, **kwargs):
args = unpack_blockwise_args(args)
atoms, seed = args[0]
new = cls(atoms=atoms, seed=seed, num_configs=len(seed), **kwargs)
new = _wrap_with_array(new, len(new.ensemble_shape))
return new
def _from_partitioned_args(self):
kwargs = self._copy_kwargs(exclude=("atoms", "seed", "num_configs"))
output = partial(self._from_partitioned_args_func, **kwargs)
return output
def _partition_args(self, chunks: Optional[Chunks] = None, lazy: bool = True):
if chunks is None:
chunks = 1
chunks = validate_chunks(self.ensemble_shape, chunks)
if lazy:
arrays = []
for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
seeds = self.seed[start:stop]
lazy_atoms = dask.delayed(self.atoms)
lazy_args = dask.delayed(_wrap_with_array)((lazy_atoms, seeds), ndims=1)
lazy_array = da.from_delayed(lazy_args, shape=(1,), dtype=object)
arrays.append(lazy_array)
array = da.concatenate(arrays)
else:
atoms = self.atoms
array = np.zeros((len(chunks[0]),), dtype=object)
for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
itemset(array, i, (atoms, self.seed[start:stop]))
return (array,)
[docs]
def to_atoms_ensemble(self):
"""
Convert the frozen phonons to an ensemble of atoms.
Returns
-------
atoms_ensemble : AtomsEnsemble
"""
trajectory = []
for _, _, block in self.generate_blocks(1):
block = block.item()
trajectory.append(block.randomize(block.atoms))
return AtomsEnsemble(trajectory)
[docs]
class AtomsEnsemble(BaseFrozenPhonons):
"""
Frozen phonons based on a molecular dynamics simulation.
Parameters
----------
trajectory : list of ASE.Atoms, dask.core.Array, list of dask.Delayed
Sequence of atoms representing a distribution of atomic configurations.
ensemble_mean : True, optional
If True, the mean of the ensemble of results from a multislice simulation is
calculated, otherwise, the result of every frozen phonon is returned.
ensemble_axes_metadata : list of AxesMetadata, optional
Axis metadata for each ensemble axis. The axis metadata must be compatible with
the shape of the array.
cell : Cell, optional
"""
[docs]
def __init__(
self,
trajectory: Sequence[Atoms],
ensemble_mean: bool = True,
ensemble_axes_metadata: Optional[list[AxisMetadata]] = None,
cell: Optional[Cell] = None,
):
if isinstance(trajectory, str):
trajectory = read(trajectory, index=":")
elif isinstance(trajectory, Atoms):
trajectory = [trajectory]
if isinstance(trajectory, (list, tuple)):
if isinstance(trajectory[0], str):
trajectory = [_safe_read_atoms(path) for path in trajectory]
if isinstance(trajectory[0], Delayed):
stack = []
for atoms in trajectory:
atoms = dask.delayed(_wrap_with_array)(atoms, 1)
atoms = da.from_delayed(atoms, shape=(1,), dtype=object)
stack.append(atoms)
trajectory_array = da.concatenate(stack)
else:
stack_array = np.empty(len(trajectory), dtype=object)
for i, atoms in enumerate(trajectory):
itemset(stack_array, i, atoms)
trajectory_array = stack_array
elif isinstance(trajectory, (np.ndarray, da.core.Array)):
trajectory_array = trajectory
else:
raise ValueError(f"Invalid type for `trajectory`, got {type(trajectory)}")
# assert isinstance(trajectory, (np.ndarray, da.core.Array))
if ensemble_axes_metadata is None:
ensemble_axes_metadata = [FrozenPhononsAxis(_ensemble_mean=ensemble_mean)]
elif isinstance(ensemble_axes_metadata, AxisMetadata):
ensemble_axes_metadata = [ensemble_axes_metadata]
elif not isinstance(ensemble_axes_metadata, list):
raise ValueError()
assert len(ensemble_axes_metadata) == len(trajectory_array.shape)
atoms = trajectory_array.ravel()[0]
atomic_numbers, cell = self._validate_atomic_numbers_and_cell(atoms, None, cell)
self._trajectory = trajectory_array
super().__init__(
atomic_numbers=atomic_numbers, cell=cell, ensemble_mean=ensemble_mean
)
self._ensemble_axes_metadata = ensemble_axes_metadata
@property
def trajectory(self) -> np.ndarray | da.core.Array:
"""Array of atoms representing an ensemble of atomic configurations."""
return self._trajectory
@property
def numbers(self):
"""The atomic numbers of the atoms."""
return self.trajectory[0].numbers
def __getitem__(self, item):
new_trajectory = self._trajectory[item]
kwargs = self._copy_kwargs(exclude=("trajectory",))
return AtomsEnsemble(new_trajectory, **kwargs)
@property
def ensemble_axes_metadata(self) -> list[AxisMetadata]:
return self._ensemble_axes_metadata
def __len__(self) -> int:
return len(self._trajectory)
@property
def num_configs(self) -> int:
return len(self._trajectory)
@property
def atoms(self) -> Atoms:
atoms = self._trajectory.ravel()[0]
if isinstance(atoms, np.ndarray):
atoms = atoms.item()
return atoms
@property
def ensemble_shape(self) -> tuple[int, ...]:
if isinstance(self._trajectory, (da.core.Array, np.ndarray)):
return self._trajectory.shape
return (len(self),)
@property
def _default_ensemble_chunks(self) -> tuple[int, ...]:
if isinstance(self._trajectory, (da.core.Array, np.ndarray)):
return (1,) * len(self.ensemble_shape)
return (1,)
def _partition_args(self, chunks: Optional[Chunks] = None, lazy: bool = True):
if chunks is None:
chunks = 1
chunks = validate_chunks(self.ensemble_shape, chunks)
if lazy:
arrays = []
for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
trajectory = self.trajectory[start:stop]
lazy_args = dask.delayed(_wrap_with_array)(trajectory, ndims=1)
lazy_array = da.from_delayed(lazy_args, shape=(1,), dtype=object)
arrays.append(lazy_array)
array = da.concatenate(arrays)
else:
trajectory = self.trajectory
if isinstance(trajectory, da.core.Array):
trajectory = trajectory.compute()
array = np.zeros((len(chunks[0]),), dtype=object)
for i, (start, stop) in enumerate(chunk_ranges(chunks)[0]):
itemset(array, i, _wrap_with_array(trajectory[start:stop], 1))
return (array,)
@staticmethod
def _from_partition_args_func(*args, **kwargs):
args = unpack_blockwise_args(args)
trajectory = args[0]
atoms_ensemble = AtomsEnsemble(trajectory, **kwargs)
return _wrap_with_array(atoms_ensemble, 1)
def _from_partitioned_args(self):
kwargs = self._copy_kwargs(exclude=("trajectory", "ensemble_shape"))
kwargs["cell"] = self.cell.array
kwargs["ensemble_axes_metadata"] = [UnknownAxis()] * len(self.ensemble_shape)
return partial(self._from_partition_args_func, **kwargs)
[docs]
def randomize(self, atoms: Atoms) -> Atoms:
return atoms
[docs]
def mean_squared_deviations(self) -> np.ndarray:
"""
Squared deviation of the positions of each atom in each direction.
"""
positions = np.stack([atoms.positions for atoms in self.trajectory])
return ((positions - positions.mean(0)) ** 2).mean(0)
[docs]
def standard_deviations(self) -> np.ndarray:
"""
Standard deviation of the positions of each atom in each direction.
"""
positions = np.stack([atoms.positions for atoms in self.trajectory])
return (positions - positions.mean(0)).std()