Source code for abtem.visualize.artists

from __future__ import annotations

import itertools
from abc import ABCMeta, abstractmethod
from typing import TYPE_CHECKING, Literal

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import artist, colors
from matplotlib.axes import Axes
from matplotlib.collections import Collection, mpath, transforms
from matplotlib.colors import Colormap
from matplotlib.patches import Rectangle
from matplotlib.ticker import ScalarFormatter
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar

from abtem.core import config
from abtem.core.axes import LinearAxis
from abtem.core.colors import hsluv_cmap
from abtem.core.units import _get_conversion_factor

if TYPE_CHECKING:
    from matplotlib.text import Annotation

    from abtem.measurements import (
        DiffractionPatterns,
        Images,
        IndexedDiffractionPatterns,
        MeasurementsEnsemble,
    )


def _get_norm(vmin=None, vmax=None, power=1.0, logscale=False):
    if (power == 1.0) and (logscale is False):
        norm = colors.Normalize(vmin=vmin, vmax=vmax)
    elif (power != 1.0) and (logscale is False):
        norm = colors.PowerNorm(gamma=power, vmin=vmin, vmax=vmax)
    elif (power == 1.0) and (logscale is True):
        norm = colors.LogNorm(vmin=vmin, vmax=vmax)
    else:
        raise ValueError("")

    return norm


def _get_value_limits(array, value_limits: tuple[float, float] = None, margin=None):
    if np.iscomplexobj(array):
        array = np.abs(array)

    if value_limits is None:
        value_limits = [None, None]

    value_limits = list(value_limits).copy()

    if value_limits[0] is None:
        value_limits[0] = float(np.nanmin(array))

    if value_limits[1] is None:
        value_limits[1] = float(np.nanmax(array))

    if margin:
        margin = (value_limits[1] - value_limits[0]) * margin
        value_limits[0] -= margin
        value_limits[1] += margin

    return value_limits


