Source code for clisops.ops.regrid

"""Regridding operation for xarray datasets."""

import warnings
from datetime import datetime as dt
from pathlib import Path

import xarray as xr
from loguru import logger

from clisops.core import Grid, Weights
from clisops.core import regrid as core_regrid
from clisops.exceptions import InvalidParameterValue
from clisops.ops.base_operation import Operation
from clisops.utils.file_namers import get_file_namer

__all__ = [
    "regrid",
]

supported_regridding_methods = ["conservative", "patch", "nearest_s2d", "bilinear"]


class Regrid(Operation):
    """Class for regridding operation, extends clisops.ops.base_operation.Operation."""

    @staticmethod
    def _get_grid_in(
        grid_desc: xr.Dataset | xr.DataArray,
        compute_bounds: bool,
    ):
        """
        Create clisops.core.regrid.Grid object as input grid of the regridding operation.

        Return the Grid object.
        """
        if isinstance(grid_desc, (xr.Dataset, xr.DataArray)):
            return Grid(ds=grid_desc, compute_bounds=compute_bounds)
        raise InvalidParameterValue(
            "An xarray.Dataset or xarray.DataArray has to be provided as input for the source grid."
        )

    def _get_grid_out(
        self,
        grid_desc: xr.Dataset | xr.DataArray | int | float | tuple | str,
        compute_bounds: bool,
        mask: str | None = None,
    ) -> Grid:
        """
        Create clisops.core.regrid.Grid object as target grid of the regridding operation.

        Returns
        -------
        Grid

        """
        if isinstance(grid_desc, str):
            if grid_desc in ["auto", "adaptive"]:
                return Grid(
                    ds=self.ds,
                    grid_id=grid_desc,
                    compute_bounds=compute_bounds,
                    mask=mask,
                )
            else:
                return Grid(grid_id=grid_desc, compute_bounds=compute_bounds, mask=mask)
        elif isinstance(grid_desc, (float, int, tuple)):
            return Grid(grid_instructor=grid_desc, compute_bounds=compute_bounds, mask=mask)
        elif isinstance(grid_desc, (xr.Dataset, xr.DataArray)):
            return Grid(ds=grid_desc, compute_bounds=compute_bounds, mask=mask)
        else:
            # clisops.core.regrid.Grid will raise the exception
            return Grid()

    @staticmethod
    def _get_weights(grid_in: Grid, grid_out: Grid, method: str):
        """
        Generate the remapping weights using clisops.core.regrid.Weights.

        Returns
        -------
        Weights
            An instance of the Weights object.

        """
        return Weights(grid_in=grid_in, grid_out=grid_out, method=method)

    def _resolve_params(self, **params) -> None:
        """Generate a dictionary of regrid parameters."""
        # all regrid specific parameters should be passed in via **params
        # this is where we resolve them and set self.params as a dict or as separate attributes
        # this would be where you make use of your other methods/ attributes e.g.
        # get_grid_in(), get_grid_out() and get_weights() to generate the regridder

        adaptive_masking_threshold = params.get("adaptive_masking_threshold", None)
        grid = params.get("grid", None)
        method = params.get("method", None)
        keep_attrs = params.get("keep_attrs", None)
        mask = params.get("mask", None)

        if mask not in ["land", "ocean", False, None]:
            raise ValueError(f"mask must be one of 'land', 'ocean' or None, not '{mask}'.")

        if method not in supported_regridding_methods:
            raise Exception(
                f"The selected regridding method is not supported. Please choose one of: "
                f"{', '.join(supported_regridding_methods)}."
            )

        logger.debug(
            f"Input parameters: method: {method}, grid: {grid}, "
            f"adaptive_masking: {adaptive_masking_threshold}, "
            f"mask: {mask}, keep_attrs: {keep_attrs}"
        )

        # Compute bounds only when required
        compute_bounds = "conservative" in method

        # Create and check source and target grids
        grid_in = self._get_grid_in(self.ds, compute_bounds)
        grid_out = self._get_grid_out(grid, compute_bounds, mask=mask)

        if grid_in.hash == grid_out.hash:
            weights = None
            regridder = None
            weights_filename = None
        else:
            # Compute the remapping weights
            t_start = dt.now()
            weights = self._get_weights(grid_in=grid_in, grid_out=grid_out, method=method)
            regridder = weights.regridder
            weights_filename = regridder.filename
            t_end = dt.now()
            logger.info(f"Computed/Retrieved weights in {(t_end - t_start).total_seconds()} seconds.")

        # Define params dict
        self.params = {
            "orig_ds": self.ds,
            "grid_in": grid_in,
            "grid_out": grid_out,
            "method": method,
            "regridder": regridder,
            "weights": weights,
            "adaptive_masking_threshold": adaptive_masking_threshold,
            "keep_attrs": keep_attrs,
        }

        # Input grid / Dataset
        self.ds = self.params.get("grid_in").ds

        # There is no __str__() method for the Regridder object, so I used its filename attribute,
        # which specifies a default filename (does not correspond with the filename we would give the weight file).
        # todo: Better option might be to have the Weights class extend the Regridder class or to define
        #  a __str__() method for the Weights class.
        logger.debug(
            "Resolved parameters: grid_in: {}, grid_out: {}, regridder: {}".format(
                self.params.get("grid_in").__str__(),
                self.params.get("grid_out").__str__(),
                weights_filename,
            )
        )

    def _get_file_namer(self) -> object:
        """Return the appropriate file namer object."""
        # "extra" is what will go at the end of the file name before .nc
        extra = "_regrid-{}-{}".format(self.params.get("method"), self.params.get("grid_out").__str__())

        namer = get_file_namer(self._file_namer)(extra=extra)

        return namer

    def _calculate(self):
        """
        Process the regridding request, calls clisops.core.regrid.regrid().

        Returns the resulting xarray.Dataset.
        """
        # Pass through the input dataset if grid_in and grid_out are equal
        if self.params.get("grid_in").hash == self.params.get("grid_out").hash:
            warnings.warn("The selected source and target grids are the same. No regridding operation required.")
            return self.params.get("orig_ds")

        # the result is saved by the process() method on the base class
        regridded_ds = core_regrid(
            self.params.get("grid_in", None),
            self.params.get("grid_out", None),
            self.params.get("weights", None),
            self.params.get("adaptive_masking_threshold", None),
            self.params.get("keep_attrs", None),
        )

        return regridded_ds


