Source code for esmf_regrid.experimental.io

"""Provides load/save functions for regridders."""

from contextlib import contextmanager

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
import numpy as np
import scipy.sparse

import esmf_regrid
from esmf_regrid import Constants, _load_context, check_method, esmpy
from esmf_regrid.experimental.unstructured_scheme import (
    GridToMeshESMFRegridder,
    MeshToGridESMFRegridder,
)
from esmf_regrid.schemes import (
    ESMFAreaWeightedRegridder,
    ESMFBilinearRegridder,
    ESMFNearestRegridder,
    GridRecord,
    MeshRecord,
)

SUPPORTED_REGRIDDERS = [
    ESMFAreaWeightedRegridder,
    ESMFBilinearRegridder,
    ESMFNearestRegridder,
    GridToMeshESMFRegridder,
    MeshToGridESMFRegridder,
]
_REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}
_SOURCE_NAME = "regridder_source_field"
_SOURCE_MASK_NAME = "regridder_source_mask"
_TARGET_NAME = "regridder_target_field"
_TARGET_MASK_NAME = "regridder_target_mask"
_WEIGHTS_NAME = "regridder_weights"
_WEIGHTS_SHAPE_NAME = "weights_shape"
_WEIGHTS_ROW_NAME = "weight_matrix_rows"
_WEIGHTS_COL_NAME = "weight_matrix_columns"
_REGRIDDER_TYPE = "regridder_type"
_VERSION_ESMF = "ESMF_version"
_VERSION_INITIAL = "esmf_regrid_version_on_initialise"
_MDTOL = "mdtol"
_METHOD = "method"
_RESOLUTION = "resolution"
_SOURCE_RESOLUTION = "src_resolution"
_TARGET_RESOLUTION = "tgt_resolution"
_ESMF_ARGS = "esmf_args"
_VALID_ESMF_KWARGS = [
    "pole_method",
    "regrid_pole_npoints",
    "line_type",
    "extrap_method",
    "extrap_num_src_pnts",
    "extrap_dist_exponent",
    "extrap_num_levels",
    "unmapped_action",
    "ignore_degenerate",
    "large_file",
]
_POLE_METHOD_DICT = {e.name: e for e in esmpy.PoleMethod}
_LINE_TYPE_DICT = {e.name: e for e in esmpy.LineType}
_EXTRAP_METHOD_DICT = {e.name: e for e in esmpy.ExtrapMethod}
_UNMAPPED_ACTION_DICT = {e.name: e for e in esmpy.UnmappedAction}
_ESMF_ENUM_ARGS = {
    "pole_method": _POLE_METHOD_DICT,
    "line_type": _LINE_TYPE_DICT,
    "extrap_method": _EXTRAP_METHOD_DICT,
    "unmapped_action": _UNMAPPED_ACTION_DICT,
}


def _add_mask_to_cube(mask, cube, name):
    if isinstance(mask, np.ndarray):
        mask = mask.astype(int)
        mask_coord = AuxCoord(mask, var_name=name, long_name=name)
        cube.add_aux_coord(mask_coord, list(range(cube.ndim)))


@contextmanager
def _managed_var_name(src_cube, tgt_cube):
    src_coord_names = []
    src_mesh_coords = []
    if src_cube.mesh is not None:
        src_mesh = src_cube.mesh
        src_mesh_coords = src_mesh.coords()
        for coord in src_mesh_coords:
            src_coord_names.append(coord.var_name)
    tgt_coord_names = []
    tgt_mesh_coords = []
    if tgt_cube.mesh is not None:
        tgt_mesh = tgt_cube.mesh
        tgt_mesh_coords = tgt_mesh.coords()
        for coord in tgt_mesh_coords:
            tgt_coord_names.append(coord.var_name)

    try:
        for coord in src_mesh_coords:
            coord.var_name = "_".join([_SOURCE_NAME, "mesh", coord.name()])
        for coord in tgt_mesh_coords:
            coord.var_name = "_".join([_TARGET_NAME, "mesh", coord.name()])
        yield None
    finally:
        for coord, var_name in zip(src_mesh_coords, src_coord_names, strict=False):
            coord.var_name = var_name
        for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names, strict=False):
            coord.var_name = var_name


def _clean_var_names(cube):
    cube.var_name = None
    for coord in cube.coords():
        coord.var_name = None
    if cube.mesh is not None:
        cube.mesh.var_name = None
        for coord in cube.mesh.coords():
            coord.var_name = None
        for con in cube.mesh.connectivities():
            con.var_name = None


