Module: matplotlib_util
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.image import AxesImage
class MatplotlibUtil:
"""Utilities for plots created using Matplotlib."""
@staticmethod
def heatmap(data: np.ndarray, row_labels: List[str], col_labels: List[str], ax=None, **kwargs):
"""
Create a heatmap from a numpy array and two lists of labels.
Parameters
----------
data
A 2D numpy array of shape (M, N).
row_labels
A list or array of length M with the labels for the rows.
col_labels
A list or array of length N with the labels for the columns.
ax
A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If
not provided, use current Axes or create a new one. Optional.
**kwargs
All other arguments are forwarded to `imshow`.
"""
if ax is None:
ax = plt.gca()
# Plot the heatmap
im = ax.imshow(data, **kwargs)
# Show all ticks and label them with the respective list entries.
ax.set_xticks(np.arange(data.shape[1]), labels=col_labels)
ax.set_yticks(np.arange(data.shape[0]), labels=row_labels)
# Let the horizontal axes labeling appear on top.
ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")
# Turn spines off and create white grid.
ax.spines[:].set_visible(False)
ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
ax.grid(which="minor", color="w", linestyle="-", linewidth=3)
ax.tick_params(which="minor", bottom=False, left=False)
return im
@staticmethod
def annotate_heatmap(
im: AxesImage,
labels: List[List[str]],
textcolors: Union[str, Tuple[str]] = ("black", "white"),
threshold: Optional[float] = None,
**textkw,
):
"""
A function to annotate a heatmap.
Parameters
----------
im
The AxesImage to be labeled.
labels:
Label for each cell
textcolors:
One color or a pair of colors. The first is used for values below a threshold,
the second for those above. Optional.
threshold
Value in data units according to which the colors from textcolors are
applied. If None (the default) uses the middle of the colormap as
separation. Optional.
**textkw
All other arguments are forwarded to each call to `text` used to create
the text labels.
"""
data = im.get_array()
# Normalize the threshold to the images color range.
if threshold is not None:
threshold = im.norm(threshold)
# Set default alignment to center, but allow it to be
# overwritten by textkw.
kw = dict(horizontalalignment="center", verticalalignment="center")
kw.update(textkw)
# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
texts = []
for i in range(data.shape[0]):
for j in range(data.shape[1]):
kw.update(
color=(
textcolors[int(im.norm(data[i, j]) < threshold)]
if isinstance(textcolors, tuple)
else textcolors
),
)
text = im.axes.text(j, i, labels[i][j], **kw)
texts.append(text)
return texts
Classes
class MatplotlibUtil
-
Utilities for plots created using Matplotlib.
Expand source code
class MatplotlibUtil: """Utilities for plots created using Matplotlib.""" @staticmethod def heatmap(data: np.ndarray, row_labels: List[str], col_labels: List[str], ax=None, **kwargs): """ Create a heatmap from a numpy array and two lists of labels. Parameters ---------- data A 2D numpy array of shape (M, N). row_labels A list or array of length M with the labels for the rows. col_labels A list or array of length N with the labels for the columns. ax A `matplotlib.axes.Axes` instance to which the heatmap is plotted. If not provided, use current Axes or create a new one. Optional. **kwargs All other arguments are forwarded to `imshow`. """ if ax is None: ax = plt.gca() # Plot the heatmap im = ax.imshow(data, **kwargs) # Show all ticks and label them with the respective list entries. ax.set_xticks(np.arange(data.shape[1]), labels=col_labels) ax.set_yticks(np.arange(data.shape[0]), labels=row_labels) # Let the horizontal axes labeling appear on top. ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False) # Rotate the tick labels and set their alignment. plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor") # Turn spines off and create white grid. ax.spines[:].set_visible(False) ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True) ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True) ax.grid(which="minor", color="w", linestyle="-", linewidth=3) ax.tick_params(which="minor", bottom=False, left=False) return im @staticmethod def annotate_heatmap( im: AxesImage, labels: List[List[str]], textcolors: Union[str, Tuple[str]] = ("black", "white"), threshold: Optional[float] = None, **textkw, ): """ A function to annotate a heatmap. Parameters ---------- im The AxesImage to be labeled. labels: Label for each cell textcolors: One color or a pair of colors. The first is used for values below a threshold, the second for those above. Optional. threshold Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. Optional. **textkw All other arguments are forwarded to each call to `text` used to create the text labels. """ data = im.get_array() # Normalize the threshold to the images color range. if threshold is not None: threshold = im.norm(threshold) # Set default alignment to center, but allow it to be # overwritten by textkw. kw = dict(horizontalalignment="center", verticalalignment="center") kw.update(textkw) # Loop over the data and create a `Text` for each "pixel". # Change the text's color depending on the data. texts = [] for i in range(data.shape[0]): for j in range(data.shape[1]): kw.update( color=( textcolors[int(im.norm(data[i, j]) < threshold)] if isinstance(textcolors, tuple) else textcolors ), ) text = im.axes.text(j, i, labels[i][j], **kw) texts.append(text) return texts
Static methods
def annotate_heatmap(im: matplotlib.image.AxesImage, labels: List[List[str]], textcolors: Union[str, Tuple[str]] = ('black', 'white'), threshold: Optional[float] = None, **textkw)
-
A function to annotate a heatmap.
Parameters
im
- The AxesImage to be labeled.
- labels:
- Label for each cell
- textcolors:
- One color or a pair of colors. The first is used for values below a threshold,
- the second for those above. Optional.
threshold
- Value in data units according to which the colors from textcolors are applied. If None (the default) uses the middle of the colormap as separation. Optional.
**textkw
- All other arguments are forwarded to each call to
text
used to create the text labels.
def heatmap(data: numpy.ndarray, row_labels: List[str], col_labels: List[str], ax=None, **kwargs)
-
Create a heatmap from a numpy array and two lists of labels.
Parameters
data
- A 2D numpy array of shape (M, N).
row_labels
- A list or array of length M with the labels for the rows.
col_labels
- A list or array of length N with the labels for the columns.
ax
- A
matplotlib.axes.Axes
instance to which the heatmap is plotted. If not provided, use current Axes or create a new one. Optional. **kwargs
- All other arguments are forwarded to
imshow
.