[docs] class ScaleBar:
[docs] def __init__( self, ax: Axes, label: str = "", size: float = None, loc: str = "lower right", borderpad: float = 0.5, size_vertical: float = None, sep: float = 6, pad: float = 0.3, label_top: bool = True, frameon: bool = False, **kwargs, ): if size is None: xlim = ax.get_xlim() size = (xlim[1] - xlim[0]) / 3 if size_vertical is None: ylim = ax.get_ylim() size_vertical = (ylim[1] - ylim[0]) / 20 self._anchored_size_bar = AnchoredSizeBar( ax.transData, label=label, label_top=label_top, size=size, borderpad=borderpad, loc=loc, size_vertical=size_vertical, sep=sep, pad=pad, frameon=frameon, **kwargs, ) ax.add_artist(self._anchored_size_bar)
[docs] class AreaIndicator:
[docs] def __init__( self, ax, xy, width, height, alpha: float = 0.33, facecolor: str = "r", edgecolor: str = "r", **kwargs, ): rect = Rectangle( xy, width, height, alpha=alpha, facecolor=facecolor, edgecolor=edgecolor, **kwargs, ) ax.add_patch(rect)
[docs] class Artist(metaclass=ABCMeta):
[docs] def __init__(self, ax, measurement): self._ax = ax self._measuremet = measurement.compute()
@abstractmethod def get_xlim(self): pass @abstractmethod def get_ylim(self): pass @abstractmethod def set_data(self, data): pass @abstractmethod def set_value_limits(self, value_limits): pass @abstractmethod def set_power(self, power=1.0): pass def set_logscale(self): pass @abstractmethod def get_power(self): pass @abstractmethod def remove(self): pass def set_xlabel(self, label): self._ax.set_xlabel(label) def set_ylabel(self, label): self._ax.set_ylabel(label) def set_ylim(self, ylim): self._ax.set_ylim(ylim) def set_xlim(self, xlim): self._ax.set_xlim(xlim)
[docs] class Artist1D(Artist, metaclass=ABCMeta): pass
[docs] class LinesArtist(Artist1D): num_cbars = 0
[docs] def __init__( self, ax: Axes, measurement, caxes: list[Axes] = None, label=None, units: str = None, legend: bool = False, **kwargs, ): super().__init__(ax=ax, measurement=measurement) y = self._reshape_data(measurement.array) x = measurement.axes_metadata[-1].coordinates(measurement.shape[-1]) if label is None and measurement.ensemble_shape: labels = [] for axis in measurement.ensemble_axes_metadata: labels.append([l.format_title(".3g") for l in axis]) labels = list(itertools.product(*labels)) label = [", ".join(l) for l in labels] if np.iscomplexobj(y): if label is not None: label = [l + "(real)" for l in label] + [l + "(imag)" for l in label] else: label = ["real", "imaginary"] self._lines = ax.plot(x, y, label=label, **kwargs) xlabel = measurement.axes_metadata[-1].format_label(units) ylabel = measurement._scale_axis_from_metadata().format_label() self.set_xlabel(xlabel) if not np.iscomplexobj(ylabel): self.set_ylabel(ylabel) if label and legend: self.set_legend() formatter = ScalarFormatter(useMathText=True) formatter.set_powerlimits((-2, 2)) ax.yaxis.set_major_formatter(formatter)
def remove(self): for line in self._lines: line.remove() self._ax.set_prop_cycle(None) @staticmethod def _reshape_data(y): if len(y.shape) > 1: y = np.moveaxis(y, -1, 0).reshape((y.shape[-1], -1)) if np.iscomplexobj(y): if len(y.shape) == 2: y = np.concatenate((y.real, y.imag), axis=-1) else: y = np.stack((y.real, y.imag), axis=-1) return y def get_xlim(self): xlim = [np.inf, -np.inf] for line in self._lines: data = line.get_data()[0] new_xlim = [np.min(data), np.max(data)] xlim = [min(new_xlim[0], xlim[0]), max(new_xlim[1], xlim[1])] ptp = xlim[1] - xlim[0] return [xlim[0] - 0.05 * ptp, xlim[1] + 0.05 * ptp] def get_ylim(self): ylim = [np.inf, -np.inf] for line in self._lines: data = line.get_data()[1] new_ylim = [np.min(data), np.max(data)] ylim = [min(new_ylim[0], ylim[0]), max(new_ylim[1], ylim[1])] ptp = ylim[1] - ylim[0] ptp = max(ptp, ylim[1] * 0.01) return [ylim[0] - 0.05 * ptp, ylim[1] + 0.05 * ptp] def get_value_limits(self): return self.get_ylim() def set_data(self, data): y = self._reshape_data(data.array) x = data.base_axes_metadata[-1].coordinates(data.shape[-1]) if len(y.shape) == 1: y = y[..., None] for i, line in enumerate(self._lines): line.set_data(x, y[..., i]) def get_logscale(self): return self._ax.set_yscale("log") def set_logscale(self): self._ax.set_yscale("log") def get_power(self): return 1.0 def set_power(self, power: float = 1.0) -> None: raise NotImplementedError # # def forward(x): # if power == 1: # return x # # return x ** (1 / power) # # def inverse(x): # if power == 1: # return x # # return x**power # # if power == 1.0: # self._ax.set_yscale("linear") # else: # self._ax.set_yscale("function", functions=(forward, inverse)) # self._ax.set_yscale("function", functions=(forward, inverse)) def set_value_limits(self, value_limits: list[float] = None): data = np.stack([line.get_data()[1] for line in self._lines], axis=0) value_limits = _get_value_limits(data, value_limits, margin=0.05) self._ax.set_ylim(value_limits) def set_legend(self, **kwargs): self._ax.legend(**kwargs)
[docs] class Artist2D(Artist): @abstractmethod def set_data(self, data): pass @abstractmethod def set_value_limits(self, value_limits): pass @abstractmethod def set_power(self, power: float = 1.0): pass @abstractmethod def set_cmap(self, cmap): pass @abstractmethod def set_cbars(self, cmap): pass def set_logscale(self): pass @staticmethod def _set_vmin_vmax(norm, vmin, vmax, array=None): if vmin is None and array is not None: vmin = array.min() if vmax is None and array is not None: vmax = array.max() with norm.callbacks.blocked(): norm.vmin = vmin norm.vmax = vmax @staticmethod def _update_norm(old_norm, power, artist): if (power != 1.0) and isinstance(old_norm, colors.PowerNorm): old_norm.gamma = power artist.norm = old_norm old_norm._changed() else: norm = _get_norm(vmin=old_norm.vmin, vmax=old_norm.vmax, power=power) artist.norm = norm @staticmethod def _make_cbar(mappable, cax, **kwargs): return plt.colorbar(mappable, cax=cax, **kwargs) @abstractmethod def get_ylim(self): pass @abstractmethod def get_xlim(self): pass @property @abstractmethod def num_cbars(self): pass @abstractmethod def get_value_limits(self): pass def set_scale_bars(self, **kwargs): self._scale_bar = ScaleBar(ax=self._ax, **kwargs) def add_area_indicator(self, area_indicator, panel="first", **kwargs): raise NotImplementedError
# for i, ax in enumerate(np.array(self.axes).ravel()): # if panel == "first" and i == 0: # area_indicator._add_to_visualization(ax, **kwargs) # elif panel == "all": # area_indicator._add_to_visualization(ax, **kwargs)
[docs] def default_cbar_scalar_formatter(): format = ScalarFormatter(useMathText=True) format.set_powerlimits((-3, 3)) return format
[docs] def validate_cmap(cmap, measurement, complex_conversion="none"): if cmap is None: if measurement.is_complex and complex_conversion in ("none", "phase"): cmap = config.get("visualize.phase_cmap", "hsluv") else: cmap = config.get("visualize.cmap", "viridis") if cmap == "hsluv": cmap = hsluv_cmap elif isinstance(cmap, str) and cmap[:5] == "solid": cmap = colors.ListedColormap([cmap.split(" ")[-1]]) return cmap
[docs] def get_extent(measurement, units=None): energy = measurement.metadata.get("energy", None) conversion_x = _get_conversion_factor( units, measurement.base_axes_metadata[0].units, energy=energy ) conversion_y = _get_conversion_factor( units, measurement.base_axes_metadata[0].units, energy=energy ) left = ( measurement.base_axes_metadata[0].offset - measurement.base_axes_metadata[0].sampling / 2 ) * conversion_x right = ( left + measurement.base_axes_metadata[0].sampling * measurement.base_shape[0] * conversion_x ) bottom = ( measurement.base_axes_metadata[1].offset - measurement.base_axes_metadata[1].sampling / 2 ) * conversion_y top = ( bottom + measurement.base_axes_metadata[1].sampling * measurement.base_shape[1] * conversion_y ) return (left, right, bottom, top)
[docs] class ImageArtist(Artist2D): num_cbars = 1
[docs] def __init__( self, ax: Axes, measurement: Images | DiffractionPatterns | MeasurementsEnsemble, caxes: list[Axes] = None, cmap: str | Colormap | None = None, vmin: float = None, vmax: float = None, power: float = 1.0, logscale: bool = False, origin: Literal["upper", "lower"] | None = "lower", units: str = None, **kwargs, ): super().__init__(ax=ax, measurement=measurement) if measurement.is_complex: raise ValueError("Complex measurements are not supported.") extent = get_extent(measurement, units=units) cmap = validate_cmap(cmap, measurement) self._axes_image = ax.imshow( measurement.array.T, origin=origin, cmap=cmap, extent=extent, interpolation=kwargs.pop("interpolation", "none"), **kwargs, ) norm = _get_norm(vmin, vmax, power, logscale) self._axes_image.set_norm(norm) self._cbar = None xlabel = measurement.base_axes_metadata[0].format_label(units) ylabel = measurement.base_axes_metadata[1].format_label(units) self.set_xlabel(xlabel) self.set_ylabel(ylabel) if caxes: cbar_label = measurement._scale_axis_from_metadata().format_label() self.set_cbars(caxes=caxes, label=cbar_label)
@property def axes_image(self): return self._axes_image @property def norm(self): return self.axes_image.norm def remove(self): self.axes_image.remove() def get_power(self): if hasattr(self.norm, "gamma"): return self.norm.gamma else: return 1.0 def get_value_limits(self): array = self.axes_image.get_array() return [array.min(), array.max()] def get_xlim(self): return self.axes_image.get_extent()[:2] def get_ylim(self): return self.axes_image.get_extent()[2:] def set_cbars(self, caxes, **kwargs): format = kwargs.pop("format", default_cbar_scalar_formatter()) cbar = self._make_cbar(self.axes_image, caxes[0], format=format, **kwargs) cbar.ax.yaxis.set_offset_position("left") def set_cmap(self, cmap): self.axes_image.set_cmap(cmap) def set_data(self, data): self.axes_image.set_data(data.array.T) def set_extent(self, extent): self.axes_image.set_extent(extent) def set_power(self, power: float = 1.0): self._update_norm(self.norm, power, self.axes_image) def set_value_limits(self, value_limits: tuple[float, float] = None): self._set_vmin_vmax(self.norm, *value_limits)
[docs] class ScaledCircleCollection(Collection):
[docs] def __init__(self, array, offsets, scale=1.0, threshold: float = 0.0, **kwargs): """ """ self._scale = scale self._threshold = threshold self._mask = array > threshold self._unmasked_offsets = offsets self._unmasked_array = array super().__init__(array=array[self._mask], offsets=offsets[self._mask], **kwargs) self._radii = self._calculate_radii() self.set_transform(transforms.IdentityTransform()) self._transforms = np.empty((0, 3, 3)) self._paths = [mpath.Path.unit_circle()] self._set_transforms() self.callbacks.connect("changed", lambda *args: self._update_radii())
def _set_transforms(self): ax = self.axes radii = self._radii[self._mask] self._transforms = np.zeros((len(radii), 3, 3)) self._transforms[:, 0, 0] = radii self._transforms[:, 1, 1] = radii self._transforms[:, 2, 2] = 1.0 if ax is not None: A = ax.transData.get_affine().get_matrix().copy() A[:2, 2:] = 0 self.set_transform(transforms.Affine2D(A))
[docs] @artist.allow_rasterization def draw(self, renderer): self._set_transforms() super().draw(renderer)
@property def threshold(self): return self._threshold def set_threshold(self, threshold): self._threshold = threshold self._mask = self._unmasked_array > self._threshold self._update() def get_radii(self): return self._radii
[docs] def set_norm(self, norm): super().set_norm(norm) self.norm.callbacks.connect("changed", lambda *args: self._update_radii())
[docs] def set_array(self, array): super().set_array(array) self.changed()
def _update(self): self.set_offsets(self._unmasked_offsets[self._mask]) self.set_array(self._unmasked_array[self._mask]) def set_data(self, array, offsets): self._unmasked_array = array self._unmasked_offsets = offsets self._mask = array > self._threshold self._update() def _calculate_radii(self): norm = self.norm data = self._unmasked_array radii = np.sqrt(np.clip(norm(data) * self._scale, a_min=1e-5, a_max=np.inf)) return radii def _update_radii(self): self._radii = self._calculate_radii() def get_scale(self): return self._scale def set_scale(self, scale): self._scale = scale self._update_radii() self.changed()
[docs] class CircleAnnotations: _placement_to_alignment = {"top": "bottom", "center": "center", "bottom": "top"}
[docs] def __init__( self, circle_collection, annotations, fontsize: int = 8, placement: str = "top", threshold: float = 0.0001, **kwargs, ): self._circle_collection = circle_collection self._placement = placement self._threshold = threshold ax = circle_collection.axes positions = self._get_positions() visibilities = self._get_visibilities() self._annotations = [] for annotation, position, visible in zip(annotations, positions, visibilities): self._annotations.append( ax.annotate( annotation, xy=position, ha="center", va=self._placement_to_alignment.get(placement), visible=visible, fontsize=fontsize, clip_on=True, **kwargs, ) ) circle_collection.callbacks.connect( "changed", lambda *args: self._update_visibilities() ) circle_collection.callbacks.connect( "changed", lambda *args: self._update_positions() )
def __getattr__(self, name): try: super(self.__class__).__getattr__(name) except AttributeError: pass def method(*args, **kwargs): return tuple( getattr(annotation, name)(*args, **kwargs) for annotation in self._annotations ) return method @property def threshold(self): return self._threshold def set_threshold(self, threshold): self._threshold = threshold self._update_visibilities() def set_placement(self, placement): self._placement = placement self.set_verticalalignment(self._placement_to_alignment[placement]) self._update_positions() def _get_visibilities(self): if not self._visible: return np.zeros(len(self._annotations), dtype=bool) mask = self._circle_collection._mask array = self._circle_collection._unmasked_array mask = mask * (array > self._threshold) return mask def set_visible(self, visible): self._visible = visible self._update_visibilities() def _get_positions(self): positions = self._circle_collection._unmasked_offsets.copy() radii = self._circle_collection.get_radii() if self._placement == "top": positions[:, 1] += radii elif self._placement == "bottom": positions[:, 1] -= radii elif self._placement != "center": raise ValueError() return positions def _update_visibilities(self): visibilities = self._get_visibilities() for annotation, visible in zip(self._annotations, visibilities): annotation.set_visible(visible) def _update_positions(self): positions = self._get_positions() for annotation, position in zip(self._annotations, positions): annotation.set_position(position)
[docs] class ScatterArtist(Artist2D): num_cbars = 1
[docs] def __init__( self, ax: Axes, measurement: IndexedDiffractionPatterns, caxes: list[Axes] = None, cmap: str | Colormap | None = None, value_limits: tuple[float, float] = None, power: float = 1.0, logscale: bool = False, units: str = None, scale: float = 0.5, annotations: bool = True, annotation_kwargs: dict = None, **kwargs, ): if annotation_kwargs is None: annotation_kwargs = {} super().__init__(ax=ax, measurement=measurement) vmin, vmax = _get_value_limits(measurement.array, value_limits=value_limits) norm = _get_norm(vmin, vmax, power, logscale) energy = measurement.metadata.get("energy", None) self._unit_conversion = _get_conversion_factor( units, old_units="1/Å", energy=energy ) cmap = validate_cmap(cmap, measurement) self._circles = ScaledCircleCollection( array=measurement.array, cmap=cmap, offsets=measurement.positions[:, :2] * self._unit_conversion, transOffset=ax.transData, norm=norm, scale=scale, **kwargs, ) ax.add_collection(self._circles) units = "1/Å" if units is None else units x_axis = LinearAxis(label="k_x", units=units, tex_label="$k_x$") y_axis = LinearAxis(label="k_y", units=units, tex_label="$k_y$") self.set_xlabel(x_axis.format_label(units)) self.set_ylabel(y_axis.format_label(units)) if annotations: annotations = [] for hkl in measurement.miller_indices: if config.get("visualize.use_tex"): annotation = " \ ".join( [f"\\bar{{{abs(i)}}}" if i < 0 else f"{i}" for i in hkl] ) annotations.append(f"${annotation}$") else: annotations.append("{} {} {}".format(*hkl)) self._annotations = CircleAnnotations( self._circles, annotations, **annotation_kwargs ) else: self._annotations = None if caxes: cbar_label = measurement._scale_axis_from_metadata().format_label() self.set_cbars(caxes=caxes, label=cbar_label)
def __getattr__(self, name): if name in self.__dict__: return self.__dict__[name] if hasattr(self._circles, name): return getattr(self._circles, name) raise AttributeError( f"{self.__class__.__name__} object has no attribute {name}" ) @property def circle_collection(self) -> ScaledCircleCollection: return self._circles @property def annotations(self) -> list[Annotation]: return self._annotations def get_ylim(self): return [ self.get_offsets()[:, 1].min() * 1.1, self.get_offsets()[:, 1].max() * 1.1, ] def get_xlim(self): return [ self.get_offsets()[:, 0].min() * 1.1, self.get_offsets()[:, 0].max() * 1.1, ] def get_value_limits(self): array = self.circle_collection.get_array() return [array.min(), array.max()] def remove(self): self.circle_collection.remove() def get_power(self): if hasattr(self.norm, "gamma"): return self.norm.gamma else: return 1.0 def set_data(self, measurement: np.ndarray): positions = measurement.positions[:, :2] * self._unit_conversion intensities = measurement.array self._circles.set_data(intensities, positions) def set_annotation_kwargs(self, **kwargs): if self._annotations is None: raise ValueError("Annotations are not enabled.") for k, v in kwargs.items(): getattr(self._annotations, f"set_{k}")(**{k: v}) def get_scale(self): return self.circle_collection.get_scale() def set_scale(self, scale: float): self.circle_collection.set_scale(scale) def set_cmap(self, cmap: str): self.circle_collection.set_cmap(cmap) def set_cbars(self, caxes=None, **kwargs): self._make_cbar(self.circle_collection, caxes[0], **kwargs) def set_value_limits(self, value_limits: tuple[float, float] = None): data = self._circles.get_array().data self._set_vmin_vmax(self.norm, *value_limits, data) def set_power(self, power: float = 1.0): self._update_norm(self.norm, power, self._circles)
[docs] class DomainColoringArtist(Artist2D): num_cbars = 2
[docs] def __init__( self, ax: Axes, measurement, caxes: list[Axes] = None, cmap: str | Colormap | None = None, vmin: float = None, vmax: float = None, power: float = 1.0, logscale: bool = False, units: str = None, **kwargs, ): super().__init__(ax=ax, measurement=measurement) norm = _get_norm(vmin, vmax, power, logscale) abs_array = np.abs(measurement.array) alpha = np.clip(norm(abs_array), a_min=0.0, a_max=1.0) extent = get_extent(measurement, units=units) cmap = validate_cmap(cmap, measurement) self._phase_axes_image = ax.imshow( np.angle(measurement.array).T, origin="lower", interpolation=kwargs.pop("interpolation", "none"), alpha=alpha.T, vmin=-np.pi, vmax=np.pi, cmap=cmap, extent=extent, **kwargs, ) self._amplitude_axes_image = ax.imshow( abs_array.T, origin="lower", interpolation=kwargs.pop("interpolation", "none"), cmap="gray", zorder=-1, extent=extent, **kwargs, ) self._amplitude_axes_image.set_norm(norm) self._amplitude_cbar = None self.set_xlabel(measurement.base_axes_metadata[0].format_label(units)) self.set_ylabel(measurement.base_axes_metadata[1].format_label(units)) if caxes is not None and len(caxes): cbar_label = measurement._scale_axis_from_metadata().format_label() self.set_cbars(caxes, label=cbar_label)
def remove(self): self.amplitude_axes_image.remove() self.phase_axes_image.remove() def get_power(self): if hasattr(self.amplitude_norm, "gamma"): return self.amplitude_norm.gamma else: return 1.0 def get_value_limits(self): array = self.amplitude_axes_image.get_array() return [array.min(), array.max()] def get_xlim(self): return self.amplitude_axes_image.get_extent()[:2] def get_ylim(self): return self.amplitude_axes_image.get_extent()[2:] @property def amplitude_norm(self): return self.amplitude_axes_image.norm @property def amplitude_axes_image(self): return self._amplitude_axes_image @property def phase_axes_image(self): return self._phase_axes_image def _update_alpha(self): data = self.amplitude_axes_image.get_array().data alpha = self.amplitude_axes_image.norm(np.abs(data)) alpha = np.clip(alpha, a_min=0, a_max=1) self.phase_axes_image.set_alpha(alpha) def set_value_limits(self, value_limits: tuple[float, float] = (None, None)): self._set_vmin_vmax(self.amplitude_norm, *value_limits) self._update_alpha() if self._amplitude_cbar: self._amplitude_cbar.ax.yaxis.set_offset_position("left") def set_cmap(self, cmap): self.phase_axes_image.set_cmap(cmap) self._phase_cbar.set_ticks([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]) self._phase_cbar.set_ticklabels( [ r"$-\pi$", r"$-\dfrac{\pi}{2}$", r"$0$", r"$\dfrac{\pi}{2}$", r"$\pi$", ] ) def set_power(self, power: float = 1.0): self._update_norm(self.amplitude_norm, power, self.amplitude_axes_image) self._update_alpha() def set_extent(self, extent): self.phase_axes_image.set_extent(extent) self.amplitude_axes_image.set_extent(extent) def set_data(self, data): abs_array = np.abs(data._array) alpha = self.amplitude_norm(abs_array) alpha = np.clip(alpha, a_min=0, a_max=1) self.phase_axes_image.set_alpha(alpha.T) self.phase_axes_image.set_data(np.angle(data._array).T) self.amplitude_axes_image.set_data(abs_array.T) def set_cbars(self, caxes, label: str = None, **kwargs): if caxes is None: caxes = [None, None] self._phase_cbar = self._make_cbar( self.phase_axes_image, cax=caxes[0], **kwargs ) format = default_cbar_scalar_formatter() self._amplitude_cbar = self._make_cbar( self.amplitude_axes_image, cax=caxes[1], format=format, **kwargs ) self._phase_cbar.set_label("arg", rotation=0, ha="center", va="top") self._phase_cbar.ax.yaxis.set_label_coords(0.5, -0.02) self._phase_cbar.set_ticks([-np.pi, -np.pi / 2, 0, +np.pi / 2, np.pi]) self._phase_cbar.set_ticklabels( [ r"$-\pi$", r"$-\dfrac{\pi}{2}$", r"$0$", r"$\dfrac{\pi}{2}$", r"$\pi$", ] ) self._amplitude_cbar.set_label("abs", rotation=0, ha="center", va="top") self._amplitude_cbar.ax.yaxis.set_label_coords(0.5, -0.02) self._amplitude_cbar.ax.yaxis.set_offset_position("left")
[docs] class OverlayImshowArtist(Artist2D):
[docs] def __init__( self, ax, data, cmap, vmin: float = None, vmax: float = None, power: float = 1.0, logscale: bool = False, ): raise NotImplementedError
# cmaps = [ListedColormap(c) for c in cmap] # # alphas = [np.clip(norm(alpha), a_min=0.0, a_max=1.0) for alpha in array] # # print(alphas[0], array.shape, array[0].shape) # ims = [ # ax.imshow( # np.ones_like(alpha.T), # origin="lower", # interpolation="none", # cmap=cmap, # alpha=alpha.T, # ) # for alpha, cmap in zip(alphas, cmaps) # ] # # ax.set_facecolor("k") # # # from matplotlib import colors # from matplotlib.cm import ScalarMappable # from matplotlib.colors import LinearSegmentedColormap, ListedColormap # # fig, ax = plt.subplots() # ax.set_facecolor("k") # # cmap = ListedColormap(["lime"]) # norm = colors.Normalize() # norm.autoscale_None(stacked.array[0]) # alpha = norm(stacked.array[0]) # # im = ax.imshow( # np.ones_like(stacked.array[0]).T, alpha=alpha.T, cmap=cmap, origin="lower" # ) # # cmap = LinearSegmentedColormap.from_list("red", ["k", "lime"]) # plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=ax) # # c = "r" # cmap = ListedColormap([c]) # norm = colors.Normalize() # norm.autoscale_None(stacked.array[1]) # alpha = norm(stacked.array[1]) # # im = ax.imshow( # np.ones_like(stacked.array[1]).T, alpha=alpha.T, cmap=cmap, origin="lower" # ) # # cmap = LinearSegmentedColormap.from_list("c", ["k", c]) # plt.colorbar(ScalarMappable(norm=norm, cmap=cmap), ax=ax) #