Source code for abtem.visualize.axes_grid

"""Module for plotting atoms, images, line scans, and diffraction patterns."""

from __future__ import annotations

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from mpl_toolkits.axes_grid1 import Divider, Size
from mpl_toolkits.axes_grid1.axes_grid import _cbaraxes_class_factory
import matplotlib
import matplotlib.colorbar as cbar
from abtem.core.utils import flatten_list_of_lists, interleave


def _cbar_orientation(cbar_loc):
    if cbar_loc == "right":
        orientation = "vertical"
    elif cbar_loc == "below":
        orientation = "horizontal"
    else:
        raise ValueError()
    return orientation



[docs] class AxesCollection:
[docs] def __init__(self, axes, caxes, cbar_mode="single"): self._axes = axes self._caxes = caxes self._cbar_mode = cbar_mode
@property def axes(self): return self._axes @property def fig(self): return self._fig @property def ncols(self) -> int: return self._axes.shape[0] @property def nrows(self) -> int: return self._axes.shape[1] def __getitem__(self, item): return self._axes[item] def __len__(self) -> int: return len(self._axes) @property def shape(self) -> tuple[int, int]: return self._axes.shape
[docs] class AxesGrid:
[docs] def __init__( self, fig, ncols: int, nrows: int, ncbars: int = 0, cbar_mode: str = "single", cbar_loc: str = "right", aspect: bool = True, anchor: str = "NW", sharex: bool = True, sharey: bool = True, rect: tuple = (0.1, 0.1, 0.9, 0.9), origin: str = "lower", ): self._fig = fig self._ncols = ncols self._nrows = nrows self._ncbars = ncbars self._aspect = aspect self._sharex = sharex self._sharey = sharey self._rect = rect self._anchor = anchor self._cbar_loc = cbar_loc self._cbar_mode = cbar_mode if not origin == "lower": # TODO raise NotImplementedError() self._sizes = { "cbar_spacing": Size.Fixed(1.1), "padding": Size.Fixed(0.1), "cbar_shift": Size.Fixed(0.0), "cbar_width": Size.Fixed(0.15), "left": Size.Fixed(0.0), "right": Size.Fixed(0.2), "top": Size.Fixed(0.2), "bottom": Size.Fixed(0.0), } self._axes = self._make_axes() self._caxes = self._make_caxes() self._divider = self._make_divider() self._set_axes_locators() self._set_caxes_locators() if sharex: for inner_axes in self._axes[:, 1:]: for ax in inner_axes: ax._axislines["bottom"].toggle(ticklabels=False, label=False) if sharey: for inner_axes in self._axes[1:]: for ax in inner_axes: ax._axislines["left"].toggle(ticklabels=False, label=False)
def axis_location_to_indices(self, axis_location): axis_locations = { "all": tuple(np.ndindex(self.shape)), "upper left": ((0, self.axes.shape[1] - 1),), "upper right": (self.axes.shape[0] - 1, self.axes.shape[1] - 1), "lower left": ((0, 0),), "lower right": ((self.axes.shape[0] - 1, 0),), } return axis_locations[axis_location] def _make_axes(self): from mpl_toolkits.axes_grid1.mpl_axes import Axes ax = Axes(self.fig, (0, 0, 1, 1), sharex=None, sharey=None) sharex = ax if self._sharex else None sharey = ax if self._sharey else None axes = [ax] + [ Axes(self.fig, (0, 0, 1, 1), sharex=sharex, sharey=sharey) for _ in range(self._nrows * self._ncols - 1) ] for i, ax in enumerate(axes): ax._tag = i for ax in axes: self._fig.add_axes(ax) axes = np.array(axes, dtype=object).reshape( (self._ncols, self._nrows), order="C" ) return axes def _remove_caxes(self): if self._cbar_mode == "single": self._caxes = self._caxes[0, 0] for cax in self._caxes.ravel(): cax.remove() def _make_caxes(self): orientation = _cbar_orientation(self._cbar_loc) if self._cbar_mode == "each": caxes = [ _cbaraxes_class_factory(Axes)( self._fig, self._rect, orientation=orientation ) for _ in range(self._nrows * self._ncols * self._ncbars) ] else: caxes = ( [ _cbaraxes_class_factory(Axes)( self._fig, self._rect, orientation=orientation ) for _ in range(self._ncbars) ] * self._nrows * self._ncols ) for cax in caxes: self._fig.add_axes(cax) caxes = np.array(caxes, dtype=object).reshape( (self._ncols, self._nrows, self._ncbars) ) return caxes def _make_size(self, axes, line_types, axes_size): i = 0 line_sizes = [] for row_type in line_types: if row_type == "ax": # line_sizes.append(Size.Scaled(1.)) # TODO line_sizes.append(axes_size(axes[i], aspect=0.2, ref_ax=axes[0])) i += 1 else: line_sizes.append(self._sizes[row_type]) return line_sizes def _make_divider(self): row_types, col_types = self._get_row_types(), self._get_col_types() row_sizes = self._make_size(self._axes[0], row_types, Size.AxesY) col_sizes = self._make_size(self._axes[:, 0], col_types, Size.AxesX) divider = Divider( self._fig, self._rect, horizontal=col_sizes, vertical=row_sizes, aspect=self._aspect, anchor="C", ) return divider def _set_axes_locators(self): row_types, col_types = self._get_row_types(), self._get_col_types() i = 0 for nx, col_type in enumerate(col_types): for ny, row_type in enumerate(row_types): if (row_type == "ax") and (col_type == "ax"): locator = self._divider.new_locator(nx=nx, ny=ny) self._axes.ravel()[i].set_axes_locator(locator) i += 1 def _set_caxes_locators(self): row_types, col_types = self._get_row_types(), self._get_col_types() i = 0 if (self._cbar_mode == "single") and (self._cbar_loc == "right"): for nx, col_type in enumerate(col_types): if col_type == "cbar_width": locator = self._divider.new_locator(nx=nx, ny=1, ny1=-2) self._caxes.ravel()[i].set_axes_locator(locator) i += 1 elif (self._cbar_mode == "single") and (self._cbar_loc == "below"): for ny, row_type in enumerate(row_types): if row_type == "cbar_width": locator = self._divider.new_locator(ny=ny, nx=1, nx1=-2) self._caxes.ravel()[i].set_axes_locator(locator) i += 1 else: for ny, row_type in enumerate(row_types): for nx, col_type in enumerate(col_types): if ((row_type == "ax") and (col_type == "cbar_width")) or ( (row_type == "cbar_width") and (col_type == "ax") ): locator = self._divider.new_locator(nx=nx, ny=ny) self._caxes.ravel()[i].set_axes_locator(locator) i += 1 def _get_col_types(self): if self._cbar_loc == "right": return self._get_line_types( n=self._ncols, ncbars=self._ncbars, orientation="horizontal" ) else: return self._get_line_types( n=self._ncols, ncbars=0, orientation="horizontal" ) def _get_row_types(self): if self._cbar_loc == "right": return self._get_line_types(n=self._nrows, ncbars=0, orientation="vertical") else: return self._get_line_types( n=self._nrows, ncbars=self._ncbars, orientation="vertical" ) def _get_line_types(self, n, ncbars, orientation): cbar_types = [ ["cbar_shift"] * (ncbars > 0) + ["cbar_width", "cbar_spacing"] * ncbars ] axes_types = [["ax", "padding"]] * (n - 1) + [["ax", "padding"]] if ncbars == 0: axes_types[-1] = axes_types[-1][:-1] if self._cbar_mode == "each": line_types = interleave(axes_types, cbar_types * n) line_types = flatten_list_of_lists(line_types) else: line_types = flatten_list_of_lists(axes_types + cbar_types) if self._cbar_loc == "below": line_types = line_types[::-1] if orientation == "horizontal": line_types = ["left"] + line_types + ["right"] else: line_types = ["bottom"] + line_types + ["top"] return line_types def set_sizes(self, **kwargs): for key, value in kwargs.items(): setattr(self._sizes[key], "fixed_size", value) self.fig.canvas.draw_idle() def set_cbar_layout(self, **kwargs): self._ncbars = kwargs.pop("ncbars", self._ncbars) self._cbar_loc = kwargs.pop("cbar_loc", self._cbar_loc) self._cbar_mode = kwargs.pop("cbar_mode", self._cbar_mode) self._remove_caxes() self._caxes = self._make_caxes() self._divider = self._make_divider() self._set_axes_locators() self._set_caxes_locators() def adjust_figure_to_bbox(self): self.set_sizes(left=0.0, bottom=0.0) size = self.fig.get_size_inches() bbox_inches = self.fig.get_tightbbox() pad_inches = plt.rcParams["savefig.pad_inches"] bbox_inches.padded(pad_inches) aspect = bbox_inches.width / bbox_inches.height self.fig.set_size_inches((size[0], size[0] / aspect)) bbox_inches = self.fig.get_tightbbox().padded(pad_inches) self.set_sizes(left=-bbox_inches.xmin) bbox_inches = self.fig.get_tightbbox().padded(pad_inches) self.set_sizes(bottom=-bbox_inches.ymin) self.fig.canvas.draw_idle() @property def axes(self): return self._axes @property def fig(self): return self._fig @property def ncols(self) -> int: return self._axes.shape[0] @property def nrows(self) -> int: return self._axes.shape[1] def __getitem__(self, item): return self._axes[item] def __len__(self) -> int: return len(self._axes) @property def shape(self) -> tuple[int, int]: return self._axes.shape