Source code for eogrow.pipelines.batch_to_eopatch

"""Pipeline for conversion of batch results to EOPatches."""

from __future__ import annotations

from typing import Any, List, Optional

import numpy as np
from pydantic import Field, validator

from eolearn.core import (
    CreateEOPatchTask,
    EONode,
    EOWorkflow,
    FeatureType,
    MergeEOPatchesTask,
    MergeFeatureTask,
    OverwritePermission,
    RemoveFeatureTask,
    RenameFeatureTask,
    SaveTask,
)
from eolearn.core.types import Feature
from eolearn.io import ImportFromTiffTask

from ..core.pipeline import Pipeline
from ..core.schemas import BaseSchema
from ..tasks.batch_to_eopatch import DeleteFilesTask, FixImportedTimeDependentFeatureTask, LoadUserDataTask
from ..tasks.common import LinearFunctionTask
from ..types import ExecKwargs, PatchList, RawSchemaDict
from ..utils.filter import get_patches_with_missing_features
from ..utils.validators import ensure_storage_key_presence, optional_field_validator, parse_dtype


[docs]class FeatureMappingSchema(BaseSchema): """Defines a mapping between 1 or more batch outputs into an EOPatch feature""" batch_files: List[str] = Field( description=( "A list of files that will be converted into an EOPatch feature. If you specify multiple tiff " "files they will be concatenated together over the bands dimension in the specified order." ), ) feature: Feature multiply_factor: float = Field(1, description="Factor used to multiply feature values with.") dtype: Optional[np.dtype] = Field(description="Dtype of the output feature.") _parse_dtype = optional_field_validator("dtype", parse_dtype, pre=True)
[docs]class BatchToEOPatchPipeline(Pipeline): """Transforms the `tiff` files into `EOPatch` objects. Temporal tiffs are expected to be of shape `(h, w, t)` and one tiff per band. If the pipeline knows that `timestamps=[]` (through `userdata.json`) and the tiffs contain only `0` elements, it transforms them into "temporally empty" `EOPatch` objects of shape `(0, h, w, b)`. """
[docs] class Schema(Pipeline.Schema): input_folder_key: str = Field(description="Storage manager key pointing to the path with Batch results") _ensure_input_folder_key = ensure_storage_key_presence("input_folder_key") output_folder_key: str = Field(description="Storage manager key pointing to where the EOPatches are saved") _ensure_output_folder_key = ensure_storage_key_presence("output_folder_key") userdata_feature_name: Optional[str] = Field( description="A name of META_INFO feature in which userdata.json would be stored." ) userdata_timestamp_reader: Optional[str] = Field( description=( "Either an import path to a utility function or a Python code describing how to read " "dates from userdata dictionary." ), example="\"[info['date'] for info in json.loads(userdata['metadata'])]\"", ) mapping: List[FeatureMappingSchema] = Field( description="A list of mapping from batch files into EOPatch features." ) @validator("mapping") def check_nonempty_input(cls, value: list, values: RawSchemaDict) -> list: if not value: params = "userdata_feature_name", "userdata_timestamp_reader" assert any( values.get(param) is not None for param in params ), "At least one of `userdata_feature_name`, `userdata_timestamp_reader`, or `mapping` has to be set." return value remove_batch_data: bool = Field(True, description="Remove the raw batch data after the conversion is complete")
config: Schema def __init__(self, *args: Any, **kwargs: Any): """Additionally sets some basic parameters calculated from config parameters""" super().__init__(*args, **kwargs) self._input_folder = self.storage.get_folder(self.config.input_folder_key) self._has_userdata = self.config.userdata_feature_name or self.config.userdata_timestamp_reader self._all_batch_files = self._get_all_batch_files()
[docs] def filter_patch_list(self, patch_list: PatchList) -> PatchList: """EOPatches are filtered according to existence of specified output features""" return get_patches_with_missing_features( self.storage.filesystem, self.storage.get_folder(self.config.output_folder_key), patch_list, self._get_output_features(), check_timestamps=self.config.userdata_timestamp_reader is not None, )
def _get_output_features(self) -> list[Feature]: """Lists all features that the pipeline outputs.""" features = [feature_mapping.feature for feature_mapping in self.config.mapping] if self.config.userdata_feature_name: features.append((FeatureType.META_INFO, self.config.userdata_feature_name)) return features
[docs] def build_workflow(self) -> EOWorkflow: """Builds the workflow""" metadata_node = EONode(CreateEOPatchTask(), name="Establish BBox") if self._has_userdata: metadata_node = EONode( LoadUserDataTask( path=self._input_folder, filesystem=self.storage.filesystem, userdata_feature_name=self.config.userdata_feature_name, userdata_timestamp_reader=self.config.userdata_timestamp_reader, ), inputs=[metadata_node], ) mapping_nodes = [ self._get_tiff_mapping_node(feature_mapping, metadata_node) for feature_mapping in self.config.mapping ] last_node = metadata_node if len(mapping_nodes) == 1: last_node = mapping_nodes[0] elif len(mapping_nodes) > 1: last_node = EONode(MergeEOPatchesTask(), inputs=mapping_nodes) processing_node = self.get_processing_node(last_node) save_task = SaveTask( path=self.storage.get_folder(self.config.output_folder_key), filesystem=self.storage.filesystem, features=self._get_output_features(), overwrite_permission=OverwritePermission.OVERWRITE_FEATURES, ) save_node = EONode(save_task, inputs=([processing_node] if processing_node else [])) cleanup_node = None if self.config.remove_batch_data: delete_task = DeleteFilesTask( path=self._input_folder, filesystem=self.storage.filesystem, filenames=self._all_batch_files, ) cleanup_node = EONode(delete_task, inputs=[save_node], name="Delete batch data") return EOWorkflow.from_endnodes(cleanup_node or save_node)
def _get_tiff_mapping_node(self, mapping: FeatureMappingSchema, previous_node: EONode | None) -> EONode: """Prepares tasks and dependencies that convert tiff files into an EOPatch feature""" if not all(batch_file.endswith(".tif") for batch_file in mapping.batch_files): raise ValueError(f"All batch files should end with .tif but found {mapping.batch_files}") feature_type, feature_name = mapping.feature if not (feature_type.is_image()): raise ValueError(f"Tiffs can only be read into spatial raster feature types, but {feature_type} was given.") tmp_features = [] end_nodes = [] for batch_file in mapping.batch_files: feature = feature_type, batch_file.replace(".tif", "_tmp") tmp_features.append(feature) tmp_timeless_feature = ( FeatureType.MASK_TIMELESS if feature_type.is_discrete() else FeatureType.DATA_TIMELESS ), feature[1] import_task = ImportFromTiffTask( tmp_timeless_feature, self._input_folder, filesystem=self.storage.filesystem ) # Filename is written into the dependency name to be used later for execution arguments: import_node = EONode( import_task, inputs=[previous_node] if previous_node else [], name=f"{batch_file} import", ) if feature_type.is_temporal(): fix_task = FixImportedTimeDependentFeatureTask(tmp_timeless_feature, feature) end_nodes.append(EONode(fix_task, inputs=[import_node])) else: end_nodes.append(import_node) previous_node = EONode(MergeEOPatchesTask(), inputs=end_nodes) if len(end_nodes) > 1 else end_nodes[0] final_feature = feature_type, feature_name end_node = self._get_feature_merge_node(previous_node, tmp_features, final_feature) if mapping.multiply_factor != 1 or mapping.dtype is not None: multiply_task = LinearFunctionTask(final_feature, slope=mapping.multiply_factor, dtype=mapping.dtype) end_node = EONode(multiply_task, inputs=[end_node]) return end_node @staticmethod def _get_feature_merge_node( previous_node: EONode, input_features: list[Feature], output_feature: Feature ) -> EONode: """Merges input features into a single output feature and removes input features. In case there is a single input feature this method just renames it into the output feature. This way it avoids memory duplication that otherwise happens in `MergeFeatureTask`.""" if len(input_features) == 1: feature_type, input_feature_name = input_features[0] _, output_feature_name = output_feature rename_task = RenameFeatureTask([(feature_type, input_feature_name, output_feature_name)]) return EONode(rename_task, inputs=[previous_node]) merge_feature_task = MergeFeatureTask(input_features=input_features, output_feature=output_feature) merge_node = EONode(merge_feature_task, inputs=[previous_node]) remove_task = RemoveFeatureTask(input_features) return EONode(remove_task, inputs=[merge_node])
[docs] def get_processing_node(self, previous_node: EONode) -> EONode: """This method can be overwritten to add more tasks that process loaded data before saving it.""" return previous_node
[docs] def get_execution_arguments(self, workflow: EOWorkflow, patch_list: PatchList) -> ExecKwargs: """Prepare execution arguments per each EOPatch""" exec_args = super().get_execution_arguments(workflow, patch_list) nodes = workflow.get_nodes() for patch_name, patch_args in exec_args.items(): for node in nodes: if isinstance(node.task, ImportFromTiffTask): if node.name is None: raise RuntimeError("One of the ImportFromTiffTask nodes has not been tagged with filename.") filename = node.name.split()[0] path = f"{patch_name}/{filename}" patch_args[node] = dict(filename=path) if isinstance(node.task, (DeleteFilesTask, LoadUserDataTask)): patch_args[node] = dict(folder=patch_name) return exec_args
def _get_all_batch_files(self) -> list[str]: """Provides a list of batch files used in this pipeline""" files = [file for feature_mapping in self.config.mapping for file in feature_mapping.batch_files] if self._has_userdata: files.append("userdata.json") return list(set(files))