"""Base class for all Operations in clisops."""
from collections import ChainMap
from pathlib import Path
import xarray as xr
from loguru import logger
from clisops.utils.common import expand_wildcards
from clisops.utils.dataset_utils import open_xr_dataset
from clisops.utils.file_namers import get_file_namer
from clisops.utils.output_utils import fix_netcdf_attrs_encoding, get_output, get_time_slices
[docs]
class Operation:
"""
Base class for all Operations.
This class provides the common interface and functionality for all operations in clisops.
Parameters
----------
ds : str or Path or xr.Dataset
The input dataset, which can be a path to a file or an xarray Dataset.
file_namer : str, optional
The file namer to use for output files. Default is "standard".
split_method : str, optional
The method to use for splitting the dataset into time slices. Default is "time:auto".
output_dir : str or Path or None, optional
The directory where output files will be saved. If None, no files will be saved. Default is None.
output_type : str, optional
The type of output to generate. Can be "netcdf", "zarr", or "xarray". Default is "netcdf".
**params : dict, optional
Additional parameters specific to the operation. These will be resolved in `self._resolve_params()`.
"""
def __init__(
self,
ds,
file_namer: str = "standard",
split_method: str = "time:auto",
output_dir: str | Path | None = None,
output_type: str = "netcdf",
**params,
):
"""
Constructor for each operation.
Sets common input parameters as attributes.
Parameters that are specific to each operation are handled in `self._resolve_params()`
"""
self._file_namer = file_namer
self._split_method = split_method
self._output_dir = output_dir
self._output_type = output_type
self._resolve_dsets(ds)
self._resolve_params(**params)
[docs]
def _resolve_dsets(self, ds):
"""
Take in the `ds` object and load it as an xarray Dataset if it
is a path/wildcard. Set the result to `self.ds`.
"""
if isinstance(ds, (str, Path)):
ds = expand_wildcards(ds)
ds = open_xr_dataset(ds)
self.ds = ds
[docs]
def _resolve_params(self, **params) -> None:
"""Resolve the operation-specific input parameters to `self.params`."""
self.params = params
[docs]
def _get_file_namer(self):
"""Return the appropriate file namer object."""
namer = get_file_namer(self._file_namer)()
return namer
[docs]
def _calculate(self):
"""The `_calculate()` method is implemented within each operation subclass."""
raise NotImplementedError()
[docs]
def _remove_str_compression(self, ds):
"""
netCDF4 datatypes of variable length are decoded to str by xarray<2023.11.0.
As of xarray 2023.11.0 they are decoded to one of np.dtypes.StrDType (eg. "<U20")
of variable length and stripped of all encoding settings. In netcdf-c versions >= 4.9.0
and xarray < 2023.11.0 the latter part needs to be conducted manually to avoid an Exception
when writing the xarray.Dataset to disk.
See issue: https://github.com/Unidata/netcdf4-python/issues/1205
See PR: https://github.com/roocs/clisops/pull/319.
"""
if isinstance(ds, xr.Dataset):
varlist = list(ds.coords) + list(ds.data_vars)
elif isinstance(ds, xr.DataArray):
varlist = list(ds.coords)
for var in varlist:
if "dtype" in ds[var].encoding:
if isinstance(ds[var].encoding["dtype"], str):
for en in [
"compression",
"complevel",
"shuffle",
"fletcher32",
"endian",
"zlib",
]:
if en in ds[var].encoding:
del ds[var].encoding[en]
return ds
[docs]
def _fix_netcdf_attrs_encoding(self, ds):
"""Executes output_utils.fix_netcdf_attrs_encoding for xarray.Datasets"""
if isinstance(ds, xr.Dataset):
ds = fix_netcdf_attrs_encoding(ds)
return ds
[docs]
def _cap_deflate_level(self, ds):
"""
For CMOR3 / CMIP6 it was investigated which netCDF4 deflate_level should be set to optimize
the balance between reduction of file size and degradation in performance. The values found
were deflate_level=1, shuffle=True. To keep the write times at a minimum, compression level 1
is not exceeded.
See issue: https://github.com/PCMDI/cmor/issues/403.
"""
if isinstance(ds, xr.Dataset):
varlist = list(ds.coords) + list(ds.data_vars)
elif isinstance(ds, xr.DataArray):
varlist = list(ds.coords)
for var in varlist:
complevel = ds[var].encoding.get("complevel", 0)
compression = ds[var].encoding.get("compression_opts", 0)
if complevel > 1:
ds[var].encoding["complevel"] = 1
elif compression > 1:
ds[var].encoding["compression_opts"] = 1
return ds
[docs]
@staticmethod
def _remove_redundant_fill_values(ds):
"""
Get coordinate and data variables and remove fill values added by xarray.
CF-Conventions say that coordinate variables cannot have missing values.
See Also
--------
https://github.com/roocs/clisops/issues/224
"""
if isinstance(ds, xr.Dataset):
var_list = list(ds.coords) + list(ds.data_vars)
elif isinstance(ds, xr.DataArray):
var_list = list(ds.coords)
else:
raise ValueError(f"Expected xarray.Dataset or xarray.DataArray, got {type(ds)}")
for var in var_list:
fval = ChainMap(ds[var].attrs, ds[var].encoding).get("_FillValue", None)
mval = ChainMap(ds[var].attrs, ds[var].encoding).get("missing_value", None)
if not fval and not mval:
ds[var].encoding["_FillValue"] = None
elif not mval:
ds[var].encoding["missing_value"] = fval
ds[var].encoding["_FillValue"] = fval
ds[var].attrs.pop("_FillValue", None)
elif not fval:
ds[var].encoding["_FillValue"] = mval
ds[var].encoding["missing_value"] = mval
ds[var].attrs.pop("missing_value", None)
else:
# Issue 308 - Assert missing_value and _FillValue are the same
if fval != mval:
ds[var].encoding["_FillValue"] = mval
ds[var].encoding["missing_value"] = mval
ds[var].attrs.pop("missing_value", None)
ds[var].attrs.pop("_FillValue", None)
logger.warning(
f"The defined _FillValue and missing_value for '{var}' are not the same "
f"'{fval}' != '{mval}'. Setting '{mval}' for both."
)
return ds
[docs]
@staticmethod
def _remove_redundant_coordinates_attr(ds):
"""
Remove the coordinate attribute added by xarray.
See Also
--------
https://github.com/roocs/clisops/issues/224
Examples
--------
If you have a dataset with a time_bnds variable that has a coordinate attribute:
.. code-block:: cpp
double time_bnds(time, bnds);
time_bnds:coordinates = "height";
Programs like `cdo` will complain about this:
.. code-block:: shell
Warning (cdf_set_var): Inconsistent variable definition for time_bnds!
"""
if isinstance(ds, xr.Dataset):
var_list = list(ds.coords) + list(ds.data_vars)
elif isinstance(ds, xr.DataArray):
var_list = list(ds.coords)
else:
raise ValueError(f"Expected xarray.Dataset or xarray.DataArray, got {type(ds)}")
for var in var_list:
c_attr = ChainMap(ds[var].attrs, ds[var].encoding).get("coordinates", None)
if not c_attr:
ds[var].encoding["coordinates"] = None
else:
ds[var].encoding["coordinates"] = c_attr
ds[var].attrs.pop("coordinates", None)
return ds
[docs]
def process(self) -> list[xr.Dataset | Path]:
"""
Main processing method used by all subclasses.
Returns
-------
list of xr.Dataset or Path
A list of outputs, which might be NetCDF file paths, Zarr file paths, or xarray.Dataset.
"""
# Create an empty list for outputs
outputs = list()
# Get the file namer object for naming output files
# NOTE: It won't be used if the output type required is "xarray"
namer = self._get_file_namer()
# Process the xarray Dataset - this will (usually) be lazily evaluated so
# no actual data will be read
processed_ds = self._calculate()
# remove fill values from lat/lon/time if required
processed_ds = self._remove_redundant_fill_values(processed_ds)
# remove redundant coordinates from bounds
processed_ds = self._remove_redundant_coordinates_attr(processed_ds)
# remove compression for string variables (as it is not supported by netcdf-c >= 4.9.0)
processed_ds = self._remove_str_compression(processed_ds)
# cap deflate level at 1
processed_ds = self._cap_deflate_level(processed_ds)
# fix string encoding of xarray.Dataset.attrs (incl. variable attrs)
processed_ds = self._fix_netcdf_attrs_encoding(processed_ds)
# Work out how many outputs should be created based on the size
# of the array. Manage this as a list of time slices.
time_slices = get_time_slices(processed_ds, self._split_method)
# Loop through each time slice
for tslice in time_slices:
# If there is only one time slice, and it is None:
# - then just set the result Dataset to the processed Dataset
if tslice is None:
result_ds = processed_ds
# If there is a time slice then extract the time slice from the
# processed Dataset
else:
result_ds = processed_ds.sel(time=slice(tslice[0], tslice[1]))
logger.info(f"Processing {self.__class__.__name__} for times: {tslice}")
# Get the output (file or xarray Dataset)
# When this is a file: xarray will read all the data and write the file
output = get_output(result_ds, self._output_type, self._output_dir, namer)
outputs.append(output)
return outputs