Source code for clisops.ops.regrid

from datetime import datetime as dt
from pathlib import Path
from typing import List, Optional, Union

import xarray as xr
from loguru import logger
from roocs_utils.exceptions import InvalidParameterValue

from clisops.core import Grid, Weights
from clisops.core import regrid as core_regrid
from clisops.ops.base_operation import Operation
from clisops.utils.file_namers import get_file_namer

__all__ = [
    "regrid",
]

supported_regridding_methods = ["conservative", "patch", "nearest_s2d", "bilinear"]


class Regrid(Operation):
    """Class for regridding operation, extends clisops.ops.base_operation.Operation."""

    @staticmethod
    def _get_grid_in(
        grid_desc: Union[xr.Dataset, xr.DataArray],
        compute_bounds: bool,
    ):
        """
        Create clisops.core.regrid.Grid object as input grid of the regridding operation.

        Return the Grid object.
        """
        if isinstance(grid_desc, (xr.Dataset, xr.DataArray)):
            return Grid(ds=grid_desc, compute_bounds=compute_bounds)
        raise InvalidParameterValue(
            "An xarray.Dataset or xarray.DataArray has to be provided as input for the source grid."
        )

    def _get_grid_out(
        self,
        grid_desc: Union[xr.Dataset, xr.DataArray, int, float, tuple, str],
        compute_bounds: bool,
    ) -> Grid:
        """Create clisops.core.regrid.Grid object as target grid of the regridding operation.

        Returns
        -------
        Grid
        """
        if isinstance(grid_desc, str):
            if grid_desc in ["auto", "adaptive"]:
                return Grid(
                    ds=self.ds, grid_id=grid_desc, compute_bounds=compute_bounds
                )
            else:
                return Grid(grid_id=grid_desc, compute_bounds=compute_bounds)
        elif isinstance(grid_desc, (float, int, tuple)):
            return Grid(grid_instructor=grid_desc, compute_bounds=compute_bounds)
        elif isinstance(grid_desc, (xr.Dataset, xr.DataArray)):
            return Grid(ds=grid_desc, compute_bounds=compute_bounds)
        else:
            # clisops.core.regrid.Grid will raise the exception
            return Grid()

    @staticmethod
    def _get_weights(grid_in: Grid, grid_out: Grid, method: str):
        """Generate the remapping weights using clisops.core.regrid.Weights.

        Returns
        -------
        Weights
            An instance of the Weights object.
        """
        return Weights(grid_in=grid_in, grid_out=grid_out, method=method)

    def _resolve_params(self, **params) -> None:
        """Generate a dictionary of regrid parameters."""
        # all regrid specific parameters should be passed in via **params
        # this is where we resolve them and set self.params as a dict or as separate attributes
        # this would be where you make use of your other methods/ attributes e.g.
        # get_grid_in(), get_grid_out() and get_weights() to generate the regridder

        adaptive_masking_threshold = params.get("adaptive_masking_threshold", None)
        grid = params.get("grid", None)
        method = params.get("method", None)
        keep_attrs = params.get("keep_attrs", None)

        if method not in supported_regridding_methods:
            raise Exception(
                "The selected regridding method is not supported. "
                "Please choose one of %s." % ", ".join(supported_regridding_methods)
            )

        logger.debug(
            f"Input parameters: method: {method}, grid: {grid}, adaptive_masking: {adaptive_masking_threshold}"
        )

        # Compute bounds only when required
        compute_bounds = "conservative" in method

        # Create and check source and target grids
        grid_in = self._get_grid_in(self.ds, compute_bounds)
        grid_out = self._get_grid_out(grid, compute_bounds)

        # Compute the remapping weights
        t_start = dt.now()
        weights = self._get_weights(grid_in=grid_in, grid_out=grid_out, method=method)
        t_end = dt.now()
        logger.info(
            f"Computed/Retrieved weights in {(t_end-t_start).total_seconds()} seconds."
        )

        # Define params dict
        self.params = {
            "grid_in": grid_in,
            "grid_out": grid_out,
            "method": method,
            "regridder": weights.regridder,
            "weights": weights,
            "adaptive_masking_threshold": adaptive_masking_threshold,
            "keep_attrs": keep_attrs,
        }

        # Input grid / Dataset
        self.ds = self.params.get("grid_in").ds

        # There is no __str__() method for the Regridder object, so I used its filename attribute,
        # which specifies a default filename (does not correspond with the filename we would give the weight file).
        # todo: Better option might be to have the Weights class extend the Regridder class or to define
        #  a __str__() method for the Weights class.
        logger.debug(
            "Resolved parameters: grid_in: {}, grid_out: {}, regridder: {}".format(
                self.params.get("grid_in").__str__(),
                self.params.get("grid_out").__str__(),
                self.params.get("regridder").filename,
            )
        )

    def _get_file_namer(self) -> object:
        """Return the appropriate file namer object."""
        # "extra" is what will go at the end of the file name before .nc
        extra = "_regrid-{}-{}".format(
            self.params.get("method"), self.params.get("grid_out").__str__()
        )

        namer = get_file_namer(self._file_namer)(extra=extra)

        return namer

    def _calculate(self):
        """
        Process the regridding request, calls clisops.core.regrid.regrid().

        Returns the resulting xarray.Dataset.
        """
        # the result is saved by the process() method on the base class
        regridded_ds = core_regrid(
            self.params.get("grid_in", None),
            self.params.get("grid_out", None),
            self.params.get("weights", None),
            self.params.get("adaptive_masking_threshold", None),
            self.params.get("keep_attrs", None),
        )

        return regridded_ds


[docs] def regrid( ds: Union[xr.Dataset, str, Path], *, method: Optional[str] = "nearest_s2d", adaptive_masking_threshold: Optional[Union[int, float]] = 0.5, grid: Optional[ Union[xr.Dataset, xr.DataArray, int, float, tuple, str] ] = "adaptive", output_dir: Optional[Union[str, Path]] = None, output_type: Optional[str] = "netcdf", split_method: Optional[str] = "time:auto", file_namer: Optional[str] = "standard", keep_attrs: Optional[Union[bool, str]] = True, ) -> List[Union[xr.Dataset, str]]: """Regrid specified input file or xarray object. Parameters ---------- ds : Union[xr.Dataset, str] method : {"nearest_s2d", "conservative", "patch", "bilinear"} adaptive_masking_threshold : Optional[Union[int, float]] grid : Union[xr.Dataset, xr.DataArray, int, float, tuple, str] output_dir : Optional[Union[str, Path]] = None output_type : {"netcdf", "nc", "zarr", "xarray"} split_method : {"time:auto"} file_namer : {"standard", "simple"} keep_attrs : {True, False, "target"} Returns ------- List[Union[xr.Dataset, str]] A list of the regridded outputs in the format selected; str corresponds to file paths if the output format selected is a file. Examples -------- | ds: xarray Dataset or "cmip5.output1.MOHC.HadGEM2-ES.rcp85.mon.atmos.Amon.r1i1p1.latest.tas" | method: "nearest_s2d" | adaptive_masking_threshold: | grid: "1deg" | output_dir: "/cache/wps/procs/req0111" | output_type: "netcdf" | split_method: "time:auto" | file_namer: "standard" | keep_attrs: True """ op = Regrid(**locals()) return op.process()