[docs] def save_regridder(rg, filename): """Save a regridder scheme instance. Saves any of the regridder classes, i.e. :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. . Parameters ---------- rg : :class:`~esmf_regrid.schemes._ESMFRegridder` The regridder instance to save. filename : str The file name to save to. """ regridder_type = rg.__class__.__name__ def _standard_grid_cube(grid, name): if grid[0].ndim == 1: shape = [coord.points.size for coord in grid] else: shape = grid[0].shape data = np.zeros(shape) cube = Cube(data, var_name=name, long_name=name) if grid[0].ndim == 1: cube.add_dim_coord(grid[0], 0) cube.add_dim_coord(grid[1], 1) else: cube.add_aux_coord(grid[0], [0, 1]) cube.add_aux_coord(grid[1], [0, 1]) return cube def _standard_mesh_cube(mesh, location, name): mesh_coords = mesh.to_MeshCoords(location) data = np.zeros(mesh_coords[0].points.shape[0]) cube = Cube(data, var_name=name, long_name=name) for coord in mesh_coords: cube.add_aux_coord(coord, 0) return cube if regridder_type in [ "ESMFAreaWeightedRegridder", "ESMFBilinearRegridder", "ESMFNearestRegridder", ]: src_grid = rg._src if isinstance(src_grid, GridRecord): src_cube = _standard_grid_cube( (src_grid.grid_y, src_grid.grid_x), _SOURCE_NAME ) elif isinstance(src_grid, MeshRecord): src_mesh, src_location = src_grid src_cube = _standard_mesh_cube(src_mesh, src_location, _SOURCE_NAME) else: raise ValueError("Improper type for `rg._src`.") _add_mask_to_cube(rg.src_mask, src_cube, _SOURCE_MASK_NAME) tgt_grid = rg._tgt if isinstance(tgt_grid, GridRecord): tgt_cube = _standard_grid_cube( (tgt_grid.grid_y, tgt_grid.grid_x), _TARGET_NAME ) elif isinstance(tgt_grid, MeshRecord): tgt_mesh, tgt_location = tgt_grid tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, _TARGET_NAME) else: raise ValueError("Improper type for `rg._tgt`.") _add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME) elif regridder_type == "GridToMeshESMFRegridder": src_grid = (rg.grid_y, rg.grid_x) src_cube = _standard_grid_cube(src_grid, _SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, _SOURCE_MASK_NAME) tgt_mesh = rg.mesh tgt_location = rg.location tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, _TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME) elif regridder_type == "MeshToGridESMFRegridder": src_mesh = rg.mesh src_location = rg.location src_cube = _standard_mesh_cube(src_mesh, src_location, _SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, _SOURCE_MASK_NAME) tgt_grid = (rg.grid_y, rg.grid_x) tgt_cube = _standard_grid_cube(tgt_grid, _TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, _TARGET_MASK_NAME) else: msg = ( f"Expected a regridder of type `GridToMeshESMFRegridder` or " f"`MeshToGridESMFRegridder`, got type {regridder_type}." ) raise TypeError(msg) method = str(check_method(rg.method).name) if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]: resolution = rg.resolution src_resolution = None tgt_resolution = None elif regridder_type == "ESMFAreaWeightedRegridder": resolution = None src_resolution = rg.src_resolution tgt_resolution = rg.tgt_resolution else: resolution = None src_resolution = None tgt_resolution = None weight_matrix = rg.regridder.weight_matrix reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix) weight_data = reformatted_weight_matrix.data weight_rows = reformatted_weight_matrix.row weight_cols = reformatted_weight_matrix.col weight_shape = reformatted_weight_matrix.shape esmf_version = rg.regridder.esmf_version esmf_regrid_version = rg.regridder.esmf_regrid_version save_version = esmf_regrid.__version__ # Currently, all schemes use the fracarea normalization. normalization = "fracarea" mdtol = rg.mdtol attributes = { "title": "iris-esmf-regrid regridding scheme", _REGRIDDER_TYPE: regridder_type, _VERSION_ESMF: esmf_version, _VERSION_INITIAL: esmf_regrid_version, "esmf_regrid_version_on_save": save_version, "normalization": normalization, _MDTOL: mdtol, _METHOD: method, } if resolution is not None: attributes[_RESOLUTION] = resolution if src_resolution is not None: attributes[_SOURCE_RESOLUTION] = src_resolution if tgt_resolution is not None: attributes[_TARGET_RESOLUTION] = tgt_resolution weights_cube = Cube(weight_data, var_name=_WEIGHTS_NAME, long_name=_WEIGHTS_NAME) row_coord = AuxCoord( weight_rows, var_name=_WEIGHTS_ROW_NAME, long_name=_WEIGHTS_ROW_NAME ) col_coord = AuxCoord( weight_cols, var_name=_WEIGHTS_COL_NAME, long_name=_WEIGHTS_COL_NAME ) weights_cube.add_aux_coord(row_coord, 0) weights_cube.add_aux_coord(col_coord, 0) esmf_args = rg.esmf_args if esmf_args is None: esmf_args = {} for arg in esmf_args: if arg not in _VALID_ESMF_KWARGS: raise KeyError(f"{arg} is not considered a valid argument to pass to ESMF.") esmf_arg_attributes = { k: v.name if hasattr(v, "name") else int(v) if isinstance(v, bool) else v for k, v in esmf_args.items() } esmf_arg_coord = AuxCoord( 0, var_name=_ESMF_ARGS, long_name=_ESMF_ARGS, attributes=esmf_arg_attributes ) weights_cube.add_aux_coord(esmf_arg_coord) weight_shape_cube = Cube( weight_shape, var_name=_WEIGHTS_SHAPE_NAME, long_name=_WEIGHTS_SHAPE_NAME, ) # Save cubes while ensuring var_names do not conflict for the sake of consistency. with _managed_var_name(src_cube, tgt_cube): cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) for cube in cube_list: cube.attributes = attributes iris.fileformats.netcdf.save(cube_list, filename)
[docs] def load_regridder(filename): """Load a regridder scheme instance. Loads any of the regridder classes, i.e. :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. Parameters ---------- filename : str The file name to load from. Returns ------- :class:`~esmf_regrid.schemes._ESMFRegridder` """ with _load_context(): cubes = iris.load(filename) # Extract the source, target and metadata information. src_cube = cubes.extract_cube(_SOURCE_NAME) _clean_var_names(src_cube) tgt_cube = cubes.extract_cube(_TARGET_NAME) _clean_var_names(tgt_cube) weights_cube = cubes.extract_cube(_WEIGHTS_NAME) weight_shape_cube = cubes.extract_cube(_WEIGHTS_SHAPE_NAME) # Determine the regridder type. regridder_type = weights_cube.attributes[_REGRIDDER_TYPE] assert regridder_type in _REGRIDDER_NAME_MAP scheme = _REGRIDDER_NAME_MAP[regridder_type] # Determine the regridding method, allowing for files created when # conservative regridding was the only method. method_string = weights_cube.attributes.get(_METHOD, "CONSERVATIVE") # Account for strings saved in previous versions. method_string = method_string.upper() method = getattr(Constants.Method, method_string) resolution = weights_cube.attributes.get(_RESOLUTION, None) src_resolution = weights_cube.attributes.get(_SOURCE_RESOLUTION, None) tgt_resolution = weights_cube.attributes.get(_TARGET_RESOLUTION, None) if resolution is not None: resolution = int(resolution) if src_resolution is not None: src_resolution = int(src_resolution) if tgt_resolution is not None: tgt_resolution = int(tgt_resolution) # Reconstruct the weight matrix. weight_data = weights_cube.data weight_rows = weights_cube.coord(_WEIGHTS_ROW_NAME).points weight_cols = weights_cube.coord(_WEIGHTS_COL_NAME).points weight_shape = weight_shape_cube.data weight_matrix = scipy.sparse.csr_matrix( (weight_data, (weight_rows, weight_cols)), shape=weight_shape ) mdtol = weights_cube.attributes[_MDTOL] if src_cube.coords(_SOURCE_MASK_NAME): use_src_mask = src_cube.coord(_SOURCE_MASK_NAME).points else: use_src_mask = False if tgt_cube.coords(_TARGET_MASK_NAME): use_tgt_mask = tgt_cube.coord(_TARGET_MASK_NAME).points else: use_tgt_mask = False # Allow for this coord not to exist for the sake of backwards compatibility. esmf_args_coords = weights_cube.coords(_ESMF_ARGS) if len(esmf_args_coords) == 0: esmf_args = {} else: esmf_args = esmf_args_coords[0].attributes for arg, arg_dict in _ESMF_ENUM_ARGS.items(): if arg in esmf_args: esmf_args[arg] = arg_dict[esmf_args[arg]] if scheme is GridToMeshESMFRegridder: resolution_keyword = _SOURCE_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} elif scheme is MeshToGridESMFRegridder: resolution_keyword = _TARGET_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} elif scheme is ESMFAreaWeightedRegridder: kwargs = { _SOURCE_RESOLUTION: src_resolution, _TARGET_RESOLUTION: tgt_resolution, "mdtol": mdtol, } elif scheme is ESMFBilinearRegridder: kwargs = {"mdtol": mdtol} else: kwargs = {} regridder = scheme( src_cube, tgt_cube, precomputed_weights=weight_matrix, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, esmf_args=esmf_args, **kwargs, ) esmf_version = weights_cube.attributes[_VERSION_ESMF] regridder.regridder.esmf_version = esmf_version esmf_regrid_version = weights_cube.attributes[_VERSION_INITIAL] regridder.regridder.esmf_regrid_version = esmf_regrid_version return regridder