Source code for astrocut.image_cutout

import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Tuple, Union

import numpy as np
from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from astropy.visualization import (
    AsinhStretch,
    AsymmetricPercentileInterval,
    LinearStretch,
    LogStretch,
    ManualInterval,
    MinMaxInterval,
    SinhStretch,
    SqrtStretch,
)
from astropy.wcs.utils import proj_plane_pixel_scales
from PIL.Image import Image, Transpose, fromarray, registered_extensions
from PIL.PngImagePlugin import PngInfo
from s3path import S3Path

from . import __version__, log
from .cutout import Cutout
from .exceptions import DataWarning, InputWarning, InvalidInputError


class ImageCutout(Cutout, ABC):
    """
    Abstract class for creating cutouts from images. This class defines attributes and methods that are common to all
    image cutout classes.

    Parameters
    ----------
    input_files : list
        List of input image files.
    coordinates : str | `~astropy.coordinates.SkyCoord`
        Coordinates of the center of the cutout.
    cutout_size : int | array | list | tuple | `~astropy.units.Quantity`
        Size of the cutout array.
    fill_value : int | float
        Value to fill the cutout with if the cutout is outside the image.
    limit_rounding_method : str
        Method to use for rounding the cutout limits. Options are 'round', 'ceil', and 'floor'.
    verbose : bool
        If True, log messages are printed to the console.

    Attributes
    ----------
    cutouts_by_file : dict
        Dictionary containing the cutouts for each input file.
    image_cutouts : list
        List of `~PIL.Image` objects representing the cutouts.

    Methods
    -------
    get_image_cutouts(stretch, minmax_percent, minmax_value, invert, colorize)
        Get the cutouts as `~PIL.Image` objects.
    cutout()
        Generate the cutouts.
    write_as_img(stretch, minmax_percent, minmax_value, invert, colorize, output_format, output_dir, cutout_prefix)
        Write the cutouts to a file in an image format.
    normalize_img(stretch, minmax_percent, minmax_value, invert)
        Apply given stretch and scaling to an image array.
    """

    def __init__(
        self,
        input_files: List[Union[str, Path, S3Path]],
        coordinates: Union[SkyCoord, str],
        cutout_size: Union[int, np.ndarray, Quantity, List[int], Tuple[int]] = 25,
        fill_value: Union[int, float] = np.nan,
        limit_rounding_method: str = "round",
        verbose: bool = False,
    ):
        super().__init__(input_files, coordinates, cutout_size, fill_value, limit_rounding_method, verbose)

        # Stores the image cutouts as PIL.Image objects
        self._image_cutouts = None

    @property
    def image_cutouts(self) -> List[Image]:
        """
        Return the cutouts as a list of `PIL.Image` objects.

        If the image objects have not been generated yet, they will be generated with default
        normalization parameters.
        """
        if not self._image_cutouts:
            self._image_cutouts = self.get_image_cutouts()
        return self._image_cutouts

    def get_image_cutouts(
        self,
        stretch: Optional[str] = "asinh",
        minmax_percent: Optional[List[int]] = None,
        minmax_value: Optional[List[int]] = None,
        invert: Optional[bool] = False,
        colorize: Optional[bool] = False,
        flip_orientation: Optional[bool] = True,
    ) -> List[Image]:
        """
        Get the cutouts as `~PIL.Image` objects given certain normalization parameters. This method also sets
        the `image_cutouts` attribute.

        Parameters
        ----------
        stretch : str
            Optional, default 'asinh'. The stretch to apply to the image array.
            Valid values are: asinh, sinh, sqrt, log, linear
        minmax_percent : array
            Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric)
            when scaling the image. The format is [lower percentile, upper percentile], where pixel
            values below the lower percentile and above the upper percentile are clipped.
            Only one of minmax_percent and minmax_value should be specified.
        minmax_value : array
            Optional. Interval based on user-specified pixel values when scaling the image.
            The format is [min value, max value], where pixel values below the min value and above
            the max value are clipped.
            Only one of minmax_percent and minmax_value should be specified.
        invert : bool
            Optional, default False.  If True the image is inverted (light pixels become dark and vice versa).
        colorize : bool
            Optional, default False. If True, the first three cutouts will be combined into a single RGB image.
        flip_orientation : bool
            Optional, default True. If True, the cutout images are flipped vertically to match the orientation
            of the input images.

        Returns
        -------
        image_cutouts : list
            List of `~PIL.Image` objects representing the cutouts.
        """
        # Validate the stretch parameter
        valid_stretches = ["asinh", "sinh", "sqrt", "log", "linear"]
        if not isinstance(stretch, str) or stretch.lower() not in valid_stretches:
            raise InvalidInputError(f"Stretch {stretch} is not recognized. Valid options are {valid_stretches}.")
        stretch = stretch.lower()

        # Apply default scaling for image outputs
        if (minmax_percent is None) and (minmax_value is None):
            minmax_percent = [0.5, 99.5]

        if colorize:  # color cutout
            all_cutouts = []
            all_cutout_files = []
            for file in self._input_files:
                file_cutouts = self.cutouts_by_file.get(file, [])
                all_cutouts.extend(file_cutouts)
                all_cutout_files.extend([file] * len(file_cutouts))

                if len(all_cutouts) > 3:
                    warnings.warn(
                        "Too many inputs for a color cutout, only the first three will be used.", InputWarning
                    )
                    all_cutouts = all_cutouts[:3]
                    all_cutout_files = all_cutout_files[:3]
                    break

            # Check for the correct number of cutouts
            if len(all_cutouts) < 3:
                raise InvalidInputError(
                    (
                        "Color cutouts require 3 input images (RGB)."
                        "If you supplied 3 images one of the cutouts may have been empty."
                    )
                )

            img_arrs = []
            for cutout in all_cutouts:
                # Image output, applying the appropriate normalization parameters
                img_arrs.append(self.normalize_img(cutout.data, stretch, minmax_percent, minmax_value, invert))

            # Combine the three cutouts into a single RGB image
            color_img = fromarray(np.dstack([img_arrs[0], img_arrs[1], img_arrs[2]]).astype(np.uint8))
            if flip_orientation:
                # Flip the image vertically to match the orientation of the input cutouts
                color_img = color_img.transpose(Transpose.FLIP_TOP_BOTTOM)
            color_img.info.update(self._build_cutout_metadata(all_cutout_files))
            self._image_cutouts = [color_img]
        else:  # one image per cutout
            image_cutouts = []
            for file, cutout_list in self.cutouts_by_file.items():
                for i, cutout in enumerate(cutout_list):
                    # Apply the appropriate normalization parameters
                    img_arr = self.normalize_img(cutout.data, stretch, minmax_percent, minmax_value, invert)
                    img = fromarray(img_arr)
                    # Flip the image vertically to match the orientation of the input cutouts
                    if flip_orientation:
                        # Flip the image vertically to match the orientation of the input cutouts
                        img = img.transpose(Transpose.FLIP_TOP_BOTTOM)
                    img.info.update(self._build_cutout_metadata([file]))
                    image_cutouts.append(img)

            self._image_cutouts = image_cutouts

        return self._image_cutouts

    @abstractmethod
    def _cutout_file(self, file: Union[str, Path, S3Path]):
        """
        Cutout an image file.

        This method is abstract and should be defined in subclasses.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    @abstractmethod
    def cutout(self):
        """
        Generate the cutout(s).

        This method is abstract and should be defined in subclasses.
        """
        raise NotImplementedError("Subclasses must implement this method.")

    def _parse_output_format(self, output_format: str) -> str:
        """
        Parse the output format string and return it in a standardized format.

        Parameters
        ----------
        output_format : str
            The output format string.

        Returns
        -------
        out_format : str
            The output format string in a standardized format.
        """
        # Put format in standard format
        out_lower = output_format.lower()
        output_format = f".{out_lower}" if not output_format.startswith(".") else out_lower

        # Error if the output format is not supported
        if output_format not in registered_extensions().keys():
            raise InvalidInputError(f"Output format {output_format} is not supported.")

        return output_format

    def _build_cutout_metadata(self, input_files: List[Union[str, Path, S3Path]]) -> dict:
        """
        Build metadata describing a cutout image.

        Parameters
        ----------
        input_files : list
            The input file or files used to create the cutout.

        Returns
        -------
        metadata : dict
            Structured metadata for the cutout image.
        """
        meta = {}

        # Add input file information to the metadata, using a single key
        # if there is only one input file and a list if there are multiple
        if len(input_files) == 1:
            meta["input_file"] = Path(input_files[0]).name
        else:
            meta["input_files"] = ", ".join([Path(file).name for file in input_files])

        # For color cutouts, we can get the pixel scale and cutout size from the first cutout
        # since we assume they have the same WCS
        cutout = self.cutouts_by_file.get(input_files[0], [None])[0]
        if cutout is not None:
            meta["cutout_size_x_pix"] = cutout.shape[1]
            meta["cutout_size_y_pix"] = cutout.shape[0]
            meta["pixel_scale_arcsec_per_pix"] = proj_plane_pixel_scales(cutout.wcs)[0] * 3600

        meta.update(
            {
                "center_ra_deg": self._coordinates.ra.deg.item(),
                "center_dec_deg": self._coordinates.dec.deg.item(),
                "origin": "STScI/MAST",
                "version": __version__,
            }
        )

        return meta

    def _get_img_save_kwargs(self, im: Image, file_path: str) -> dict:
        """
        Return format-specific keyword arguments for saving an image.

        Parameters
        ----------
        im : `~PIL.Image`
            The image to save.
        file_path : str
            The output file path.

        Returns
        -------
        save_kwargs : dict
            Keyword arguments to pass to `PIL.Image.Image.save`.
        """
        # Get the cutout metadata from top-level Image.info
        metadata = im.info
        if not metadata:
            return {}
        metadata_json = json.dumps(metadata)
        output_suffix = Path(file_path).suffix.lower()

        # Format for saving to file depends on output format
        if output_suffix == ".png":
            pnginfo = PngInfo()
            for key, value in metadata.items():
                pnginfo.add_text(key, str(value))
            return {"pnginfo": pnginfo}

        if output_suffix in {".jpg", ".jpeg", ".tif", ".tiff"}:
            exif = im.getexif()
            # Use tag for image description to store cutout metadata as json string
            exif[270] = metadata_json
            return {"exif": exif.tobytes()}

        return {}

    def _save_img_to_file(self, im: Image, file_path: str) -> bool:
        """
        Save a `~PIL.Image` object to a file.

        Parameters
        ----------
        im : `~PIL.Image`
            The image to save.
        file_path : str
            The path to save the image to.

        Returns
        -------
        success : bool
            True if the image was saved successfully, False otherwise.
        """
        try:
            save_kwargs = self._get_img_save_kwargs(im, file_path)
            im.save(file_path, **save_kwargs)
            return True
        except ValueError as e:
            output_format = Path(file_path).suffix
            warnings.warn(
                f"Cutout could not be saved in {output_format} format: {e}. Please try a different output format.",
                DataWarning,
            )
            return False
        except KeyError as e:
            output_format = Path(file_path).suffix
            warnings.warn(
                f"Cutout could not be saved in {output_format} format due to a KeyError: {e}. "
                "Please try a different output format.",
                DataWarning,
            )
            return False
        except OSError as e:
            warnings.warn(f"Cutout could not be saved: {e}", DataWarning)
            return False

    def write_as_img(
        self,
        stretch: Optional[str] = "asinh",
        minmax_percent: Optional[List[int]] = None,
        minmax_value: Optional[List[int]] = None,
        invert: Optional[bool] = False,
        colorize: Optional[bool] = False,
        output_format: str = ".jpg",
        output_dir: Union[str, Path] = ".",
        cutout_prefix: str = "cutout",
        flip_orientation: Optional[bool] = True,
    ) -> Union[str, List[str]]:
        """
        Write the cutout to memory or to a file in an image format. If colorize is set, the first 3 cutouts
        will be combined into a single RGB image. Otherwise, each cutout will be written to a separate file.

        Parameters
        ----------
        stretch : str
            Optional, default 'asinh'. The stretch to apply to the image array.
            Valid values are: asinh, sinh, sqrt, log, linear
        minmax_percent : array
            Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric)
            when scaling the image. The format is [lower percentile, upper percentile], where pixel
            values below the lower percentile and above the upper percentile are clipped.
            Only one of minmax_percent and minmax_value shoul be specified.
        minmax_value : array
            Optional. Interval based on user-specified pixel values when scaling the image.
            The format is [min value, max value], where pixel values below the min value and above
            the max value are clipped.
            Only one of minmax_percent and minmax_value should be specified.
        invert : bool
            Optional, default False.  If True the image is inverted (light pixels become dark and vice versa).
        colorize : bool
            Optional, default False. If True, the first three cutouts will be combined into a single RGB image.
        flip_orientation : bool
            Optional, default True. If True, the cutout images are flipped vertically to match the
            orientation of the input images.
        output_format : str
            Optional, default '.jpg'. The output format for the cutout image(s).
        output_dir : str | `~pathlib.Path`
            Optional, default '.'. The directory to write the cutout image(s) to.
        cutout_prefix : str
            Optional, default 'cutout'. The prefix to add to the cutout image file name.

        Returns
        -------
        cutout_path : List[Path]
            Path(s) to the written cutout files.

        Raises
        ------
        InvalidInputError
            If less than three inputs were provided for a colorized cutout.
        """
        # Parse the output format
        output_format = self._parse_output_format(output_format)

        # Get the image cutouts with the given normalization parameters
        image_cutouts = self.get_image_cutouts(
            stretch, minmax_percent, minmax_value, invert, colorize, flip_orientation
        )

        # Create the output directory if it does not exist
        Path(output_dir).mkdir(parents=True, exist_ok=True)

        # Set up output files and write them
        if colorize:  # Combine first three cutouts into a single RGB image
            # Write the colorized cutout to disk
            filename = "{}_{:.7f}_{:.7f}_{}-x-{}_astrocut{}".format(
                cutout_prefix,
                self._coordinates.ra.value,
                self._coordinates.dec.value,
                str(self._cutout_size[0]).replace(" ", ""),
                str(self._cutout_size[1]).replace(" ", ""),
                output_format,
            )

            # Attempt to write image to file
            cutout_paths = Path(output_dir, filename).as_posix()
            success = self._save_img_to_file(image_cutouts[0], cutout_paths)
            if not success:
                cutout_paths = None

        else:  # Write each cutout to a separate image file
            cutout_paths = []  # Store the paths of the written cutout files
            for i, file in enumerate(self.cutouts_by_file):
                # Write individual cutouts to disk
                filename = "{}_{:.7f}_{:.7f}_{}-x-{}_astrocut_{}{}".format(
                    Path(file).stem,
                    self._coordinates.ra.value,
                    self._coordinates.dec.value,
                    str(self._cutout_size[0]).replace(" ", ""),
                    str(self._cutout_size[1]).replace(" ", ""),
                    i,
                    output_format,
                )

                # Attempt to write image to file
                cutout_path = Path(output_dir, filename).as_posix()
                success = self._save_img_to_file(image_cutouts[i], cutout_path)

                # Append the path to the written file or the memory object
                # If the image could not be written, append None
                if not success:
                    cutout_path = None
                cutout_paths.append(cutout_path)

        log.debug("Cutout filepaths: {}".format(cutout_paths))
        return cutout_paths

    @staticmethod
    def normalize_img(
        img_arr: np.ndarray,
        stretch: str = "asinh",
        minmax_percent: Optional[List[int]] = None,
        minmax_value: Optional[List[int]] = None,
        invert: bool = False,
    ) -> np.ndarray:
        """
        Apply given stretch and scaling to an image array.

        Parameters
        ----------
        img_arr : array
            The input image array.
        stretch : str
            Optional, default 'asinh'. The stretch to apply to the image array.
            Valid values are: asinh, sinh, sqrt, log, linear
        minmax_percent : array
            Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric)
            when scaling the image. The format is [lower percentile, upper percentile], where pixel
            values below the lower percentile and above the upper percentile are clipped.
            Only one of minmax_percent and minmax_value shoul be specified.
        minmax_value : array
            Optional. Interval based on user-specified pixel values when scaling the image.
            The format is [min value, max value], where pixel values below the min value and above
            the max value are clipped.
            Only one of minmax_percent and minmax_value should be specified.
        invert : bool
            Optional, default False.  If True the image is inverted (light pixels become dark and vice versa).

        Returns
        -------
        response : array
            The normalized image array, in the form in an integer arrays with values in the range 0-255.

        Raises
        ------
        InvalidInputError
            If the stretch is not supported.
        """

        # Check if the input image array is empty
        if img_arr.size == 0:
            raise InvalidInputError("Input image array is empty.")

        # Setting up the transform with the stretch
        if stretch == "asinh":
            transform = AsinhStretch()
        elif stretch == "sinh":
            transform = SinhStretch()
        elif stretch == "sqrt":
            transform = SqrtStretch()
        elif stretch == "log":
            transform = LogStretch()
        elif stretch == "linear":
            transform = LinearStretch()
        else:
            raise InvalidInputError(
                f"Stretch {stretch} is not supported!Valid options are: asinh, sinh, sqrt, log, linear."
            )

        # Adding the scaling to the transform
        if minmax_percent is not None:
            transform += AsymmetricPercentileInterval(*minmax_percent)

            if minmax_value is not None:
                warnings.warn(
                    "Both minmax_percent and minmax_value are set, minmax_value will be ignored.", InputWarning
                )
        elif minmax_value is not None:
            transform += ManualInterval(*minmax_value)
        else:  # Default, scale the entire image range to [0,1]
            transform += MinMaxInterval()

        # Performing the transform and then putting it into the integer range 0-255
        norm_img = transform(img_arr)
        np.multiply(255, norm_img, out=norm_img)
        norm_img = norm_img.astype(np.uint8)

        # Applying invert if requested
        np.subtract(255, norm_img, out=norm_img, where=invert)

        return norm_img


[docs] def normalize_img( img_arr: np.ndarray, stretch: str = "asinh", minmax_percent: Optional[List[int]] = None, minmax_value: Optional[List[int]] = None, invert: bool = False, ) -> np.ndarray: """ Apply given stretch and scaling to an image array. Parameters ---------- img_arr : array The input image array. stretch : str Optional, default 'asinh'. The stretch to apply to the image array. Valid values are: asinh, sinh, sqrt, log, linear minmax_percent : array Optional. Interval based on a keeping a specified fraction of pixels (can be asymmetric) when scaling the image. The format is [lower percentile, upper percentile], where pixel values below the lower percentile and above the upper percentile are clipped. Only one of minmax_percent and minmax_value shoul be specified. minmax_value : array Optional. Interval based on user-specified pixel values when scaling the image. The format is [min value, max value], where pixel values below the min value and above the max value are clipped. Only one of minmax_percent and minmax_value should be specified. invert : bool Optional, default False. If True the image is inverted (light pixels become dark and vice versa). Returns ------- response : array The normalized image array, in the form in an integer arrays with values in the range 0-255. """ return ImageCutout.normalize_img( img_arr=img_arr, stretch=stretch, minmax_percent=minmax_percent, minmax_value=minmax_value, invert=invert )