[docs] def regrid( ds: xr.Dataset | str | Path, *, method: str | None = "nearest_s2d", adaptive_masking_threshold: int | float | None = 0.5, grid: None | (xr.Dataset | xr.DataArray | int | float | tuple | str) = "adaptive", mask: str | None = None, output_dir: str | Path | None = None, output_type: str | None = "netcdf", split_method: str | None = "time:auto", file_namer: str | None = "standard", keep_attrs: bool | str | None = True, ) -> list[xr.Dataset | str]: """ Regrid specified input file or xarray object. Parameters ---------- ds : xarray.Dataset or str or Path Dataset to regrid, or a path to a file or files (wildcards allowed). method : {"nearest_s2d", "conservative", "patch", "bilinear"} The regridding method to use. Default is "nearest_s2d". adaptive_masking_threshold : int or float, optional Threshold for adaptive masking. If None, adaptive masking is not applied. Default is 0.5. grid : xarray.Dataset or xarray.DataArray or int or float or tuple or str The target grid for regridding. If None, the default grid is used. If "adaptive", an adaptive grid will be used based on the input dataset. If "auto", the grid will be automatically determined based on the input dataset. If a tuple, it should be in the format (lat, lon) or (lat, lon, level). Default is "adaptive". mask : {"ocean", "land"}, optional The mask to apply to the regridded data. If None, no mask is applied. output_dir : str or Path, optional The directory where the output files will be saved. If None, the output will not be saved to disk. output_type : {"netcdf", "nc", "zarr", "xarray"} The format of the output files. If "xarray", the output will be an xarray Dataset. If "netcdf", "nc", or "zarr", the output will be saved to files in the specified format. Default is "netcdf". split_method : {"time:auto"} The method to split the output files. Currently only "time:auto" is supported, which will split the output files by time slices automatically. Default is "time:auto". file_namer : {"standard", "simple"} File namer to use for generating output file names. "standard" uses the dataset name and adds a suffix for the operation. "simple" uses a numbered sequence for the output files. Default is "standard". keep_attrs : {True, False, "target"} Whether to keep the attributes of the input dataset in the output dataset. If "target", the attributes of the target grid will be kept. Default is True. Returns ------- list of xr.Dataset or list of str A list of the regridded outputs in the format selected; str corresponds to file paths if the output format selected is a file. Examples -------- | ds: xarray Dataset or "cmip5.output1.MOHC.HadGEM2-ES.rcp85.mon.atmos.Amon.r1i1p1.latest.tas" | method: "nearest_s2d" | adaptive_masking_threshold: | grid: "1deg" | mask: "land" | output_dir: "/cache/wps/procs/req0111" | output_type: "netcdf" | split_method: "time:auto" | file_namer: "standard" | keep_attrs: True """ op = Regrid(**locals()) return op.process()