"""Provides representations of ESMF's Spatial Discretisation Objects."""
from abc import ABC, abstractmethod
import cartopy.crs as ccrs
import numpy as np
import scipy.sparse
from . import esmpy
class SDO(ABC):
"""
Abstract base class for handling spatial discretisation objects.
This contains shared things for representing the three spatial discretisation
objects supported by ESMPy, Grids, Meshes, and LocStreams.
"""
def __init__(self, shape, index_offset, field_kwargs, mask=None):
self._shape = shape
self._index_offset = index_offset
self._field_kwargs = field_kwargs
self._mask = mask
@abstractmethod
def _make_esmf_sdo(self):
pass
def make_esmf_field(self):
"""Return an ESMF field representing the spatial discretisation object."""
sdo = self._make_esmf_sdo()
field = esmpy.Field(sdo, **self._field_kwargs)
return field
@property
def shape(self):
"""Return shape."""
return self._shape
@property
def _refined_shape(self):
"""Return shape passed to ESMF."""
return self._shape
@property
def _refined_mask(self):
"""Return mask passed to ESMF."""
return self._mask
@property
def dims(self):
"""Return number of dimensions."""
return len(self._shape)
@property
def size(self):
"""Return the number of cells in the sdo."""
return np.prod(self._shape)
@property
def _refined_size(self):
"""Return the number of cells passed to ESMF."""
return np.prod(self._refined_shape)
@property
def index_offset(self):
"""Return the index offset."""
return self._index_offset
@property
def mask(self):
"""Return the mask."""
return self._mask
def _array_to_matrix(self, array):
"""
Reshape data to a form that is compatible with weight matrices.
The data should be presented in the form of a matrix (i.e. 2D) in order
to be compatible with the weight matrix.
Weight matrices deriving from ESMF use fortran ordering when flattening
grids to determine cell indices so we use the same order for reshaping.
We then take the transpose so that matrix multiplication happens over
the appropriate axes.
"""
return array.T.reshape((self.size, -1))
def _matrix_to_array(self, array, extra_dims):
"""
Reshape data to restore original dimensions.
This is the inverse operation of `_array_to_matrix`.
"""
return array.reshape((extra_dims + self._shape)[::-1]).T
[docs]
class GridInfo(SDO):
"""
Class for handling structured grids.
This class holds information about lat-lon type grids. That is, grids
defined by lists of latitude and longitude values for points/bounds
(with respect to some coordinate reference system i.e. rotated pole).
It contains methods for translating this information into :mod:`esmpy` objects.
In particular, there are methods for representing as a
:class:`esmpy.api.grid.Grid` and
as a :class:`esmpy.api.field.Field` containing that
:class:`~esmpy.api.grid.Grid`. This esmpy :class:`~esmpy.api.field.Field`
is designed to
contain enough information for area weighted regridding and may be
inappropriate for other :mod:`esmpy` regridding schemes.
"""
def __init__(
self,
lons,
lats,
lonbounds,
latbounds,
crs=None,
circular=False,
areas=None,
mask=None,
center=False,
):
"""
Create a :class:`GridInfo` object describing the grid.
Parameters
----------
lons : :obj:`~numpy.typing.ArrayLike`
A 1D or 2D array or list describing the longitudes of the
grid points.
lats : :obj:`~numpy.typing.ArrayLike`
A 1D or 2D array or list describing the latitudes of the
grid points.
lonbounds : :obj:`~numpy.typing.ArrayLike`
A 1D or 2D array or list describing the longitude bounds of
the grid. Should have length one greater than ``lons``.
latbounds : :obj:`~numpy.typing.ArrayLike`
A 1D or 2D array or list describing the latitude bounds of
the grid. Should have length one greater than ``lats``.
crs : :class:`cartopy.crs.CRS`, optional
Describes how to interpret the
above arguments. If ``None``, defaults to :class:`~cartopy.crs.Geodetic`.
circular : bool, default=False
Describes if the final longitude bounds should
be considered contiguous with the first.
areas : :obj:`~numpy.typing.ArrayLike`, optional
Array describing the areas associated with
each face. If ``None``, then :mod:`esmpy` will use its own
calculated areas.
mask: :obj:`~numpy.typing.ArrayLike`, optional
Array describing which elements :mod:`esmpy` will ignore.
center : bool, default=False
Describes if the center points of the grid cells are used in regridding
calculations.
"""
self.lons = lons
self.lats = lats
londims = len(self.lons.shape)
if len(lonbounds.shape) != londims:
msg = (
f"The dimensionality of longitude bounds "
f"({len(lonbounds.shape)}) is incompatible with the "
f"dimensionality of the longitude ({londims})."
)
raise ValueError(msg)
latdims = len(self.lats.shape)
if len(latbounds.shape) != latdims:
msg = (
f"The dimensionality of latitude bounds "
f"({len(latbounds.shape)}) is incompatible with the "
f"dimensionality of the latitude ({latdims})."
)
raise ValueError(msg)
if londims != latdims:
msg = (
f"The dimensionality of the longitude "
f"({londims}) is incompatible with the "
f"dimensionality of the latitude ({latdims})."
)
raise ValueError(msg)
if londims not in (1, 2):
msg = (
f"Expected a latitude/longitude with a dimensionality "
f"of 1 or 2, got {londims}."
)
raise ValueError(msg)
if londims == 1:
shape = (len(lats), len(lons))
else:
shape = self.lons.shape
self.lonbounds = lonbounds
self._refined_lonbounds = lonbounds
self.latbounds = latbounds
self._refined_latbounds = latbounds
if crs is None:
self.crs = ccrs.Geodetic()
else:
self.crs = crs
self.circular = circular
self.areas = areas
self.center = center
super().__init__(
shape=shape,
index_offset=1,
field_kwargs={"staggerloc": esmpy.StaggerLoc.CENTER},
mask=mask,
)
def _as_esmf_info(self):
shape = np.array(self._refined_shape)
londims = len(self.lons.shape)
if londims == 1:
if self.circular:
adjustedlonbounds = self._refined_lonbounds[:-1]
else:
adjustedlonbounds = self._refined_lonbounds
centerlons, centerlats = np.meshgrid(self.lons, self.lats)
cornerlons, cornerlats = np.meshgrid(
adjustedlonbounds, self._refined_latbounds
)
elif londims == 2:
if self.circular:
slice = np.s_[:, :-1]
else:
slice = np.s_[:]
centerlons = self.lons[slice]
centerlats = self.lats[slice]
cornerlons = self._refined_lonbounds[slice]
cornerlats = self._refined_latbounds[slice]
truecenters = ccrs.Geodetic().transform_points(self.crs, centerlons, centerlats)
truecorners = ccrs.Geodetic().transform_points(self.crs, cornerlons, cornerlats)
# The following note in xESMF suggests that the arrays passed to ESMPy ought to
# be fortran ordered:
# https://xesmf.readthedocs.io/en/latest/internal_api.html#xesmf.backend.warn_f_contiguous
# It is yet to be determined what effect this has on performance.
truecenterlons = np.asfortranarray(truecenters[..., 0])
truecenterlats = np.asfortranarray(truecenters[..., 1])
truecornerlons = np.asfortranarray(truecorners[..., 0])
truecornerlats = np.asfortranarray(truecorners[..., 1])
info = (
shape,
truecenterlons,
truecenterlats,
truecornerlons,
truecornerlats,
self.circular,
self.areas,
)
return info
def _make_esmf_sdo(self):
info = self._as_esmf_info()
(
shape,
truecenterlons,
truecenterlats,
truecornerlons,
truecornerlats,
circular,
areas,
) = info
if circular:
grid = esmpy.Grid(
shape,
pole_kind=[1, 1],
num_peri_dims=1,
periodic_dim=1,
pole_dim=0,
)
else:
grid = esmpy.Grid(shape, pole_kind=[1, 1])
grid.add_coords(staggerloc=esmpy.StaggerLoc.CORNER)
grid_corner_x = grid.get_coords(0, staggerloc=esmpy.StaggerLoc.CORNER)
grid_corner_x[:] = truecornerlons
grid_corner_y = grid.get_coords(1, staggerloc=esmpy.StaggerLoc.CORNER)
grid_corner_y[:] = truecornerlats
# Grid center points are added here, this is not necessary for
# conservative area weighted regridding
if self.center:
grid.add_coords(staggerloc=esmpy.StaggerLoc.CENTER)
grid_center_x = grid.get_coords(0, staggerloc=esmpy.StaggerLoc.CENTER)
grid_center_x[:] = truecenterlons
grid_center_y = grid.get_coords(1, staggerloc=esmpy.StaggerLoc.CENTER)
grid_center_y[:] = truecenterlats
def add_get_item(grid, **kwargs):
grid.add_item(**kwargs)
return grid.get_item(**kwargs)
if self.mask is not None:
grid_mask = add_get_item(
grid, item=esmpy.GridItem.MASK, staggerloc=esmpy.StaggerLoc.CENTER
)
grid_mask[:] = self._refined_mask
if areas is not None:
grid_areas = add_get_item(
grid, item=esmpy.GridItem.AREA, staggerloc=esmpy.StaggerLoc.CENTER
)
grid_areas[:] = areas.T
return grid
[docs]
class RefinedGridInfo(GridInfo):
"""
Class for handling structured grids represented in :mod:`esmpy` in higher resolution.
A specialised version of :class:`GridInfo`. Designed to provide higher
accuracy conservative regridding for rectilinear grids, especially those with
particularly large cells which may not be well represented by :mod:`esmpy`. This
class differs from :class:`GridInfo` primarily in the way it represents itself
as a :class:`~esmpy.api.field.Field` in :mod:`esmpy`. This :class:`~esmpy.api.field.Field`
is designed to be a higher resolution version of the given grid and should
contain enough information for area weighted regridding but may be
inappropriate for other :mod:`esmpy` regridding schemes.
"""
def __init__(
self,
lonbounds,
latbounds,
resolution=3,
crs=None,
mask=None,
):
"""
Create a :class:`RefinedGridInfo` object describing the grid.
Parameters
----------
lonbounds : :obj:`~numpy.typing.ArrayLike`
A 1D array or list describing the longitude bounds of the grid.
Must be strictly increasing (for example, if a bound goes from
170 to -170 consider transposing -170 to 190).
latbounds : :obj:`~numpy.typing.ArrayLike`
A 1D array or list describing the latitude bounds of the grid.
Must be strictly increasing.
resolution : int, default=400
A number describing how many latitude slices each cell should
be divided into when passing a higher resolution grid to ESMF.
crs : :class:`cartopy.crs.CRS`, optional
Describes how to interpret the
above arguments. If ``None``, defaults to :class:`~cartopy.crs.Geodetic`.
"""
# Convert bounds to numpy arrays where necessary.
if not isinstance(lonbounds, np.ndarray):
lonbounds = np.array(lonbounds)
if not isinstance(latbounds, np.ndarray):
latbounds = np.array(latbounds)
# Ensure bounds are strictly increasing.
if not np.all(lonbounds[:-1] < lonbounds[1:]):
raise ValueError("The longitude bounds must be strictly increasing.")
if not np.all(latbounds[:-1] < latbounds[1:]):
raise ValueError("The latitude bounds must be strictly increasing.")
self.resolution = resolution
self.n_lons_orig = len(lonbounds) - 1
self.n_lats_orig = len(latbounds) - 1
# Create dummy lat/lon values
lons = np.zeros(self.n_lons_orig)
lats = np.zeros(self.n_lats_orig)
super().__init__(lons, lats, lonbounds, latbounds, crs=crs, mask=mask)
if self.n_lats_orig == 1 and np.allclose(latbounds, [-90, 90]):
self._refined_latbounds = np.array([-90, 0, 90])
self._refined_lonbounds = lonbounds
else:
self._refined_latbounds = latbounds
self._refined_lonbounds = np.append(
np.linspace(
lonbounds[:-1],
lonbounds[1:],
self.resolution,
endpoint=False,
axis=1,
).flatten(),
lonbounds[-1],
)
self.lon_expansion = int(
(len(self._refined_lonbounds) - 1) / (len(self.lonbounds) - 1)
)
self.lat_expansion = int(
(len(self._refined_latbounds) - 1) / (len(self.latbounds) - 1)
)
@property
def _refined_shape(self):
"""Return shape passed to ESMF."""
return (
self.n_lats_orig * self.lat_expansion,
self.n_lons_orig * self.lon_expansion,
)
@property
def _refined_mask(self):
"""Return mask passed to ESMF."""
new_mask = np.broadcast_to(
self.mask[:, np.newaxis, :, np.newaxis],
[
self.n_lats_orig,
self.lat_expansion,
self.n_lons_orig,
self.lon_expansion,
],
)
new_mask = new_mask.reshape(self._refined_shape)
return new_mask
def _collapse_weights(self, is_tgt):
"""
Return a matrix to collapse the weight matrix.
The refined grid may contain more cells than the represented grid. When this is
the case, the generated weight matrix will refer to too many points and will have
to be collapsed. This is done by multiplying by this matrix, pre-multiplying when
the target grid is represented and post multiplying when the source grid is
represented.
Parameters
----------
is_tgt : bool
True if the target field is being represented, False otherwise.
"""
# The column indices represent each of the cells in the refined grid.
column_indices = np.arange(self._refined_size)
# The row indices represent the cells of the unrefined grid. These are broadcast
# so that each row index coincides with all column indices of the refined cells
# which the unrefined cell is split into.
if self.lat_expansion > 1:
# The latitudes are expanded only in the case where there is one latitude
# bound from -90 to 90. In this case, there is no longitude expansion.
row_indices = np.empty([self.n_lons_orig, self.lat_expansion])
row_indices[:] = np.arange(self.n_lons_orig)[:, np.newaxis]
else:
# The row indices are broadcast across a dimension representing the expansion
# of the longitude. Each row index is broadcast and flattened so that all the
# row indices representing the unrefined cell match up with the column indices
# representing the refined cells it is split into.
row_indices = np.empty(
[self.n_lons_orig, self.lon_expansion, self.n_lats_orig]
)
row_indices[:] = np.arange(self.n_lons_orig * self.n_lats_orig).reshape(
[self.n_lons_orig, self.n_lats_orig]
)[:, np.newaxis, :]
row_indices = row_indices.flatten()
matrix_shape = (self.size, self._refined_size)
refinement_weights = scipy.sparse.csr_matrix(
(
np.ones(self._refined_size),
(row_indices, column_indices),
),
shape=matrix_shape,
)
if is_tgt:
# When the RefinedGridInfo is the target of the regridder, we want to take
# the average of the weights of each refined target cell. This is because
# these weights represent the proportion of area of the target cells which
# is covered by a given source cell. Since the refined cells are divided in
# such a way that they have equal area, the weights for the unrefined cells
# can be reconstructed by taking an average. This is done via matrix
# multiplication, with the returned matrix pre-multiplying the weight matrix
# so that it operates on the rows of the weight matrix (representing the
# target cells). At this point the returned matrix consists of ones, so we
# divided by the number of refined cells per unrefined cell.
refinement_weights = refinement_weights / (
self.lon_expansion * self.lat_expansion
)
else:
# When the RefinedGridInfo is the source of the regridder, we want to take
# the sum of the weights of each refined target cell. This is because those
# weights represent the proportion of the area of a given target cell which
# is covered by each refined source cell. The total proportion covered by
# each unrefined source cell is then the sum of the weights from each of its
# refined cells. This sum is done by matrix multiplication, the returned
# matrix post-multiplying the weight matrix so that it operates on the columns
# of the weight matrix (representing the source cells). In order for the
# post-multiplication to work, the returned matrix must be transposed.
refinement_weights = refinement_weights.T
return refinement_weights