Source code for eogrow.pipelines.features

"""Implements a pipeline to construct features for training/prediction."""

from __future__ import annotations

import logging
from typing import Dict, Optional, Tuple

import numpy as np
from pydantic import Field

from eolearn.core import (
    CopyTask,
    EONode,
    EOWorkflow,
    FeatureType,
    LoadTask,
    MergeFeatureTask,
    OverwritePermission,
    SaveTask,
)
from eolearn.core.types import Feature
from eolearn.features import NormalizedDifferenceIndexTask, SimpleFilterTask
from eolearn.mask import JoinMasksTask

from ..core.pipeline import Pipeline
from ..core.schemas import BaseSchema
from ..tasks.features import (
    MaxNDVIMosaickingTask,
    MedianMosaickingTask,
    MosaickingTask,
    ValidDataFractionPredicate,
    join_valid_and_cloud_masks,
)
from ..types import PatchList, TimePeriod
from ..utils.filter import get_patches_with_missing_features
from ..utils.validators import (
    ensure_storage_key_presence,
    field_validator,
    optional_field_validator,
    parse_dtype,
    parse_time_period,
)

LOGGER = logging.getLogger(__name__)


[docs]class ValidityFiltering(BaseSchema): cloud_mask_feature_name: Optional[str] = Field( description="Name of cloud mask to enable additional filtering by cloud" ) valid_data_feature_name: str = Field(description="Name of the valid-data mask to use for filtering.") validity_threshold: Optional[float] = Field( description="Threshold to remove frames with valid data lower than threshold" )
[docs]class FeaturesPipeline(Pipeline): """A pipeline to calculate and prepare features for ML"""
[docs] class Schema(Pipeline.Schema): input_folder_key: str = Field( description="The storage manager key pointing to the input folder for the features pipeline." ) _ensure_input_folder_key = ensure_storage_key_presence("input_folder_key") output_folder_key: str = Field( description="The storage manager key pointing to the output folder for the features pipeline." ) _ensure_output_folder_key = ensure_storage_key_presence("output_folder_key") bands_feature_name: str = Field(description="Name of data feature containing band data") data_preparation: ValidityFiltering ndis: Dict[str, Tuple[int, int]] = Field( default_factory=dict, description=( "A dictionary of kind `{feature_name: (id1, id2)}` that specifies how to calculate the NDIs of bands " "(with indices `id1` and `id2` in the bands feature) and save it under `feature_name`." ), ) dtype: Optional[np.dtype] = Field(description="The dtype under which the concatenated features should be saved") _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True) output_feature_name: str = Field(description="Name of output data feature encompassing bands and NDIs")
config: Schema
[docs] def filter_patch_list(self, patch_list: PatchList) -> PatchList: """EOPatches are filtered according to existence of specified output features""" return get_patches_with_missing_features( self.storage.filesystem, self.storage.get_folder(self.config.output_folder_key), patch_list, self._get_output_features(), check_timestamps=True, )
def _get_output_features(self) -> list[Feature]: """Lists all features that are to be saved upon the pipeline completion""" return [(FeatureType.DATA, self.config.output_feature_name)] def _get_bands_feature(self) -> Feature: return FeatureType.DATA, self.config.bands_feature_name def _get_valid_data_feature(self) -> Feature: return FeatureType.MASK, self.config.data_preparation.valid_data_feature_name
[docs] def build_workflow(self) -> EOWorkflow: """ Creates a workflow: 1. Loads and prepares a 'bands_feature' and 'valid_data_feature' 2. Temporally regularizes bands and NDIs 3. Calculates NDIs based on 'bands_feature' 4. Applies post-processing, which prepares all output features 5. Saves all relevant features (specified in _get_output_features) """ preparation_node = self.get_data_preparation_node() regularization_node = self.get_temporal_regularization_node(preparation_node) ndi_node = self.get_ndi_node(regularization_node) postprocessing_node = self.get_postprocessing_node(ndi_node) save_task = SaveTask( self.storage.get_folder(self.config.output_folder_key), filesystem=self.storage.filesystem, features=self._get_output_features(), overwrite_permission=OverwritePermission.OVERWRITE_FEATURES, use_zarr=self.storage.config.use_zarr, ) save_node = EONode(save_task, inputs=[postprocessing_node]) return EOWorkflow.from_endnodes(save_node)
[docs] def get_data_preparation_node(self) -> EONode: """Nodes that load, filter, and prepare a feature containing all bands :return: A node with preparation tasks and feature for masking invalid data """ filtering_config = self.config.data_preparation load_task = LoadTask( self.storage.get_folder(self.config.input_folder_key), filesystem=self.storage.filesystem, lazy_loading=True, ) end_node = EONode(load_task) if filtering_config.cloud_mask_feature_name: zip_masks_task = JoinMasksTask( [ self._get_valid_data_feature(), (FeatureType.MASK, filtering_config.cloud_mask_feature_name), ], self._get_valid_data_feature(), join_operation=join_valid_and_cloud_masks, ) end_node = EONode(zip_masks_task, inputs=[end_node], name="Combine validity mask and cloud mask") if filtering_config.validity_threshold is not None: filter_func = ValidDataFractionPredicate(filtering_config.validity_threshold) filter_task = SimpleFilterTask(self._get_valid_data_feature(), filter_func) end_node = EONode(filter_task, inputs=[end_node]) return end_node
[docs] def get_temporal_regularization_node(self, previous_node: EONode) -> EONode: """Builds node adding temporal regularization to workflow.""" return previous_node
[docs] def get_ndi_node(self, previous_node: EONode) -> EONode: """Builds a node for constructing Normalized Difference Indices""" for name, (id1, id2) in self.config.ndis.items(): ndi_task = NormalizedDifferenceIndexTask(self._get_bands_feature(), (FeatureType.DATA, name), (id1, id2)) previous_node = EONode(ndi_task, inputs=[previous_node]) return previous_node
[docs] def get_postprocessing_node(self, previous_node: EONode) -> EONode: """Tasks performed after temporal regularization. Should also prepare features for the saving step""" ndi_features = [(FeatureType.DATA, name) for name in self.config.ndis] merge_task = MergeFeatureTask( [self._get_bands_feature(), *ndi_features], (FeatureType.DATA, self.config.output_feature_name), dtype=self.config.dtype, ) return EONode(merge_task, inputs=[previous_node])
[docs]class MosaickingSpecifications(BaseSchema): time_period: TimePeriod _parse_time_period = field_validator("time_period", parse_time_period, pre=True) n_mosaics: int max_ndi_indices: Optional[Tuple[int, int]] = Field( description=( "When omitted uses median value mosaicking. If set, uses max NDI mosaicking for the NDI of the bands at" " specified indices. For example, to use max NDVI when using all 13 bands of L1C set parameter to `[7, 3]`" " (uses B08 and B04)" ) )
[docs]class MosaickingFeaturesPipeline(FeaturesPipeline): """A pipeline to calculate and prepare features for ML including mosaicking""" _NDI_FEATURE = FeatureType.DATA, "_NDI_FEATURE_OF_MOSAICKING_PIPELINE"
[docs] class Schema(FeaturesPipeline.Schema): mosaicking: MosaickingSpecifications = Field( description=( "Fine-tuning of mosaicking parameters. If not set, the interpolation will work on current timestamps" ) )
config: Schema
[docs] def get_data_preparation_node(self) -> EONode: preparation_node = super().get_data_preparation_node() if self.config.mosaicking.max_ndi_indices: ndi_task = NormalizedDifferenceIndexTask( self._get_bands_feature(), self._NDI_FEATURE, self.config.mosaicking.max_ndi_indices ) return EONode(ndi_task, inputs=[preparation_node]) return preparation_node
[docs] def get_temporal_regularization_node(self, previous_node: EONode) -> EONode: start_date, end_date = self.config.mosaicking.time_period mosaicking_task: MosaickingTask if self.config.mosaicking.max_ndi_indices: mosaicking_task = MaxNDVIMosaickingTask( self._get_bands_feature(), (start_date, end_date, self.config.mosaicking.n_mosaics), self._NDI_FEATURE, self._get_valid_data_feature(), ) else: mosaicking_task = MedianMosaickingTask( self._get_bands_feature(), (start_date, end_date, self.config.mosaicking.n_mosaics), self._get_valid_data_feature(), ) mosaicking_node = EONode(mosaicking_task, inputs=[previous_node]) return EONode( CopyTask(features=[self._get_bands_feature()]), inputs=[mosaicking_node], name="Remove non-mosaicked features", )