Source code for esmf_regrid.experimental.io

"""Provides load/save functions for regridders."""

from contextlib import contextmanager

import iris
from iris.coords import AuxCoord
from iris.cube import Cube, CubeList
import numpy as np
import scipy.sparse

import esmf_regrid
from esmf_regrid import Constants, _load_context, check_method
from esmf_regrid.experimental.unstructured_scheme import (
    GridToMeshESMFRegridder,
    MeshToGridESMFRegridder,
)
from esmf_regrid.schemes import (
    ESMFAreaWeightedRegridder,
    ESMFBilinearRegridder,
    ESMFNearestRegridder,
    GridRecord,
    MeshRecord,
)

SUPPORTED_REGRIDDERS = [
    ESMFAreaWeightedRegridder,
    ESMFBilinearRegridder,
    ESMFNearestRegridder,
    GridToMeshESMFRegridder,
    MeshToGridESMFRegridder,
]
REGRIDDER_NAME_MAP = {rg_class.__name__: rg_class for rg_class in SUPPORTED_REGRIDDERS}
SOURCE_NAME = "regridder_source_field"
SOURCE_MASK_NAME = "regridder_source_mask"
TARGET_NAME = "regridder_target_field"
TARGET_MASK_NAME = "regridder_target_mask"
WEIGHTS_NAME = "regridder_weights"
WEIGHTS_SHAPE_NAME = "weights_shape"
WEIGHTS_ROW_NAME = "weight_matrix_rows"
WEIGHTS_COL_NAME = "weight_matrix_columns"
REGRIDDER_TYPE = "regridder_type"
VERSION_ESMF = "ESMF_version"
VERSION_INITIAL = "esmf_regrid_version_on_initialise"
MDTOL = "mdtol"
METHOD = "method"
RESOLUTION = "resolution"
SOURCE_RESOLUTION = "src_resolution"
TARGET_RESOLUTION = "tgt_resolution"


def _add_mask_to_cube(mask, cube, name):
    if isinstance(mask, np.ndarray):
        mask = mask.astype(int)
        mask_coord = AuxCoord(mask, var_name=name, long_name=name)
        cube.add_aux_coord(mask_coord, list(range(cube.ndim)))


@contextmanager
def _managed_var_name(src_cube, tgt_cube):
    src_coord_names = []
    src_mesh_coords = []
    if src_cube.mesh is not None:
        src_mesh = src_cube.mesh
        src_mesh_coords = src_mesh.coords()
        for coord in src_mesh_coords:
            src_coord_names.append(coord.var_name)
    tgt_coord_names = []
    tgt_mesh_coords = []
    if tgt_cube.mesh is not None:
        tgt_mesh = tgt_cube.mesh
        tgt_mesh_coords = tgt_mesh.coords()
        for coord in tgt_mesh_coords:
            tgt_coord_names.append(coord.var_name)

    try:
        for coord in src_mesh_coords:
            coord.var_name = "_".join([SOURCE_NAME, "mesh", coord.name()])
        for coord in tgt_mesh_coords:
            coord.var_name = "_".join([TARGET_NAME, "mesh", coord.name()])
        yield None
    finally:
        for coord, var_name in zip(src_mesh_coords, src_coord_names, strict=False):
            coord.var_name = var_name
        for coord, var_name in zip(tgt_mesh_coords, tgt_coord_names, strict=False):
            coord.var_name = var_name


def _clean_var_names(cube):
    cube.var_name = None
    for coord in cube.coords():
        coord.var_name = None
    if cube.mesh is not None:
        cube.mesh.var_name = None
        for coord in cube.mesh.coords():
            coord.var_name = None
        for con in cube.mesh.connectivities():
            con.var_name = None


