Source code for eogrow.utils.zipmap

"""A collection of functions to be used with the ZipMapPipeline."""

from __future__ import annotations

from typing import Dict, Union

import numpy as np

from ..core.schemas import BaseSchema
from .validators import optional_field_validator, parse_dtype


[docs]class MapParams(BaseSchema): mapping: Dict[int, int] default: Union[int, None] dtype: np.dtype _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True)
[docs]def map_values( array: np.ndarray, *, mapping: dict[int, int], default: int | None = None, dtype: np.dtype | None = None, ) -> np.ndarray: """Maps all values from `array` according to a dictionary. A default value can be given, which is assigned to values without a corresponding key in the mapping dictionary. Vectorizing the `.get` method is faster, but difficult to do with a default (and also loses speed advantage).""" if default is not None: flat_mapped_array = np.array([mapping.get(x, default) for x in array.ravel()], dtype=dtype) return flat_mapped_array.reshape(array.shape) mapped_array = np.vectorize(mapping.get)(array) return mapped_array.astype(dtype)