Source code for abtem.visualize.visualizations

"""Module for plotting atoms, images, line scans, and diffraction patterns."""

from __future__ import annotations

import itertools
from typing import TYPE_CHECKING

import matplotlib
import matplotlib.colorbar as cbar
import matplotlib.pyplot as plt
import numpy as np
from ase import Atoms
from ase.data import chemical_symbols, covalent_radii
from ase.data.colors import jmol_colors
from matplotlib.axes import Axes
from matplotlib.collections import PatchCollection
from matplotlib.lines import Line2D
from matplotlib.patches import Circle

from abtem.atoms import pad_atoms, plane_to_axes
from abtem.core import config
from abtem.core.utils import label_to_index
from abtem.visualize.artists import (
    DomainColoringArtist,
    ImageArtist,
    LinesArtist,
    ScatterArtist,
    _get_value_limits,
    validate_cmap,
)
from abtem.visualize.axes_grid import AxesCollection, AxesGrid
from abtem.visualize.widgets import slider_from_axes_metadata

if TYPE_CHECKING:
    from abtem.measurements import BaseMeasurements


[docs] def discrete_cmap(num_colors, base_cmap): if isinstance(base_cmap, str): base_cmap = plt.get_cmap(base_cmap) colors = base_cmap(range(0, num_colors)) return matplotlib.colors.LinearSegmentedColormap.from_list("", colors, num_colors)
def _validate_axes_types(overlay, explode, ensemble_dims): if explode is True: explode = tuple(range(ensemble_dims)) elif explode is False: explode = () if overlay is True: overlay = tuple(range(ensemble_dims)) elif overlay is False: overlay = () if isinstance(explode, int): explode = (explode,) if isinstance(overlay, int): overlay = (overlay,) if len(overlay + explode) > ensemble_dims: raise ValueError if len(set(explode) & set(overlay)) > 0: raise ValueError("An axis cannot be both exploded and overlaid.") return overlay, explode
[docs] def convert_complex(measurement: BaseMeasurements, method: str) -> BaseMeasurements: if not measurement.is_complex: return measurement if method in ("domain_coloring", "none", None): return measurement if method in ("phase", "angle"): measurement = measurement.phase() elif method in ("amplitude", "abs"): measurement = measurement.abs() elif method in ("intensity", "abs2"): measurement = measurement.intensity() elif method in ("real",): measurement = measurement.real() elif method in ("imaginary", "imag"): measurement = measurement.imag() else: raise ValueError(f"complex conversion '{method}" f"' not implemented") return measurement
def _validate_artist_type(measurement, complex_conversion, artist_type=None): if artist_type is not None: return artist_type if len(measurement.base_shape) == 2: if measurement.is_complex and ( complex_conversion in ("domain_coloring", "none", None) ): return DomainColoringArtist else: return ImageArtist elif hasattr(measurement, "miller_indices"): return ScatterArtist elif len(measurement.base_shape) == 1: return LinesArtist elif artist_type is None: raise ValueError("artist type not recognized") else: return artist_type def _make_cax(ax, use_gridspec=True, **kwargs): if ax is None: raise ValueError( "Unable to determine Axes to steal space for Colorbar. " "Either provide the *cax* argument to use as the Axes for " "the Colorbar, provide the *ax* argument to steal space " "from it, or add *mappable* to an Axes." ) fig = ( # Figure of first axes; logic copied from make_axes. [*ax.flat] if isinstance(ax, np.ndarray) else [*ax] if np.iterable(ax) else [ax] )[0].figure current_ax = fig.gca() if ( fig.get_layout_engine() is not None and not fig.get_layout_engine().colorbar_gridspec ): use_gridspec = False if ( use_gridspec and isinstance(ax, matplotlib.axes._base._AxesBase) and ax.get_subplotspec() ): cax, kwargs = cbar.make_axes_gridspec(ax, **kwargs) else: cax, kwargs = cbar.make_axes(ax, **kwargs) # make_axes calls add_{axes,subplot} which changes gca; undo that. fig.sca(current_ax) cax.grid(visible=False, which="both", axis="both") return cax
[docs] class Visualization:
[docs] def __init__( self, measurement, ax: Axes = None, artist_type=None, figsize: tuple[int, int] = None, aspect: bool = False, common_scale: bool = False, value_limits: tuple[float, float] = (None, None), overlay: bool | tuple[int, ...] = False, explode: bool | tuple[int, ...] = False, share_x: bool = False, share_y: bool = False, cbar: bool = False, interactive: bool = True, title: str = None, xlim: tuple[float, float] = None, ylim: tuple[float, float] = None, **kwargs, ): self._measurement = measurement.to_cpu().compute() overlay, explode = _validate_axes_types( overlay, explode, len(measurement.ensemble_shape) ) self._overlay = overlay self._explode = explode if ax is None and not interactive: with plt.ioff(): fig = plt.figure(figsize=figsize) elif ax is None: fig = plt.figure(figsize=figsize) else: fig = ax.get_figure() artist_type = _validate_artist_type( measurement, complex_conversion="none", artist_type=artist_type ) if cbar: ncbars = artist_type.num_cbars else: ncbars = 0 if common_scale: cbar_mode = "single" else: cbar_mode = "each" if ax is None: axes_shape = tuple( n for i, n in enumerate(measurement.ensemble_shape) if not i in self.indexing_axes and i not in overlay ) shape = axes_shape + (0,) * (2 - len(axes_shape)) ncols, nrows = (max(shape[0], 1), max(shape[1], 1)) axes = AxesGrid( fig=fig, ncols=ncols, nrows=nrows, ncbars=ncbars, cbar_mode=cbar_mode, aspect=aspect, sharex=share_x, sharey=share_y, ) else: axes = np.array([[ax]], dtype=object) caxes = np.zeros_like(axes, dtype=object) for i in np.ndindex(axes.shape): caxes[i] = [_make_cax(ax, **kwargs) for i in range(ncbars)] axes = AxesCollection(axes, caxes, cbar_mode=cbar_mode) self._axes = axes self._indices = () self._complex_conversion = "none" self._autoscale = config.get("visualize.autoscale", False) self._column_titles = [] self._row_titles = [] self._panel_labels = [] self._artists = None self.get_figure().canvas.header_visible = False if isinstance(title, str): self.set_column_titles(title) elif title and len(explode) > 0: axes_metadata = measurement.axes_metadata[explode[0]].to_ordinal_axis( measurement.shape[explode[0]] ) column_titles = [ l.format_title(".3g", include_label=i == 0) for i, l in enumerate(axes_metadata) ] self.set_column_titles(column_titles) if title and len(explode) > 1: axes_metadata = measurement.axes_metadata[explode[1]].to_ordinal_axis( measurement.shape[explode[1]] ) row_titles = [ l.format_title(".3g", include_label=i == 0) for i, l in enumerate(axes_metadata) ] self.set_row_titles(row_titles) self._make_new_artists(artist_type=artist_type, **kwargs) self.adjust_coordinate_limits_to_artists(xlim=xlim, ylim=ylim) if common_scale: self.set_common_value_limits(value_limits) else: self.set_value_limits(value_limits) if artist_type is DomainColoringArtist: self.axes.set_sizes(cbar_spacing=0.5)
def interact(self, gui_type, display): if not "ipympl" in matplotlib.get_backend(): raise RuntimeError( f"interactive visualizations requires the 'ipympl' matplotlib backend" ) sliders = [ slider_from_axes_metadata( self.measurement.axes_metadata[i], self.measurement.shape[i] ) for i in self.indexing_axes ] gui = gui_type(sliders, self.axes.fig.canvas) gui.attach_visualization(self) if display: from IPython.display import display as ipython_display ipython_display(gui) return gui @property def autoscale(self): return self._autoscale @property def indexing_axes(self): ensemble_axes = set(range(len(self.measurement.ensemble_shape))) return tuple(ensemble_axes - set(self._overlay) - set(self._explode)) @autoscale.setter def autoscale(self, autoscale: bool): self._autoscale = autoscale self.set_value_limits() def _reduce_measurement( self, indices: tuple[int | tuple[int, int], ...], axis_indices ) -> BaseMeasurements: assert len(indices) <= len(self.indexing_axes) assert len(axis_indices) == 2 validated_indices = () summed_axes = () removed_axes = 0 j = 0 k = 0 for i in range(len(self._measurement.ensemble_shape)): if i in self.indexing_axes: if j >= len(indices): validated_indices += (0,) elif isinstance(indices[j], int): validated_indices += (indices[j],) removed_axes += 1 elif isinstance(indices[j], tuple): validated_indices += (slice(*indices[j]),) summed_axes += (i - removed_axes,) j += 1 elif i in self._explode: validated_indices += (axis_indices[k],) k += 1 removed_axes += 1 elif i not in self._overlay: validated_indices += (0,) measurement = self._measurement[validated_indices] if len(summed_axes) > 0: measurement = measurement.sum(axis=summed_axes) measurement = convert_complex(measurement, self._complex_conversion) return measurement @property def measurement(self) -> BaseMeasurements: return self._measurement @property def artists(self): return self._artists @property def axes(self): return self._axes def get_figure(self): return self.axes[0, 0].get_figure() def adjust_coordinate_limits_to_artists(self, xlim=None, ylim=None): if xlim is None: xlim = [np.inf, -np.inf] if ylim is None: ylim = [np.inf, -np.inf] for artist in self.artists.ravel(): new_xlim = artist.get_xlim() xlim = [min(new_xlim[0], xlim[0]), max(new_xlim[1], xlim[1])] new_ylim = artist.get_ylim() ylim = [min(new_ylim[0], ylim[0]), max(new_ylim[1], ylim[1])] self.set_xlim(xlim) self.set_ylim(ylim) def set_xlabel(self, label: str = None): self.set_artists("xlabel", label=label) def set_ylabel(self, label: str = None): self.set_artists("ylabel", label=label) def set_xlim(self, xlim: tuple[float, float] | list[float] = None): self.set_artists("xlim", xlim=xlim) def set_ylim(self, ylim: tuple[float, float] | list[float] = None): self.set_artists("ylim", ylim=ylim) def set_value_limits( self, value_limits: tuple[float, float] | list[float] = (None, None) ): self.set_artists("value_limits", value_limits=value_limits) def set_power(self, power: float = 1.0): self.set_artists("power", power=power) def set_common_value_limits(self, value_limits=(None, None)): value_limits = _get_value_limits( self._measurement.array, value_limits=value_limits ) self.set_value_limits(value_limits) def set_column_titles( self, titles: str | list[str], pad: float = 10.0, fontsize: float = 12, **kwargs, ): if isinstance(titles, str): titles = [titles] * self.axes.shape[0] for column_title in self._column_titles: column_title.remove() column_titles = [] for i, ax in enumerate(self.axes[:, -1]): annotation = ax.annotate( titles[i], xy=(0.5, 1), xytext=(0, pad), xycoords="axes fraction", textcoords="offset points", ha="center", va="baseline", fontsize=fontsize, **kwargs, ) column_titles.append(annotation) self._column_titles = column_titles def set_row_titles( self, titles: str | list[str], shift: float = 0.0, fontsize: float = 12, **kwargs, ): for row_title in self._row_titles: row_title.remove() row_titles = [] for i, ax in enumerate(self.axes[0, :]): annotation = ax.annotate( titles[i], xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - shift, 0), xycoords=ax.yaxis.label, textcoords="offset points", ha="right", va="center", rotation=90, fontsize=fontsize, **kwargs, ) row_titles.append(annotation) self._row_titles = row_titles # def set_panel_labels( # self, # labels: str = None, # frameon: bool = True, # loc: str = "upper left", # pad: float = 0.1, # borderpad: float = 0.1, # prop: dict = None, # formatting: str = ".3g", # units: str = None, # **kwargs, # ): # if labels is None: # titles = _get_axes_titles( # self.ensemble_axes_metadata, self.axes_types, self.axes.shape # ) # labels = ["\n".join(title) for title in titles] # # if not isinstance(labels, (tuple, list)): # raise ValueError() # # if len(labels) != np.array(self.axes).size: # raise ValueError() # # if prop is None: # prop = {} # # for old_label in self._panel_labels: # old_label.remove() # # panel_labels = [] # for ax, label in zip(np.array(self.axes).ravel(), labels): # anchored_text = AnchoredText( # label, # pad=pad, # borderpad=borderpad, # frameon=frameon, # loc=loc, # prop=prop, # **kwargs, # ) # anchored_text.formatting = formatting # # anchored_text.patch.set_boxstyle("round,pad=0.,rounding_size=0.2") # ax.add_artist(anchored_text) # # panel_labels.append(anchored_text) # # self._panel_labels = panel_labels def axis(self, mode: str = "all", ticks: bool = False, spines: bool = True): if mode == "all": return if mode == "none": indices = () else: indices = tuple(self.axes.axis_location_to_indices(mode)) for index in np.ndindex(self.axes.shape): if index in indices: continue ax = self.axes[index] ax._axislines["bottom"].toggle(ticklabels=False, label=False, ticks=ticks) ax._axislines["left"].toggle(ticklabels=False, label=False, ticks=ticks) if not spines: ax.spines["top"].set_visible(False) ax.spines["right"].set_visible(False) ax.spines["bottom"].set_visible(False) ax.spines["left"].set_visible(False) def remove_artists(self): if self.artists is None: return for artist in self.artists.ravel(): artist.remove() def update_data_indices(self, indices): self._indices = indices for i in np.ndindex(self.axes.shape): data = self._reduce_measurement(indices, i) self._artists[i].set_data(data) def _make_new_artists( self, artist_type=None, **kwargs, ): self.remove_artists() artist_type = _validate_artist_type( self.measurement, complex_conversion=self._complex_conversion, artist_type=artist_type, ) artists = np.zeros(self.axes.shape, dtype=object) for i in np.ndindex(self.axes.shape): ax = self.axes[i] # if hasattr(self.axes, "_caxes"): caxes = self.axes._caxes[i] if self.axes._cbar_mode == "single" and not i == (0, 0): caxes = None # else: # caxes = [_make_cax(ax) for _ in range(artist_type.num_cbars)] measurement = self._reduce_measurement(self._indices, i) artist = artist_type( ax=ax, caxes=caxes, measurement=measurement, **kwargs, ) artists.itemset(i, artist) self._artists = artists def set_artists(self, name, locs: str | tuple[int, ...] = "all", **kwargs): artist_type = _validate_artist_type(self._measurement, self._complex_conversion) if not hasattr(artist_type, f"set_{name}"): raise RuntimeError( f"artist of type '{artist_type.__name__}' does not have a method 'set_{name}'" ) if not hasattr(self.axes, "_axis_location_to_indices"): locs = tuple(i for i in np.ndindex(self.axes.shape)) if isinstance(locs, str): locs = self.axes.axis_location_to_indices(locs) for i in locs: getattr(self.artists[i], f"set_{name}")(**kwargs) def set_legend(self, **kwargs): self.set_artists("legend", locs="all", **kwargs) def set_cbars(self, **kwargs): self.set_artists("cbars", locs="all", **kwargs) def set_complex_conversion(self, complex_conversion: str): raise NotImplementedError() # self._complex_conversion = complex_conversion # artist_type = _validate_artist_type( # self._measurement, complex_conversion=self._complex_conversion # ) # self.axes.set_cbar_layout(ncbars=artist_type.num_cbars) # self.set_artists() def set_cmap(self, cmap): cmap = validate_cmap(cmap, self.measurement, self._complex_conversion) self.set_artists("cmap", cmap=cmap) def set_scale_bars(self, locs: str = "lower right", **kwargs): self.set_artists("scale_bars", locs=locs, **kwargs)
_cube = np.array( [ [[0, 0, 0], [0, 0, 1]], [[0, 0, 0], [0, 1, 0]], [[0, 0, 0], [1, 0, 0]], [[0, 0, 1], [0, 1, 1]], [[0, 0, 1], [1, 0, 1]], [[0, 1, 0], [1, 1, 0]], [[0, 1, 0], [0, 1, 1]], [[1, 0, 0], [1, 1, 0]], [[1, 0, 0], [1, 0, 1]], [[0, 1, 1], [1, 1, 1]], [[1, 0, 1], [1, 1, 1]], [[1, 1, 0], [1, 1, 1]], ] ) def _merge_columns(atoms: Atoms, plane, tol: float = 1e-7) -> Atoms: uniques, labels = np.unique(atoms.numbers, return_inverse=True) new_atoms = Atoms(cell=atoms.cell) for unique, indices in zip(uniques, label_to_index(labels)): positions = atoms.positions[indices] positions = _merge_positions(positions, plane, tol) numbers = np.full((len(positions),), unique) new_atoms += Atoms(positions=positions, numbers=numbers) return new_atoms def _merge_positions(positions, plane, tol: float = 1e-7) -> np.ndarray: axes = plane_to_axes(plane) rounded_positions = tol * np.round(positions[:, axes[:2]] / tol) unique, labels = np.unique(rounded_positions, axis=0, return_inverse=True) new_positions = np.zeros((len(unique), 3)) for i, label in enumerate(label_to_index(labels)): top_atom = np.argmax(-positions[label][:, axes[2]]) new_positions[i] = positions[label][top_atom] # new_positions[i, axes[2]] = np.max(positions[label][top_atom, axes[2]]) return new_positions
[docs] def show_atoms( atoms: Atoms, plane: tuple[float, float] | str = "xy", ax: Axes = None, scale: float = 0.75, title: str = None, numbering: bool = False, show_periodic: bool = False, figsize: tuple[float, float] = None, legend: bool = False, merge: float = 1e-2, tight_limits: bool = False, show_cell: bool = None, **kwargs, ): """ Display 2D projection of atoms as a matplotlib plot. Parameters ---------- atoms : ase.Atoms The atoms to be shown. plane : str, two float The projection plane given as a concatenation of 'x' 'y' and 'z', e.g. 'xy', or as two floats representing the azimuth and elevation angles of the viewing direction [degrees], e.g. (45, 45). ax : matplotlib.axes.Axes, optional If given the plots are added to the axes. scale : float Factor scaling their covalent radii for the atom display sizes (default is 0.75). title : str Title of the displayed image. Default is None. numbering : bool Display the index of the Atoms as a number. Default is False. show_periodic : bool If True, show the periodic images of the atoms at the cell boundary. figsize : two int, optional The figure size given as width and height in inches, passed to `matplotlib.pyplot.figure`. legend : bool If True, add a legend indicating the color of the atomic species. merge: float To speed up plotting large numbers of atoms, those closer than the given value [Å] are merged. tight_limits : bool If True the limits of the plot are adjusted kwargs : Keyword arguments for matplotlib.collections.PatchCollection. Returns ------- matplotlib.figure.Figure, matplotlib.axes.Axes """ if show_periodic: atoms = atoms.copy() atoms = pad_atoms(atoms, margins=1e-3) if merge > 0.0: atoms = _merge_columns(atoms, plane, merge) if tight_limits and show_cell is None: show_cell = False elif show_cell is None: show_cell = True if ax is None: fig, ax = plt.subplots(figsize=figsize) else: fig = ax.get_figure() cell = atoms.cell axes = plane_to_axes(plane) cell_lines = np.array( [[np.dot(line[0], cell), np.dot(line[1], cell)] for line in _cube] ) cell_lines_x, cell_lines_y = cell_lines[..., axes[0]], cell_lines[..., axes[1]] if show_cell: for cell_line_x, cell_line_y in zip(cell_lines_x, cell_lines_y): ax.plot(cell_line_x, cell_line_y, "k-") if len(atoms) > 0: positions = atoms.positions[:, axes[:2]] order = np.argsort(-atoms.positions[:, axes[2]]) positions = positions[order] colors = jmol_colors[atoms.numbers[order]] sizes = covalent_radii[atoms.numbers[order]] * scale circles = [] for position, size in zip(positions, sizes): circles.append(Circle(position, size)) coll = PatchCollection(circles, facecolors=colors, edgecolors="black", **kwargs) ax.add_collection(coll) ax.axis("equal") ax.set_xlabel(plane[0] + " [Å]") ax.set_ylabel(plane[1] + " [Å]") ax.set_title(title) if numbering: if merge: raise ValueError("atom numbering requires 'merge' to be False") for i, (position, size) in enumerate(zip(positions, sizes)): ax.annotate( "{}".format(order[i]), xy=position, ha="center", va="center" ) if legend: legend_elements = [ Line2D( [0], [0], marker="o", color="w", markeredgecolor="k", label=chemical_symbols[unique], markerfacecolor=jmol_colors[unique], markersize=12, ) for unique in np.unique(atoms.numbers) ] ax.legend(handles=legend_elements, loc="upper right") if tight_limits: ax.set_adjustable("box") ax.set_xlim([np.min(cell_lines_x), np.max(cell_lines_x)]) ax.set_ylim([np.min(cell_lines_y), np.max(cell_lines_y)]) return fig, ax