[docs] def save_regridder(rg, filename): """Save a regridder scheme instance. Saves any of the regridder classes, i.e. :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. . Parameters ---------- rg : :class:`~esmf_regrid.schemes._ESMFRegridder` The regridder instance to save. filename : str The file name to save to. """ regridder_type = rg.__class__.__name__ def _standard_grid_cube(grid, name): if grid[0].ndim == 1: shape = [coord.points.size for coord in grid] else: shape = grid[0].shape data = np.zeros(shape) cube = Cube(data, var_name=name, long_name=name) if grid[0].ndim == 1: cube.add_dim_coord(grid[0], 0) cube.add_dim_coord(grid[1], 1) else: cube.add_aux_coord(grid[0], [0, 1]) cube.add_aux_coord(grid[1], [0, 1]) return cube def _standard_mesh_cube(mesh, location, name): mesh_coords = mesh.to_MeshCoords(location) data = np.zeros(mesh_coords[0].points.shape[0]) cube = Cube(data, var_name=name, long_name=name) for coord in mesh_coords: cube.add_aux_coord(coord, 0) return cube if regridder_type in [ "ESMFAreaWeightedRegridder", "ESMFBilinearRegridder", "ESMFNearestRegridder", ]: src_grid = rg._src if isinstance(src_grid, GridRecord): src_cube = _standard_grid_cube( (src_grid.grid_y, src_grid.grid_x), SOURCE_NAME ) elif isinstance(src_grid, MeshRecord): src_mesh, src_location = src_grid src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) else: raise ValueError("Improper type for `rg._src`.") _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_grid = rg._tgt if isinstance(tgt_grid, GridRecord): tgt_cube = _standard_grid_cube( (tgt_grid.grid_y, tgt_grid.grid_x), TARGET_NAME ) elif isinstance(tgt_grid, MeshRecord): tgt_mesh, tgt_location = tgt_grid tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) else: raise ValueError("Improper type for `rg._tgt`.") _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) elif regridder_type == "GridToMeshESMFRegridder": src_grid = (rg.grid_y, rg.grid_x) src_cube = _standard_grid_cube(src_grid, SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_mesh = rg.mesh tgt_location = rg.location tgt_cube = _standard_mesh_cube(tgt_mesh, tgt_location, TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) elif regridder_type == "MeshToGridESMFRegridder": src_mesh = rg.mesh src_location = rg.location src_cube = _standard_mesh_cube(src_mesh, src_location, SOURCE_NAME) _add_mask_to_cube(rg.src_mask, src_cube, SOURCE_MASK_NAME) tgt_grid = (rg.grid_y, rg.grid_x) tgt_cube = _standard_grid_cube(tgt_grid, TARGET_NAME) _add_mask_to_cube(rg.tgt_mask, tgt_cube, TARGET_MASK_NAME) else: msg = ( f"Expected a regridder of type `GridToMeshESMFRegridder` or " f"`MeshToGridESMFRegridder`, got type {regridder_type}." ) raise TypeError(msg) method = str(check_method(rg.method).name) if regridder_type in ["GridToMeshESMFRegridder", "MeshToGridESMFRegridder"]: resolution = rg.resolution src_resolution = None tgt_resolution = None elif regridder_type == "ESMFAreaWeightedRegridder": resolution = None src_resolution = rg.src_resolution tgt_resolution = rg.tgt_resolution else: resolution = None src_resolution = None tgt_resolution = None weight_matrix = rg.regridder.weight_matrix reformatted_weight_matrix = scipy.sparse.coo_matrix(weight_matrix) weight_data = reformatted_weight_matrix.data weight_rows = reformatted_weight_matrix.row weight_cols = reformatted_weight_matrix.col weight_shape = reformatted_weight_matrix.shape esmf_version = rg.regridder.esmf_version esmf_regrid_version = rg.regridder.esmf_regrid_version save_version = esmf_regrid.__version__ # Currently, all schemes use the fracarea normalization. normalization = "fracarea" mdtol = rg.mdtol attributes = { "title": "iris-esmf-regrid regridding scheme", REGRIDDER_TYPE: regridder_type, VERSION_ESMF: esmf_version, VERSION_INITIAL: esmf_regrid_version, "esmf_regrid_version_on_save": save_version, "normalization": normalization, MDTOL: mdtol, METHOD: method, } if resolution is not None: attributes[RESOLUTION] = resolution if src_resolution is not None: attributes[SOURCE_RESOLUTION] = src_resolution if tgt_resolution is not None: attributes[TARGET_RESOLUTION] = tgt_resolution weights_cube = Cube(weight_data, var_name=WEIGHTS_NAME, long_name=WEIGHTS_NAME) row_coord = AuxCoord( weight_rows, var_name=WEIGHTS_ROW_NAME, long_name=WEIGHTS_ROW_NAME ) col_coord = AuxCoord( weight_cols, var_name=WEIGHTS_COL_NAME, long_name=WEIGHTS_COL_NAME ) weights_cube.add_aux_coord(row_coord, 0) weights_cube.add_aux_coord(col_coord, 0) weight_shape_cube = Cube( weight_shape, var_name=WEIGHTS_SHAPE_NAME, long_name=WEIGHTS_SHAPE_NAME, ) # Save cubes while ensuring var_names do not conflict for the sake of consistency. with _managed_var_name(src_cube, tgt_cube): cube_list = CubeList([src_cube, tgt_cube, weights_cube, weight_shape_cube]) for cube in cube_list: cube.attributes = attributes iris.fileformats.netcdf.save(cube_list, filename)
[docs] def load_regridder(filename): """Load a regridder scheme instance. Loads any of the regridder classes, i.e. :class:`~esmf_regrid.experimental.unstructured_scheme.GridToMeshESMFRegridder`, :class:`~esmf_regrid.experimental.unstructured_scheme.MeshToGridESMFRegridder`, :class:`~esmf_regrid.schemes.ESMFAreaWeightedRegridder`, :class:`~esmf_regrid.schemes.ESMFBilinearRegridder` or :class:`~esmf_regrid.schemes.ESMFNearestRegridder`. Parameters ---------- filename : str The file name to load from. Returns ------- :class:`~esmf_regrid.schemes._ESMFRegridder` """ with _load_context(): cubes = iris.load(filename) # Extract the source, target and metadata information. src_cube = cubes.extract_cube(SOURCE_NAME) _clean_var_names(src_cube) tgt_cube = cubes.extract_cube(TARGET_NAME) _clean_var_names(tgt_cube) weights_cube = cubes.extract_cube(WEIGHTS_NAME) weight_shape_cube = cubes.extract_cube(WEIGHTS_SHAPE_NAME) # Determine the regridder type. regridder_type = weights_cube.attributes[REGRIDDER_TYPE] assert regridder_type in REGRIDDER_NAME_MAP scheme = REGRIDDER_NAME_MAP[regridder_type] # Determine the regridding method, allowing for files created when # conservative regridding was the only method. method = getattr( Constants.Method, weights_cube.attributes.get(METHOD, "CONSERVATIVE") ) resolution = weights_cube.attributes.get(RESOLUTION, None) src_resolution = weights_cube.attributes.get(SOURCE_RESOLUTION, None) tgt_resolution = weights_cube.attributes.get(TARGET_RESOLUTION, None) if resolution is not None: resolution = int(resolution) if src_resolution is not None: src_resolution = int(src_resolution) if tgt_resolution is not None: tgt_resolution = int(tgt_resolution) # Reconstruct the weight matrix. weight_data = weights_cube.data weight_rows = weights_cube.coord(WEIGHTS_ROW_NAME).points weight_cols = weights_cube.coord(WEIGHTS_COL_NAME).points weight_shape = weight_shape_cube.data weight_matrix = scipy.sparse.csr_matrix( (weight_data, (weight_rows, weight_cols)), shape=weight_shape ) mdtol = weights_cube.attributes[MDTOL] if src_cube.coords(SOURCE_MASK_NAME): use_src_mask = src_cube.coord(SOURCE_MASK_NAME).points else: use_src_mask = False if tgt_cube.coords(TARGET_MASK_NAME): use_tgt_mask = tgt_cube.coord(TARGET_MASK_NAME).points else: use_tgt_mask = False if scheme is GridToMeshESMFRegridder: resolution_keyword = SOURCE_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} elif scheme is MeshToGridESMFRegridder: resolution_keyword = TARGET_RESOLUTION kwargs = {resolution_keyword: resolution, "method": method, "mdtol": mdtol} elif scheme is ESMFAreaWeightedRegridder: kwargs = { SOURCE_RESOLUTION: src_resolution, TARGET_RESOLUTION: tgt_resolution, "mdtol": mdtol, } elif scheme is ESMFBilinearRegridder: kwargs = {"mdtol": mdtol} else: kwargs = {} regridder = scheme( src_cube, tgt_cube, precomputed_weights=weight_matrix, use_src_mask=use_src_mask, use_tgt_mask=use_tgt_mask, **kwargs, ) esmf_version = weights_cube.attributes[VERSION_ESMF] regridder.regridder.esmf_version = esmf_version esmf_regrid_version = weights_cube.attributes[VERSION_INITIAL] regridder.regridder.esmf_regrid_version = esmf_regrid_version return regridder