Source code for neuralib.plot.colormap

import itertools

import numpy as np
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.cm import ScalarMappable
from matplotlib.colorbar import ColorbarBase
from mpl_toolkits.axes_grid1 import make_axes_locatable

__all__ = [
    'DiscreteColorMapper',
    'get_customized_cmap',
    'insert_colorbar',
    'insert_cyclic_colorbar'
]


[docs] class DiscreteColorMapper: """map color to iterable object **Example of a ``dict`` palette** >>> palette = {3: ('#3182bd', '#6baed6', '#9ecae1'), ... 5: ('#a1d99b', '#c7e9c0', '#756bb1', '#9e9ac8', '#bcbddc')} >>> cmapper = DiscreteColorMapper(palette, 5) >>> x = ['1', '2', '3'] >>> color_list = [cmapper[i] for i in x] **Example of a mpl ``str`` palette** >>> cmapper = DiscreteColorMapper('viridis', 20) >>> regions = ['rsc', 'vis', 'hpc'] >>> color_list = [cmapper[r] for r in regions] """
[docs] def __init__(self, palette: dict[int, tuple[str, ...]] | str, number: int): """ :param palette: A custom dictionary color palette or a Matplotlib colormap name :param number: Number of colors to cycle through in the palette or colormap """ if isinstance(palette, dict): if number not in palette: raise ValueError(f"Palette dictionary does not contain key '{number}'.") self._color_pool = iter(itertools.cycle(palette[number])) elif isinstance(palette, str): import matplotlib as mpl mapper = mpl.colormaps[palette].resampled(number) self._color_pool = iter(itertools.cycle(mapper(range(number)))) else: raise TypeError('palette must be a dict or str') self._key_pool = {}
def __getitem__(self, item: str) -> str | np.ndarray: """get single color""" if item in self._key_pool: return self._key_pool[item] else: color = next(self._color_pool) self._key_pool[item] = color return color
[docs] def get_customized_cmap(name: str, value: tuple[float, float], numbers: int, endpoint: bool = True) -> np.ndarray: """ Generate gradient color map array. `N` = number of color :param name: name of cmap :param value: value range, could be 0-1 :param numbers: `N` :param endpoint: If cyclic colormap, then used `False` :return: RGBA. `Array[float, [N, 4]]` """ cmap = plt.get_cmap(name) return cmap(np.linspace(*value, numbers, endpoint=endpoint))
[docs] def insert_colorbar(ax: Axes, im: ScalarMappable, inset: bool = False, **kwargs) -> ColorbarBase: """ Insert colormap :param ax: ``Axes`` :param im: ``ScalarMappable`` :param inset: use ``inset axes``, otherwise append a new axis :param kwargs: Additional args pass to ``ax.figure.colorbar`` :return: """ if inset: from mpl_toolkits.axes_grid1.inset_locator import inset_axes cax = inset_axes( ax, width="5%", height="25%", loc='upper left', bbox_to_anchor=(1.01, 0., 1, 1), bbox_transform=ax.transAxes, borderpad=0, ) else: divider = make_axes_locatable(ax) cax = divider.append_axes("right", size="7%", pad=0.1) return ax.figure.colorbar(im, cax=cax, **kwargs)
[docs] def insert_cyclic_colorbar(ax: Axes, im: ScalarMappable, *, num_colors: int = 12, num_labels: int = 4, width: float = 0.6, inner_diameter: float = 0.6, vmin: float | None = None, vmax: float | None = None) -> None: """ Insert cyclic colormap in ``inset_axes`` :param ax: ``Axes`` :param im: ``ScalarMappable`` :param num_colors: Number of color in the cyclic colorbar :param num_labels: Number of labels in the cyclic colorbar :param width: Width of the each color :param inner_diameter: The size of the inner circle :param vmin: Min value of the colormap, equal to ``vmax`` in cyclic data :param vmax: Max value of the colormap, equal to ``vmin`` in cyclic data """ polar_ax = ax.inset_axes((1.1, 0.65, 0.4, 0.4), polar=True) theta = np.linspace(0, 2 * np.pi, num_colors, endpoint=False) r1 = np.ones_like(theta) r2 = np.ones_like(theta) * inner_diameter cmap = im.cmap.name cyclic_cmap = ['twilight', 'twilight_shifted', 'hsv'] if cmap not in cyclic_cmap: raise ValueError(f'cmap should be one of the {cyclic_cmap}, not "{cmap}"') colors = get_customized_cmap(cmap, (0, 1), num_colors, endpoint=False) polar_ax.bar(theta, r1, color=colors, width=width, bottom=r2) # Add labels corresponding to the data values angles = np.linspace(0, 2 * np.pi, num_labels, endpoint=False) if vmin is None: vmin = im.norm.vmin if vmax is None: vmax = im.norm.vmax if vmin is None or vmax is None: raise ValueError('vmin and vmax are required when the image norm does not define them') values = np.linspace(vmin, vmax, num_labels, endpoint=False) for angle, value in zip(angles, values, strict=True): polar_ax.text(angle, 3, f'{value:.1f}', horizontalalignment='center', verticalalignment='center') polar_ax.set_yticklabels([]) polar_ax.set_xticks([]) # remove angle labels polar_ax.grid(False) # remove the grid polar_ax.set_frame_on(False)