Source code for plotpy.io

# -*- coding: utf-8 -*-
#
# Licensed under the terms of the BSD 3-Clause
# (see plotpy/LICENSE for details)

# pylint: disable=C0103

"""
Input/Output helper functions
-----------------------------

Overview
^^^^^^^^

The `io` module provides the following input/output helper functions:

* :py:func:`.io.imread`: load an image (.png, .tiff,
    .dicom, etc.) and return its data as a NumPy array
* :py:func:`.io.imwrite`: save an array to an image file
* :py:func:`.io.load_items`: load plot items from HDF5
* :py:func:`.io.save_items`: save plot items to HDF5

Reference
^^^^^^^^^

.. autofunction:: imread
.. autofunction:: imwrite
.. autofunction:: load_items
.. autofunction:: save_items
"""

from __future__ import annotations

import logging
import os.path as osp
import re
import sys
from typing import TYPE_CHECKING, Any

import numpy as np
import PIL.Image
import PIL.TiffImagePlugin  # py2exe

from plotpy.config import _

if TYPE_CHECKING:
    import guidata.io


def scale_data_to_dtype(data: np.ndarray, dtype: np.dtype) -> np.ndarray:
    """Scale data to fit datatype dynamic range

    Args:
        data: data to scale
        dtype: target datatype

    Returns:
        scaled data

    .. warning::

        This function modifies data in place
    """
    info = np.iinfo(dtype)
    dmin = data.min()
    dmax = data.max()
    if dmax == dmin:
        # Constant data: avoid 0/0 division (which would yield NaN). Map every
        # sample to the dtype range minimum so downstream image rendering keeps
        # working instead of seeing NaNs.
        data = np.full_like(data, float(info.min))
        return np.array(data, dtype)
    data -= dmin
    data = data * float(info.max - info.min) / (dmax - dmin)
    data = data + float(info.min)
    return np.array(data, dtype)


# ===============================================================================
# I/O File type definitions
# ===============================================================================
class FileType:
    """Filetype object:
    * `name` : description of filetype,
    * `read_func`, `write_func` : I/O callbacks,
    * `extensions`: filename extensions (with a dot!) or filenames,
    (list, tuple or space-separated string)
    * `data_types`: supported data types"""

    def __init__(
        self,
        name,
        extensions,
        read_func=None,
        write_func=None,
        data_types=None,
        requires_template=False,
    ):
        self.name = name
        if isinstance(extensions, str):
            extensions = extensions.split()
        self.extensions = [osp.splitext(" " + ext)[1] for ext in extensions]
        self.read_func = read_func
        self.write_func = write_func
        self.data_types = data_types
        self.requires_template = requires_template

    def matches(self, action, dtype, template):
        """Return True if file type matches passed data type and template
        (or if dtype is None)"""
        assert action in ("load", "save")
        matches = dtype is None or self.data_types is None or dtype in self.data_types
        if action == "save" and self.requires_template:
            matches = matches and template is not None
        return matches

    @property
    def wcards(self):
        """

        :return:
        """
        return "*" + (" *".join(self.extensions))

    def filters(self, action, dtype, template):
        """

        :param action:
        :param dtype:
        :param template:
        :return:
        """
        assert action in ("load", "save")
        if self.matches(action, dtype, template):
            return f"\n{self.name} ({self.wcards})"
        else:
            return ""


