Source code for descarteslabs.common.display._display

# Copyright 2018-2024 Descartes Labs.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Displays ndarrays as images, but is easier to use and more flexible than matplotlib's ``imshow``.
"""

from __future__ import division

import math

import numpy as np
from strenum import StrEnum


def _import_matplotlib_pyplot():
    try:
        import matplotlib
    except ImportError:
        raise ImportError("The matplotlib package is required for displaying images.")
    try:
        # a change in matplotlib at some point causes certain runtime failures
        # unless these are previously imported
        import matplotlib.backends
        import matplotlib.backends.backend_agg  # noqa F401
    except ImportError:
        pass
    try:
        import matplotlib.pyplot
    except RuntimeError as e:
        if matplotlib.get_backend() == "MacOSX":
            raise RuntimeError(
                "Python is not installed as a framework; the Mac OS X backend will not work.\n"
                "To resolve this, *before* calling dl.utils.display(), execute this code:\n\n"
                "import matplotlib\n"
                "matplotlib.use('TkAgg')\n"
                "import matplotlib.pyplot as plt\n\n"
                "In an interactive session, you'll have to restart your Python interpreter first."
            )
        else:
            raise e
    return matplotlib


class LayoutDirection(StrEnum):
    left_to_right = "left-to-right"
    top_to_bottom = "top-to-bottom"

    @classmethod
    def directions(cls):
        return [attr for attr in dir(cls) if not attr.startswith("_")]


[docs]def display(*imgs, **kwargs): """ Display 2D and 3D ndarrays as images with matplotlib. The ndarrays must either be 2D, or 3D with 1 or 3 bands. If they are 3D masked arrays, the mask will be used as an alpha channel. Unlike matplotlib's ``imshow``, arrays can be any dtype; internally, each is normalized to the range [0..1]. Parameters ---------- *imgs: 1 or more ndarrays When multiple images are given, each is displayed on its own row by default. bands_axis: int, default 0 Axis which contains bands in each array. title: str, or sequence of str; optional Title for each image. If a sequence, must be the same length as ``imgs``. size: int, default 10 Length, in inches, to display the longer side of each image. robust: bool, default True Use the 2nd and 98th percentiles to compute color limits. Otherwise, the minimum and maximum values in each array are used. interpolation: str, default "bilinear" Interpolation method for matplotlib to use when scaling images for display. Bilinear is the default, since it produces smoother results when scaling down continuously-valued data (i.e. images). For displaying discrete data, however, choose 'nearest' to prevent values not existing in the input from appearing in the output. Acceptable values are 'none', 'nearest', 'bilinear', 'bicubic', 'spline16', 'spline36', 'hanning', 'hamming', 'hermite', 'kaiser', 'quadric', 'catrom', 'gaussian', 'bessel', 'mitchell', 'sinc', 'lanczos' colormap: str, default None The name of a Colormap registered with matplotlib. Some commonly used built-in options are 'plasma', 'magma', 'viridis', 'inferno'. See https://matplotlib.org/users/colormaps.html for more options. To use a Colormap, the input images must have a single band. The Colormap will be ignored for images with more than one band. figsize: tuple(int), default (size, (size / ncols) * nrows) Width, height in inches. nrows: int, default is the number of images Number of rows if there are multiple images. ncols: int, default 1 Number of columns if there are multiple images. layout_direction: str, default "left-to-right" If ncols is greated than 1, it determines whether the layout is left-to-right for the images, or top-to-bottom. Raises ------ ImportError If matplotlib is not installed. """ _display_or_save(None, *imgs, **kwargs)
[docs]def save_image(filename, *imgs, **kwargs): """ Save 2D and 3D ndarrays as images with matplotlib. For an explanation of the rest of the arguments, please look under :func:`display`. For an explanation of valid extension types, please look under matplotlib :func:'savefig'. Parameters ---------- filename: str The name and extension of the image to be saved. """ _display_or_save(filename, *imgs, **kwargs)
def _flatten(axs, left_to_right=True): # Flatten the images into a single stream if left_to_right: for axcol in axs: for ax in axcol: yield ax else: # top to bottom for col in range(len(axs[0])): for axcol in axs: yield axcol[col] def _display_or_save(filename, *imgs, **kwargs): if len(imgs) == 0: return bands_axis = kwargs.pop("bands_axis", 0) titles = kwargs.pop("title", None) size = kwargs.pop("size", 10) robust = kwargs.pop("robust", True) interpolation = kwargs.pop("interpolation", "bilinear") colormap_name = kwargs.pop("colormap", None) figsize = kwargs.pop("figsize", None) nrows = kwargs.pop("nrows", None) ncols = kwargs.pop("ncols", 1) layout_direction = LayoutDirection( kwargs.pop("layoutdirection", LayoutDirection.left_to_right) ) if len(kwargs) > 0: raise TypeError( "Unexpected keyword arguments for display: {}".format( ", ".join(kwargs.keys()) ) ) matplotlib = _import_matplotlib_pyplot() plt = matplotlib.pyplot if len(imgs) == 1: if isinstance(imgs[0], (list, tuple)): raise TypeError( "To display a sequence of images, unpack it: `display(*images_list)`" ) elif isinstance(imgs[0], np.ndarray) and len(imgs[0].shape) == 4: raise TypeError( "To display a 4D ndarray (image stack), unpack it: `display(*stack)`" ) # TODO: facet grid # TODO: leaves huge gaps between images that aren't very square # would need to calculate figsize better based on shapes of each img # or use seaborn? if nrows is None: nrows = math.ceil(len(imgs) / ncols) if figsize is None: figsize = size, (size / ncols) * nrows fig, axs = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) if isinstance(titles, (list, tuple, np.ndarray)): if len(titles) != len(imgs): raise ValueError("Different number of titles given than images") else: titles = [titles] * len(imgs) colormap = None if colormap_name: colormap = plt.cm.get_cmap(colormap_name) for ax, img, title in zip( _flatten(axs, layout_direction == LayoutDirection.left_to_right), imgs, titles ): if not isinstance(img, np.ndarray): raise TypeError("Expected ndarray, instead got {}".format(type(img))) if len(img.shape) not in (2, 3): raise NotImplementedError( "Can only display 2D or 3D arrays, not shape {}".format(img.shape) ) if len(img.shape) == 2: # expand 2d image to 3d with 1 band slicer = [slice(None)] * 3 slicer[bands_axis] = np.newaxis img = img[tuple(slicer)] nbands = img.shape[bands_axis] if nbands not in (1, 3, 4): raise NotImplementedError( "Can only display images with 1 or 3 bands currently, not {}. " "Is axis {} actually your bands axis?".format(nbands, bands_axis) ) if nbands == 4: # don't include alpha band in min max slicer = [slice(None)] * 3 slicer[bands_axis] = slice(None, 3) spectrals = img[tuple(slicer)] else: spectrals = img # calculate min and max if robust: if hasattr(spectrals, "mask"): spectrals = ( spectrals.compressed() ) # don't include masked values in percentile vmin, vmax = np.nanpercentile(spectrals, 2), np.nanpercentile(spectrals, 98) if vmin == vmax: robust = False if not robust: vmin, vmax = np.ma.min(spectrals), np.ma.max(spectrals) # matplotlib requires shape (n, m, band) disp = np.moveaxis(img, bands_axis, -1).astype(np.float64) # rescale disp -= vmin disp /= vmax - vmin np.clip(disp, 0, 1, out=disp) # to coerce away any floating-point errors if hasattr(disp, "mask"): # turn mask into alpha if nbands == 4: raise NotImplementedError( "Currently can't supply an image with an explicit alpha band as well as a mask" ) if np.isscalar(disp.mask): alpha = np.ones(disp.shape[:2]) * (not disp.mask) else: alpha = (~disp.mask.any(axis=-1)).astype(disp.dtype) if nbands == 1: if colormap: disp = colormap(disp[:, :, 0]) disp = disp[ :, :, :3 ] # Removes the alpha channel the color map always adds else: # to use an alpha channel, matplotlib must have a 4-band image, # so just duplicate the 1 band for r, g, and b disp = np.concatenate( [disp] * 3, axis=-1 ) # TODO: unnecessary copy of disp's mask disp = np.concatenate([disp, alpha[:, :, np.newaxis]], axis=-1) if disp.shape[-1] == 1: # matplotlib takes 1 band images as (n, m), not (n, m, 1) disp = disp[:, :, 0] if colormap: disp = colormap(disp) ax.grid(False) # just to be sure ax.imshow(disp, aspect="equal", interpolation=interpolation) if title is not None: ax.set_title(str(title)) fig.tight_layout() if filename is not None: plt.savefig(filename) else: plt.show()