"""Provides ESMF representations of grids/UGRID meshes and a modified regridder."""
import numpy as np
from numpy import ma
import scipy.sparse
from scipy.sparse import safely_cast_index_arrays
import esmf_regrid
from esmf_regrid import Constants, check_method, check_norm
from . import esmpy
from ._esmf_sdo import GridInfo, RefinedGridInfo
__all__ = [
"GridInfo",
"RefinedGridInfo",
"Regridder",
]
ESMF_NO_VERSION = "N/A"
def _get_regrid_weights_dict(src_field, tgt_field, regrid_method, esmf_args=None):
if esmf_args is None:
esmf_args = {}
else:
esmf_args = esmf_args.copy()
# Provide default values
if "ignore_degenerate" not in esmf_args:
esmf_args["ignore_degenerate"] = True
if "unmapped_action" not in esmf_args:
esmf_args["unmapped_action"] = esmpy.UnmappedAction.IGNORE
# The value, in array form, that ESMF should treat as an affirmative mask.
expected_mask = np.array([True])
regridder = esmpy.Regrid(
src_field,
tgt_field,
regrid_method=regrid_method,
# Choosing the norm_type DSTAREA allows for mdtol type operations
# to be performed using the weights information later on.
norm_type=esmpy.NormType.DSTAREA,
src_mask_values=expected_mask,
dst_mask_values=expected_mask,
factors=True,
**esmf_args,
)
# Without specifying deep_copy=true, the information in weights_dict
# would be corrupted when the ESMF regridder is destroyed.
weights_dict = regridder.get_weights_dict(deep_copy=True)
# The weights_dict contains all the information needed for regridding,
# the ESMF objects can be safely removed.
regridder.destroy()
return weights_dict
def _weights_dict_to_sparse_array(weights, shape, index_offsets):
matrix = scipy.sparse.csr_matrix(
(
weights["weights"],
(
weights["row_dst"] - index_offsets[0],
weights["col_src"] - index_offsets[1],
),
),
shape=shape,
)
return matrix
[docs]
class Regridder:
"""Regridder for directly interfacing with :mod:`esmpy`."""
def __init__(
self,
src,
tgt,
method=Constants.Method.CONSERVATIVE,
precomputed_weights=None,
esmf_args=None,
):
"""Create a regridder from descriptions of horizontal grids/meshes.
Weights will be calculated using :mod:`esmpy` and stored as a
:class:`scipy.sparse.csr_matrix`
for use in regridding. If precomputed weights are provided,
these will be used instead of calculating via :mod:`esmpy`.
Parameters
----------
src : :class:`~esmf_regrid.experimental.unstructured_regrid.MeshInfo` or :class:`GridInfo`
Describes the source mesh/grid.
Data supplied to this regridder should be in a :class:`numpy.ndarray`
whose shape is compatible with ``src``.
tgt : :class:`~esmf_regrid.experimental.unstructured_regrid.MeshInfo` or :class:`GridInfo`
Describes the target mesh/grid.
Data output by this regridder will be a :class:`numpy.ndarray` whose
shape is compatible with ``tgt``.
method : :class:`Constants.Method`
The method to be used to calculate weights.
precomputed_weights : :class:`scipy.sparse.spmatrix`, optional
If ``None``, :mod:`esmpy` will be used to
calculate regridding weights. Otherwise, :mod:`esmpy` will be bypassed
and ``precomputed_weights`` will be used as the regridding weights.
esmf_args : dict, optional
A dictionary of arguments to pass to ESMF.
"""
self.src = src
self.tgt = tgt
# check method is correct type
self.method = check_method(method)
self.esmf_regrid_version = esmf_regrid.__version__
if precomputed_weights is None:
self.esmf_version = esmpy.__version__
src_field = src.make_esmf_field()
src_sdo = src_field.grid
tgt_field = tgt.make_esmf_field()
tgt_sdo = tgt_field.grid
try:
weights_dict = _get_regrid_weights_dict(
src_field,
tgt_field,
regrid_method=method.value,
esmf_args=esmf_args,
)
finally:
src_field.destroy()
src_sdo.destroy()
tgt_field.destroy()
tgt_sdo.destroy()
self.weight_matrix = _weights_dict_to_sparse_array(
weights_dict,
(self.tgt._refined_size, self.src._refined_size),
(self.tgt.index_offset, self.src.index_offset),
)
if isinstance(tgt, RefinedGridInfo):
# At this point, the weight matrix represents more target points than
# tgt represents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = (
tgt._collapse_weights(is_tgt=True) @ self.weight_matrix
)
if isinstance(src, RefinedGridInfo):
# At this point, the weight matrix represents more source points than
# src represents. In order to collapse these points, we collapse the
# weights matrix by the appropriate matrix multiplication.
self.weight_matrix = self.weight_matrix @ src._collapse_weights(
is_tgt=False
)
else:
if not scipy.sparse.issparse(precomputed_weights):
e_msg = "Precomputed weights must be given as a sparse matrix."
raise ValueError(e_msg)
if precomputed_weights.shape != (self.tgt.size, self.src.size):
msg = "Expected precomputed weights to have shape {}, got shape {} instead."
raise ValueError(
msg.format(
(self.tgt.size, self.src.size),
precomputed_weights.shape,
)
)
self.esmf_version = ESMF_NO_VERSION
self.weight_matrix = precomputed_weights
self.minimal_regridder = MinimalRegridder(
self.src.shape, self.tgt.shape, method, self.weight_matrix
)
[docs]
def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""Perform regridding on an array of data.
Parameters
----------
src_array : :obj:`~numpy.typing.ArrayLike`
Array whose shape is compatible with ``self.src``.
norm_type : :class:`Constants.NormType`
Either ``Constants.NormType.FRACAREA`` or ``Constants.NormType.DSTAREA``.
Determines the type of normalisation applied to the weights.
mdtol : float, default=1
A number between 0 and 1 describing the missing data tolerance.
Depending on the value of ``mdtol``, if a cell in the target grid is not
sufficiently covered by unmasked cells of the source grid, then it will
be masked. ``mdtol=1`` means that only target cells which are not
covered at all will be masked, ``mdtol=0`` means that all target
cells that are not entirely covered will be masked, and ``mdtol=0.5``
means that all target cells that are less than half covered will
be masked.
Returns
-------
:obj:`~numpy.typing.ArrayLike`
An array whose shape is compatible with ``self.tgt``.
"""
return self.minimal_regridder.regrid(
src_array, norm_type=norm_type, mdtol=mdtol
)
class MinimalRegridder:
def __init__(self, src_shape, tgt_shape, method, weights):
"""Create a minimal version of the regridder.
This regridder object contains only the information required to perform
regridding on numpy arrays.
Parameters
----------
src_shape : tuple of int
Shape of the source array.
tgt_shape : tuple of int
Shape of the target array.
method : :class:`Constants.Method`
The method to be used to calculate weights.
weights : :class:`scipy.sparse.spmatrix`
The weights matrix to apply.
"""
self.src_shape = src_shape
self.src_size = np.prod(src_shape)
self.src_dims = len(src_shape)
self.tgt_shape = tgt_shape
self.method = method
self.weight_matrix = weights
def _array_to_matrix(self, array):
"""Reshape data to a form that is compatible with weight matrices.
The data should be presented in the form of a matrix (i.e. 2D) in order
to be compatible with the weight matrix.
Weight matrices deriving from ESMF use fortran ordering when flattening
grids to determine cell indices so we use the same order for reshaping.
We then take the transpose so that matrix multiplication happens over
the appropriate axes.
"""
return array.T.reshape((self.src_size, -1))
def _matrix_to_array(self, array, extra_dims):
"""Reshape data to restore original dimensions.
This is the inverse operation of `_array_to_matrix`.
"""
return array.reshape((extra_dims + self.tgt_shape)[::-1]).T
def _gen_weights_and_data(self, src_array):
extra_shape = src_array.shape[: -self.src_dims]
if self.method == Constants.Method.NEAREST:
weight_matrix = self.weight_matrix.astype(src_array.dtype)
else:
weight_matrix = self.weight_matrix
flat_src = self._array_to_matrix(ma.filled(src_array, 0.0))
flat_tgt = weight_matrix @ flat_src
src_inverted_mask = self._array_to_matrix(~ma.getmaskarray(src_array))
weight_sums = weight_matrix @ src_inverted_mask
return weight_sums, flat_tgt, extra_shape
def _regrid_from_weights_and_data(
self,
tgt_weights,
tgt_data,
extra,
norm_type=Constants.NormType.FRACAREA,
mdtol=1,
):
# Set the minimum mdtol to be slightly higher than 0 to account for rounding
# errors.
mdtol = max(mdtol, 1e-8)
tgt_mask = tgt_weights > 1 - mdtol
normalisations = np.ones_like(tgt_data)
if self.method != Constants.Method.NEAREST:
masked_weight_sums = tgt_weights * tgt_mask
if norm_type == Constants.NormType.FRACAREA:
normalisations[tgt_mask] /= masked_weight_sums[tgt_mask]
elif norm_type == Constants.NormType.DSTAREA:
pass
normalisations = ma.array(normalisations, mask=np.logical_not(tgt_mask))
tgt_array = tgt_data * normalisations
tgt_array = self._matrix_to_array(tgt_array, extra)
return tgt_array
def _out_dtype(self, in_dtype):
"""Return the expected output dtype for a given input dtype."""
if self.method == Constants.Method.NEAREST:
out_dtype = in_dtype
else:
weight_matrix = self.weight_matrix
weight_dtype = weight_matrix.dtype
out_dtype = (
np.ones(1, dtype=in_dtype) * np.ones(1, dtype=weight_dtype)
).dtype
return out_dtype
def regrid(self, src_array, norm_type=Constants.NormType.FRACAREA, mdtol=1):
"""Perform regridding on an array of data.
Parameters
----------
src_array : :obj:`~numpy.typing.ArrayLike`
Array whose shape is compatible with ``self.src``.
norm_type : :class:`Constants.NormType`
Either ``Constants.NormType.FRACAREA`` or ``Constants.NormType.DSTAREA``.
Determines the type of normalisation applied to the weights.
mdtol : float, default=1
A number between 0 and 1 describing the missing data tolerance.
Depending on the value of ``mdtol``, if a cell in the target grid is not
sufficiently covered by unmasked cells of the source grid, then it will
be masked. ``mdtol=1`` means that only target cells which are not
covered at all will be masked, ``mdtol=0`` means that all target
cells that are not entirely covered will be masked, and ``mdtol=0.5``
means that all target cells that are less than half covered will
be masked.
Returns
-------
:obj:`~numpy.typing.ArrayLike`
An array whose shape is compatible with ``self.tgt``.
"""
# Sets default value, as this can't be done with class attributes within method call
norm_type = check_norm(norm_type)
array_shape = src_array.shape
main_shape = array_shape[-self.src_dims :]
if main_shape != self.src_shape:
e_msg = (
f"Expected an array whose shape ends in {self.src_shape}, "
f"got an array with shape ending in {main_shape}."
)
raise ValueError(e_msg)
tgt_weights, tgt_data, extra = self._gen_weights_and_data(src_array)
tgt_array = self._regrid_from_weights_and_data(
tgt_weights, tgt_data, extra, norm_type=norm_type, mdtol=mdtol
)
return tgt_array