class ImageIOHandler:
    """I/O handler: regroup all FileType objects"""

    def __init__(self):
        self.filetypes = []

    def allfilters(self, action, dtype, template):
        """

        :param action:
        :param dtype:
        :param template:
        :return:
        """
        wcards = " ".join(
            [
                ftype.wcards
                for ftype in self.filetypes
                if ftype.matches(action, dtype, template)
            ]
        )
        return "{} ({})".format(_("All supported files"), wcards)

    def get_filters(self, action, dtype=None, template=None):
        """Return file type filters for `action` (string: 'save' or 'load'),
        `dtype` data type (None: all data types), and `template` (True if save
        function requires a template (e.g. DICOM files), False otherwise)"""
        filters = self.allfilters(action, dtype, template)
        for ftype in self.filetypes:
            filters += ftype.filters(action, dtype, template)
        return filters

    def add(
        self,
        name,
        extensions,
        read_func=None,
        write_func=None,
        import_func=None,
        data_types=None,
        requires_template=None,
    ):
        """

        :param name:
        :param extensions:
        :param read_func:
        :param write_func:
        :param import_func:
        :param data_types:
        :param requires_template:
        :return:
        """
        if import_func is not None:
            try:
                import_func()
            except ImportError:
                return
        assert read_func is not None or write_func is not None
        ftype = FileType(
            name,
            extensions,
            read_func=read_func,
            write_func=write_func,
            data_types=data_types,
            requires_template=requires_template,
        )
        self.filetypes.append(ftype)

    def _get_filetype(self, ext):
        """Return FileType object associated to file extension `ext`"""
        for ftype in self.filetypes:
            if ext.lower() in ftype.extensions:
                return ftype
        else:
            raise RuntimeError(f"Unsupported file type: '{ext}'")

    def get_readfunc(self, ext):
        """Return read function associated to file extension `ext`"""
        ftype = self._get_filetype(ext)
        if ftype.read_func is None:
            raise RuntimeError(f"Unsupported file type (read): '{ext}'")
        else:
            return ftype.read_func

    def get_writefunc(self, ext):
        """Return read function associated to file extension `ext`"""
        ftype = self._get_filetype(ext)
        if ftype.write_func is None:
            raise RuntimeError(f"Unsupported file type (write): '{ext}'")
        else:
            return ftype.write_func


iohandler = ImageIOHandler()


# ==============================================================================
# tifffile-based Private I/O functions
# ==============================================================================


def _imread_tiff(filename):
    """Open a TIFF image and return a NumPy array"""
    try:
        import tifffile

        return tifffile.imread(filename)
    except ImportError:
        return _imread_pil(filename)


def _imwrite_tiff(filename, arr):
    """Save a NumPy array to a TIFF file"""
    try:
        import tifffile

        return tifffile.imwrite(filename, arr)
    except ImportError:
        return _imwrite_pil(filename, arr)


# ==============================================================================
# PIL-based Private I/O functions
# ==============================================================================
if sys.byteorder == "little":
    _ENDIAN = "<"
else:
    _ENDIAN = ">"

DTYPES = {
    "1": ("|b1", None),
    "L": ("|u1", None),
    "I": ("%si4" % _ENDIAN, None),
    "F": ("%sf4" % _ENDIAN, None),
    "I;16": ("%su2" % _ENDIAN, None),
    "I;16B": ("%su2" % _ENDIAN, None),
    "I;16S": ("%si2" % _ENDIAN, None),
    "P": ("|u1", None),
    "RGB": ("|u1", 3),
    "RGBX": ("|u1", 4),
    "RGBA": ("|u1", 4),
    "CMYK": ("|u1", 4),
    "YCbCr": ("|u1", 4),
}


def _imread_pil(filename, to_grayscale=False, **kwargs):
    """Open image with PIL and return a NumPy array"""
    PIL.TiffImagePlugin.OPEN_INFO[(PIL.TiffImagePlugin.II, 0, 1, 1, (16,), ())] = (
        "I;16",
        "I;16",
    )
    with PIL.Image.open(filename) as img:
        base, ext = osp.splitext(filename)
        ext = ext.lower()
        if ext in [".tif", ".tiff"]:
            # try to know if multiple pages
            nb_pages = 0
            while True:
                try:
                    img.seek(nb_pages)
                    nb_pages += 1
                except EOFError:
                    break
            if nb_pages > 1:
                for i in range(nb_pages):
                    img.seek(i)
                    filename = base
                    filename += "_{i:d}".format(i=i)
                    filename += ext
                    img.save(filename)
                if nb_pages == 2:
                    # possibility to be TIFF file with thumbnail and full image
                    # --> try to load full image (second one)
                    filename = base + "_{i:d}".format(i=1) + ext
                else:
                    # we don't know which one must be loaded --> load first image
                    filename = base + "_{i:d}".format(i=0) + ext

    with PIL.Image.open(filename) as img:
        if img.mode in ("CMYK", "YCbCr"):
            # Converting to RGB
            img = img.convert("RGB")
        if to_grayscale and img.mode in ("RGB", "RGBA", "RGBX"):
            # Converting to grayscale
            img = img.convert("L")
        elif "A" in img.mode or (img.mode == "P" and "transparency" in img.info):
            img = img.convert("RGBA")
        elif img.mode == "P":
            img = img.convert("RGB")
        try:
            dtype, extra = DTYPES[img.mode]
        except KeyError:
            raise RuntimeError(f"{img.mode} mode is not supported")
        shape = (img.size[1], img.size[0])
        if extra is not None:
            shape += (extra,)
        try:
            return np.array(img, dtype=np.dtype(dtype)).reshape(shape)
        except SystemError:
            return np.array(img.getdata(), dtype=np.dtype(dtype)).reshape(shape)


