"""Module implementing utilities for chained configs."""
from __future__ import annotations
from typing import Any, Dict
import ray
from pydantic import Field, ValidationError
from ..core.config import RawConfig
from ..core.schemas import BaseSchema
from .meta import collect_schema, load_pipeline_class
[docs]class PipelineRunSchema(BaseSchema):
pipeline_config: dict
pipeline_resources: Dict[str, Any] = Field(
default_factory=dict,
description=(
"Keyword arguments passed to ray when executing the main pipeline process. The options are specified [here]"
"(https://docs.ray.io/en/latest/ray-core/api/doc/ray.remote_function.RemoteFunction.options.html)."
),
)
[docs]def validate_pipeline_chain(pipeline_chain: list[RawConfig]) -> None:
for i, run_config in enumerate(pipeline_chain):
try:
run_schema = PipelineRunSchema.parse_obj(run_config)
except ValidationError as e:
raise TypeError(
f"Pipeline-chain element {i} should be a dictionary with the fields `pipeline_config` and the optional"
" `pipeline_resources`."
) from e
pipeline_schema = collect_schema(load_pipeline_class(run_schema.pipeline_config))
pipeline_schema.parse_obj(run_schema.pipeline_config)
[docs]def run_pipeline_chain(pipeline_chain: list[RawConfig]) -> None:
for run_config in pipeline_chain:
run_schema = PipelineRunSchema.parse_obj(run_config)
runner = _pipeline_runner.options(**run_schema.pipeline_resources)
ray.get(runner.remote(run_schema.pipeline_config)) # type: ignore [arg-type]
@ray.remote(max_retries=0)
def _pipeline_runner(config: RawConfig) -> None:
return load_pipeline_class(config).from_raw_config(config).run()