def _imwrite_pil(filename, arr):
    """Write `arr` NumPy array to `filename` using PIL"""
    for mode, (dtype_str, extra) in list(DTYPES.items()):
        if dtype_str == arr.dtype.str:
            if extra is None:  # mode for grayscale images
                if len(arr.shape[2:]) > 0:
                    continue  # not suitable for RGB(A) images
                else:
                    break  # this is it!
            else:  # mode for RGB(A) images
                if len(arr.shape[2:]) == 0:
                    continue  # not suitable for grayscale images
                elif arr.shape[-1] == extra:
                    break  # this is it!
    else:
        # F Chermette 2022
        if arr.dtype.str == "%sf8" % _ENDIAN:
            arr = np.array(arr, dtype="f4")
            mode = "F"
        else:
            raise RuntimeError("Cannot determine PIL data type")
    img = PIL.Image.fromarray(arr, mode)
    img.save(filename)


# ==============================================================================
# DICOM Private I/O functions
# ==============================================================================
def _import_dcm():
    """DICOM Import function (checking for required libraries):
    DICOM support requires library `pydicom`"""
    logger = logging.getLogger("pydicom")
    logger.setLevel(logging.CRITICAL)

    # This import statement must stay here because the purpose of this function
    # is to check if pydicom is installed:
    # pylint: disable=import-outside-toplevel
    # pylint: disable=import-error
    from pydicom import dcmread  # type:ignore # noqa: F401

    logger.setLevel(logging.WARNING)


def _imread_dcm(filename, **kwargs):
    """Open DICOM image with pydicom and return a NumPy array"""
    # pylint: disable=import-outside-toplevel
    # pylint: disable=import-error
    from pydicom import dcmread  # type:ignore

    dcm = dcmread(filename, force=True)
    # **********************************************************************
    # The following is necessary until pydicom numpy support is improved:
    # (after that, a simple: 'arr = dcm.PixelArray' will work the same)
    format_str = "%sint%s" % (("u", "")[dcm.PixelRepresentation], dcm.BitsAllocated)
    try:
        dtype = np.dtype(format_str)
    except TypeError:
        raise TypeError(
            "Data type not understood by NumPy: "
            "PixelRepresentation=%d, BitsAllocated=%d"
            % (dcm.PixelRepresentation, dcm.BitsAllocated)
        )
    arr = np.frombuffer(dcm.PixelData, dtype)
    try:
        # pydicom 0.9.3:
        dcm_is_little_endian = dcm.isLittleEndian
    except AttributeError:
        # pydicom 0.9.4:
        dcm_is_little_endian = dcm.is_little_endian
    if dcm_is_little_endian != (sys.byteorder == "little"):
        arr.byteswap(True)
    spp = getattr(dcm, "SamplesperPixel", 1)
    if hasattr(dcm, "NumberOfFrames") and dcm.NumberOfFrames > 1:
        if spp > 1:
            arr = arr.reshape(spp, dcm.NumberofFrames, dcm.Rows, dcm.Columns)
        else:
            arr = arr.reshape(dcm.NumberOfFrames, dcm.Rows, dcm.Columns)
    else:
        if spp > 1:
            if dcm.BitsAllocated == 8:
                arr = arr.reshape(spp, dcm.Rows, dcm.Columns)
            else:
                raise NotImplementedError(
                    "This code only handles SamplesPerPixel > 1 if Bits Allocated = 8"
                )
        else:
            arr = arr.reshape(dcm.Rows, dcm.Columns)
    # **********************************************************************
    return arr


def _imwrite_dcm(filename, arr, template=None):
    """Save a numpy array `arr` into a DICOM image file `filename`
    based on DICOM structure `template`"""
    # Note: due to IOHandler formalism, `template` has to be a keyword argument
    assert template is not None, (
        "The `template` keyword argument is required to save DICOM files\n"
        "(that is the template DICOM structure object)"
    )
    info = np.iinfo(arr.dtype)
    template.BitsAllocated = info.bits
    template.BitsStored = info.bits
    template.HighBit = info.bits - 1
    template.PixelRepresentation = ("u", "i").index(info.kind)
    data_vr = ("US", "SS")[template.PixelRepresentation]
    template.Rows = arr.shape[0]
    template.Columns = arr.shape[1]
    template.SmallestImagePixelValue = int(arr.min())
    template[0x00280106].VR = data_vr
    template.LargestImagePixelValue = int(arr.max())
    template[0x00280107].VR = data_vr
    if not template.PhotometricInterpretation.startswith("MONOCHROME"):
        template.PhotometricInterpretation = "MONOCHROME1"
    template.PixelData = arr.tostring()
    template[0x7FE00010].VR = "OB"
    template.save_as(filename)


# ==============================================================================
# Text files Private I/O functions
# ==============================================================================
def _imread_txt(filename, **kwargs):
    """Open text file image and return a NumPy array"""
    for delimiter in ("\t", ",", " ", ";"):
        try:
            return np.loadtxt(filename, delimiter=delimiter)
        except ValueError:
            continue
    else:
        raise ValueError(f"Could not load {filename!r}")


def _imwrite_txt(filename, arr):
    """Write `arr` NumPy array to text file `filename`"""
    if arr.dtype in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32):
        fmt = "%d"
    else:
        fmt = "%.18e"
    ext = osp.splitext(filename)[1]
    if ext.lower() in (".txt", ".asc", ""):
        np.savetxt(filename, arr, fmt=fmt)
    elif ext.lower() == ".csv":
        np.savetxt(filename, arr, fmt=fmt, delimiter=",")


# ==============================================================================
# Registering I/O functions
# ==============================================================================
iohandler.add(
    _("PNG files"),
    "*.png",
    read_func=_imread_pil,
    write_func=_imwrite_pil,
    data_types=(np.uint8, np.uint16),
)
iohandler.add(
    _("TIFF files"), "*.tif *.tiff", read_func=_imread_tiff, write_func=_imwrite_tiff
)
iohandler.add(
    _("8-bit images"),
    "*.jpg *.gif",
    read_func=_imread_pil,
    write_func=_imwrite_pil,
    data_types=(np.uint8,),
)
iohandler.add(_("NumPy arrays"), "*.npy", read_func=np.load, write_func=np.save)
iohandler.add(
    _("Text files"), "*.txt *.csv *.asc", read_func=_imread_txt, write_func=_imwrite_txt
)
iohandler.add(
    _("DICOM files"),
    "*.dcm",
    read_func=_imread_dcm,
    write_func=_imwrite_dcm,
    import_func=_import_dcm,
    data_types=(np.int8, np.uint8, np.int16, np.uint16),
    requires_template=True,
)


# ==============================================================================
# Generic image read/write functions
# ==============================================================================
[docs] def imread( fname: str, ext: str | None = None, to_grayscale: bool = False ) -> np.ndarray: """Read an image from a file as a NumPy array Args: fname: image filename ext: image file extension (if None, extension is guessed from filename) to_grayscale: convert RGB images to grayscale """ if not isinstance(fname, str): fname = str(fname) # in case filename is a QString instance if ext is None: _base, ext = osp.splitext(fname) arr = iohandler.get_readfunc(ext)(fname) dtype = arr.dtype if to_grayscale and arr.ndim == 3: # Converting to grayscale arr = arr[..., :4].mean(axis=2) arr = arr.astype(dtype) return arr else: return arr
[docs] def imwrite( fname: str, arr: np.ndarray, ext: str | None = None, dtype: np.dtype | None = None, max_range: bool | None = None, **kwargs, ) -> None: """Write a NumPy array to an image file Args: fname: image filename arr: NumPy array ext: image file extension (if None, extension is guessed from filename) dtype: data type (if None, data type is guessed from array) max_range: scale data to fit dtype dynamic range kwargs: additional keyword arguments passed to the image writer .. warning:: If `max_range` is True, array data is modified in place """ if not isinstance(fname, str): fname = str(fname) # in case filename is a QString instance if ext is None: _base, ext = osp.splitext(fname) if max_range: arr = scale_data_to_dtype(arr, arr.dtype if dtype is None else dtype) iohandler.get_writefunc(ext)(fname, arr, **kwargs)
# ============================================================================== # plotpy plot items I/O # ============================================================================== SERIALIZABLE_ITEMS = [] ITEM_MODULES = {} def register_serializable_items(modname, classnames): """Register serializable item from module name and class name""" global SERIALIZABLE_ITEMS, ITEM_MODULES SERIALIZABLE_ITEMS += classnames ITEM_MODULES[modname] = ITEM_MODULES.setdefault(modname, []) + classnames # Curves register_serializable_items( "plotpy.items", [ "CurveItem", "QuiverItem", "PolygonMapItem", "ErrorBarCurveItem", "RawImageItem", "ImageItem", "XYImageItem", "RGBImageItem", "TrImageItem", "MaskedImageItem", "MaskedXYImageItem", "Marker", "XRangeSelection", "YRangeSelection", "PolygonShape", "PointShape", "SegmentShape", "RectangleShape", "ObliqueRectangleShape", "EllipseShape", "Axes", "AnnotatedPoint", "AnnotatedSegment", "AnnotatedXRange", "AnnotatedYRange", "AnnotatedRectangle", "AnnotatedObliqueRectangle", "AnnotatedEllipse", "AnnotatedCircle", "AnnotatedPolygon", "LabelItem", "LegendBoxItem", "SelectedLegendBoxItem", ], ) def item_class_from_name(name: str) -> type[Any] | None: """Return plot item class from class name Args: name: plot item class name Returns: plot item class Raises: AssertionError: if class name is unknown (item is not serializable) """ global SERIALIZABLE_ITEMS, ITEM_MODULES assert name in SERIALIZABLE_ITEMS, "Unknown class %r" % name for modname, names in list(ITEM_MODULES.items()): if name in names: return getattr(__import__(modname, fromlist=[name]), name) def item_name_from_object(obj: Any) -> str | None: """Return plot item class name from instance Args: obj: plot item instance Returns: plot item class name """ return obj.__class__.__name__ def save_item( writer: guidata.io.HDF5Writer | guidata.io.INIWriter | guidata.io.JSONWriter, group_name, item: Any, ) -> None: """Save plot item to HDF5, INI or JSON file Args: writer: HDF5, INI or JSON writer group_name: group name item: serializable plot item """ with writer.group(group_name): if item is None: writer.write_none() else: item.serialize(writer) with writer.group("item_class_name"): writer.write_str(item_name_from_object(item)) def load_item( reader: guidata.io.HDF5Reader | guidata.io.INIReader | guidata.io.JSONReader, group_name, ) -> Any | None: """Load plot item from HDF5, INI or JSON file Args: reader: HDF5, INI or JSON reader group_name: group name Returns: Plot item instance """ with reader.group(group_name): with reader.group("item_class_name"): try: klass_name = reader.read_str() except ValueError: # None was saved instead of a real item return klass = item_class_from_name(klass_name) item = klass() item.deserialize(reader) return item
[docs] def save_items( writer: guidata.io.HDF5Writer | guidata.io.INIWriter | guidata.io.JSONWriter, items: list[Any], ) -> None: """Save items to HDF5, INI or JSON file Args: writer: HDF5, INI or JSON writer items: list of serializable plot items """ counts = {} names = [] def _get_name(item): basename = item_name_from_object(item) count = counts[basename] = counts.setdefault(basename, 0) + 1 name = "%s_%03d" % (basename, count) names.append(name.encode("utf-8")) return name for item in items: with writer.group(_get_name(item)): item.serialize(writer) writer.write(item.isVisible(), group_name="visible") with writer.group("plot_items"): writer.write_sequence(names)
[docs] def load_items( reader: guidata.io.HDF5Reader | guidata.io.INIReader | guidata.io.JSONReader, ) -> list[Any]: """Load items from HDF5, INI or JSON file Args: reader: HDF5, INI or JSON reader Returns: list of plot item instances """ with reader.group("plot_items"): names = reader.read_sequence() items = [] for name in names: try: name_str = name.decode() except AttributeError: name_str = name klass_name = re.match(r"([A-Z]+[A-Za-z0-9\_]*)\_([0-9]*)", name_str).groups()[0] klass = item_class_from_name(klass_name) item = klass() with reader.group(name): item.deserialize(reader) item.setVisible(reader.read("visible", default=True)) items.append(item) return items
if __name__ == "__main__": # Test if items can all be constructed from their Python module for name in SERIALIZABLE_ITEMS: print(name, "-->", item_class_from_name(name))