diff --git a/samples/core/loop_output/loop_output_test.py b/samples/core/loop_output/loop_output_test.py index a770a009653..93e545b255b 100644 --- a/samples/core/loop_output/loop_output_test.py +++ b/samples/core/loop_output/loop_output_test.py @@ -12,11 +12,72 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import unittest import kfp +import kfp_server_api +from ml_metadata.proto import Execution from .loop_output import my_pipeline -from ...test.util import run_pipeline_func, TestCase +from .loop_output_v2 import my_pipeline as my_pipeline_v2 +from ...test.util import KfpTask, run_pipeline_func, TestCase + + +def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, + tasks: dict[str, KfpTask], **kwargs): + t.assertEqual(run.status, 'Succeeded') + # assert DAG structure + t.assertCountEqual(tasks.keys(), ['args-generator-op', 'for-loop-1']) + t.assertCountEqual( + ['for-loop-1-#0', 'for-loop-1-#1', 'for-loop-1-#2'], + tasks['for-loop-1'].children.keys(), + ) + # assert all iteration parameters + t.assertCountEqual( + [1.1, 1.2, 1.3], + [ + x.inputs + .parameters['pipelinechannel--args-generator-op-Output-loop-item'] + for x in tasks['for-loop-1'].children.values() + ], + ) + # assert 1 iteration task + t.assertEqual( + { + 'name': 'for-loop-1-#0', + 'type': 'system.DAGExecution', + 'state': + Execution.State.RUNNING, # TODO(Bobgy): this should be COMPLETE + 'inputs': { + 'parameters': { + 'pipelinechannel--args-generator-op-Output-loop-item': 1.1 + } + } + }, + tasks['for-loop-1'].children['for-loop-1-#0'].get_dict(), + ) + t.assertEqual( + { + 'name': 'print-op', + 'type': 'system.ContainerExecution', + 'state': Execution.State.COMPLETE, + 'inputs': { + 'parameters': { + 's': 1.1 + } + } + }, + tasks['for-loop-1'].children['for-loop-1-#0'].children['print-op'] + .get_dict(), + ) + run_pipeline_func([ + TestCase( + pipeline_func=my_pipeline_v2, + mode=kfp.dsl.PipelineExecutionMode.V2_ENGINE, + verify_func=verify, + ), TestCase( pipeline_func=my_pipeline, mode=kfp.dsl.PipelineExecutionMode.V1_LEGACY, diff --git a/samples/core/loop_output/loop_output_v2.py b/samples/core/loop_output/loop_output_v2.py new file mode 100644 index 00000000000..1ba4ab7b930 --- /dev/null +++ b/samples/core/loop_output/loop_output_v2.py @@ -0,0 +1,41 @@ +# Copyright 2021 The Kubeflow Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from kfp.v2 import dsl + +# In tests, we install a KFP package from the PR under test. Users should not +# normally need to specify `kfp_package_path` in their component definitions. +_KFP_PACKAGE_PATH = os.getenv('KFP_PACKAGE_PATH') + + +@dsl.component(kfp_package_path=_KFP_PACKAGE_PATH) +def args_generator_op() -> str: + return '[1.1, 1.2, 1.3]' + + +# TODO(Bobgy): how can we make this component with type float? +# got error: kfp.v2.components.types.type_utils.InconsistentTypeException: +# Incompatible argument passed to the input "s" of component "Print op": Argument +# type "STRING" is incompatible with the input type "NUMBER_DOUBLE" +@dsl.component(kfp_package_path=_KFP_PACKAGE_PATH) +def print_op(s: str): + print(s) + + +@dsl.pipeline(name='pipeline-with-loop-output-v2') +def my_pipeline(): + args_generator = args_generator_op() + with dsl.ParallelFor(args_generator.output) as item: + print_op(s=item) diff --git a/samples/core/loop_static/loop_static_v2.py b/samples/core/loop_static/loop_static_v2.py new file mode 100644 index 00000000000..bb7ee405457 --- /dev/null +++ b/samples/core/loop_static/loop_static_v2.py @@ -0,0 +1,30 @@ +from kfp.v2 import components, dsl +from typing import List + + +@dsl.component +def print_op(text: str) -> str: + print(text) + return text + + +@dsl.component +def concat_op(a: str, b: str) -> str: + print(a + b) + return a + b + + +_DEFAULT_LOOP_ARGUMENTS = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}] + + +@dsl.pipeline(name='pipeline-with-loop-static') +def my_pipeline( + static_loop_arguments: List[dict] = _DEFAULT_LOOP_ARGUMENTS, + greeting: str = 'this is a test for looping through parameters', +): + print_task = print_op(text=greeting) + + with dsl.ParallelFor(static_loop_arguments) as item: + concat_task = concat_op(a=item.a, b=item.b) + concat_task.after(print_task) + print_task_2 = print_op(text=concat_task.output) diff --git a/samples/test/lightweight_python_functions_v2_pipeline_test.py b/samples/test/lightweight_python_functions_v2_pipeline_test.py index 21825386249..8311a60b535 100644 --- a/samples/test/lightweight_python_functions_v2_pipeline_test.py +++ b/samples/test/lightweight_python_functions_v2_pipeline_test.py @@ -39,7 +39,6 @@ def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs): t.assertEqual( { 'inputs': { - 'artifacts': [], 'parameters': { 'message': 'message', } @@ -93,8 +92,8 @@ def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs): 'parameters': { 'input_bool': True, 'input_dict': { - "A": 1, - "B": 2 + "A": 1.0, + "B": 2.0, }, 'input_list': ["a", "b", "c"], 'message': 'message' @@ -110,7 +109,6 @@ def verify(run: kfp_server_api.ApiRun, mlmd_connection_config, **kwargs): 'name': 'model', 'type': 'system.Model' }], - 'parameters': {} }, 'type': 'system.ContainerExecution', 'state': Execution.State.COMPLETE, diff --git a/samples/test/metrics_visualization_v2_test.py b/samples/test/metrics_visualization_v2_test.py index 7473490dab8..26589c484df 100644 --- a/samples/test/metrics_visualization_v2_test.py +++ b/samples/test/metrics_visualization_v2_test.py @@ -40,10 +40,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, t.assertEqual( { - 'inputs': { - 'artifacts': [], - 'parameters': {} - }, 'name': 'wine-classification', 'outputs': { 'artifacts': [{ @@ -100,7 +96,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, 'name': 'metrics', 'type': 'system.ClassificationMetrics' }], - 'parameters': {} }, 'type': 'system.ContainerExecution', 'state': Execution.State.COMPLETE, @@ -108,7 +103,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, t.assertEqual( { 'inputs': { - 'artifacts': [], 'parameters': { 'test_samples_fraction': 0.3 } @@ -144,7 +138,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, 'name': 'metrics', 'type': 'system.ClassificationMetrics' }], - 'parameters': {} }, 'type': 'system.ContainerExecution', 'state': Execution.State.COMPLETE, @@ -160,10 +153,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, t.assertEqual( { - 'inputs': { - 'artifacts': [], - 'parameters': {} - }, 'name': 'digit-classification', 'outputs': { 'artifacts': [{ @@ -174,7 +163,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, 'name': 'metrics', 'type': 'system.Metrics' }], - 'parameters': {} }, 'type': 'system.ContainerExecution', 'state': Execution.State.COMPLETE, @@ -182,10 +170,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, t.assertEqual( { - 'inputs': { - 'artifacts': [], - 'parameters': {} - }, 'name': 'html-visualization', 'outputs': { 'artifacts': [{ @@ -195,7 +179,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, 'name': 'html_artifact', 'type': 'system.HTML' }], - 'parameters': {} }, 'state': Execution.State.COMPLETE, 'type': 'system.ContainerExecution' @@ -203,10 +186,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, t.assertEqual( { - 'inputs': { - 'artifacts': [], - 'parameters': {} - }, 'name': 'markdown-visualization', 'outputs': { 'artifacts': [{ @@ -216,7 +195,6 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, 'name': 'markdown_artifact', 'type': 'system.Markdown' }], - 'parameters': {} }, 'state': Execution.State.COMPLETE, 'type': 'system.ContainerExecution' diff --git a/samples/test/two_step_with_uri_placeholder_test.py b/samples/test/two_step_with_uri_placeholder_test.py index 63f480e1311..34cf4cd55fe 100644 --- a/samples/test/two_step_with_uri_placeholder_test.py +++ b/samples/test/two_step_with_uri_placeholder_test.py @@ -26,24 +26,16 @@ def verify_tasks(t: unittest.TestCase, tasks: Dict[str, KfpTask]): - task_names = [*tasks.keys()] - t.assertCountEqual(task_names, ['read-from-gcs', 'write-to-gcs'], + t.assertCountEqual(tasks.keys(), ['read-from-gcs', 'write-to-gcs'], 'task names') write_task = tasks['write-to-gcs'] read_task = tasks['read-from-gcs'] - pprint('======= preprocess task =======') - pprint(write_task.get_dict()) - pprint('======= train task =======') - pprint(read_task.get_dict()) - pprint('==============') - t.assertEqual( { 'name': 'write-to-gcs', 'inputs': { - 'artifacts': [], 'parameters': { 'msg': 'Hello world!', } @@ -56,7 +48,6 @@ def verify_tasks(t: unittest.TestCase, tasks: Dict[str, KfpTask]): 'name': 'artifact', 'type': 'system.Artifact' }], - 'parameters': {} }, 'type': 'system.ContainerExecution', 'state': Execution.State.COMPLETE, @@ -72,11 +63,6 @@ def verify_tasks(t: unittest.TestCase, tasks: Dict[str, KfpTask]): 'name': 'artifact', 'type': 'system.Artifact', }], - 'parameters': {} - }, - 'outputs': { - 'artifacts': [], - 'parameters': {} }, 'type': 'system.ContainerExecution', 'state': Execution.State.COMPLETE, diff --git a/samples/test/util.py b/samples/test/util.py index 16e3e518bdc..2bdc3bf1d42 100644 --- a/samples/test/util.py +++ b/samples/test/util.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import json import logging import os @@ -19,7 +21,7 @@ import random from dataclasses import dataclass, asdict from pprint import pprint -from typing import Dict, List, Callable, Optional, Mapping +from typing import Callable, Optional import unittest from google.protobuf.json_format import MessageToDict @@ -29,7 +31,7 @@ import kfp_server_api from ml_metadata import metadata_store from ml_metadata.metadata_store.metadata_store import ListOptions -from ml_metadata.proto import metadata_store_pb2 +from ml_metadata.proto import Event, Execution, metadata_store_pb2 MINUTE = 60 @@ -63,11 +65,11 @@ class TestCase: pipeline_func: Callable mode: kfp.dsl.PipelineExecutionMode = kfp.dsl.PipelineExecutionMode.V2_COMPATIBLE enable_caching: bool = False - arguments: Optional[Dict[str, str]] = None + arguments: Optional[dict[str, str]] = None verify_func: Verifier = _default_verify_func -def run_pipeline_func(test_cases: List[TestCase]): +def run_pipeline_func(test_cases: list[TestCase]): """Run a pipeline function and wait for its result. :param pipeline_func: pipeline function to run @@ -317,7 +319,7 @@ def run_v2_pipeline( launcher_v2_image: Optional[str], pipeline_root: Optional[str], enable_caching: bool, - arguments: Mapping[str, str], + arguments: dict[str, str], ): import tempfile import subprocess @@ -415,13 +417,13 @@ def new( @dataclass class TaskInputs: parameters: dict - artifacts: List[KfpArtifact] + artifacts: list[KfpArtifact] @dataclass class TaskOutputs: parameters: dict - artifacts: List[KfpArtifact] + artifacts: list[KfpArtifact] @dataclass @@ -432,26 +434,44 @@ class KfpTask: state: int inputs: TaskInputs outputs: TaskOutputs + children: Optional[dict[str, KfpTask]] = None def get_dict(self): - d = asdict(self) + ignore_zero_values = lambda x: {k: v for (k, v) in x if v} + d = asdict(self, dict_factory=ignore_zero_values) # remove uri, because they are not deterministic - for artifact in d.get('inputs').get('artifacts'): + for artifact in d.get('inputs', {}).get('artifacts', []): artifact.pop('uri') - for artifact in d.get('outputs').get('artifacts'): + for artifact in d.get('outputs', {}).get('artifacts', []): artifact.pop('uri') + # children should be accessed separately + if d.get('children') is not None: + d.pop('children') return d + def __repr__(self, depth=1): + return_string = [str(self.get_dict())] + if self.children: + for child in self.children.values(): + return_string.extend( + ["\n", "--" * depth, + child.__repr__(depth + 1)]) + return "".join(return_string) + @classmethod def new( - cls, - context: metadata_store_pb2.Context, - execution: metadata_store_pb2.Execution, - execution_types_by_id, # dict[int, metadata_store_pb2.ExecutionType] - events_by_execution_id, # dict[int, List[metadata_store_pb2.Event]] - artifacts_by_id, # dict[int, metadata_store_pb2.Artifact] - artifact_types_by_id, # dict[int, metadata_store_pb2.ArtifactType] + cls, + execution: metadata_store_pb2.Execution, + execution_types_by_id: dict[int, metadata_store_pb2.ExecutionType], + events_by_execution_id: dict[int, list[metadata_store_pb2.Event]], + artifacts_by_id: dict[int, metadata_store_pb2.Artifact], + artifact_types_by_id: dict[int, metadata_store_pb2.ArtifactType], + children: Optional[dict[str, KfpTask]], ): + name = execution.custom_properties.get('task_name').string_value + iteration_index = execution.custom_properties.get('iteration_index') + if iteration_index: + name += f'-#{iteration_index.int_value}' execution_type = execution_types_by_id[execution.type_id] params = _parse_parameters(execution) events = events_by_execution_id.get(execution.id, []) @@ -487,13 +507,14 @@ def kfp_artifact(aid: int, output_artifacts.sort(key=lambda a: a.name) return cls( - name=execution.custom_properties.get('task_name').string_value, + name=name, type=execution_type.name, state=execution.last_known_state, inputs=TaskInputs( parameters=params['inputs'], artifacts=input_artifacts), outputs=TaskOutputs( parameters=params['outputs'], artifacts=output_artifacts), + children=children or None, ) @@ -511,6 +532,8 @@ def __init__( port=8080, ) self.mlmd_store = metadata_store.MetadataStore(mlmd_connection_config) + self.dag_type = self.mlmd_store.get_execution_type( + type_name='system.DAGExecution') def get_tasks(self, run_id: str): run_context = self.mlmd_store.get_context_by_type_and_name( @@ -530,11 +553,14 @@ def get_tasks(self, run_id: str): raise Exception( f'Cannot find system.DAGExecution execution "run/{run_id}"') logger.info(f'root_dag: name={root.name} id={root.id}') + return self._get_tasks(root.id, run_context.id) + def _get_tasks(self, dag_id: int, + run_context_id: int) -> dict[str, KfpTask]: # Note, we only need to query by parent_dag_id. However, there is no index # on parent_dag_id. To speed up the query, we also limit the query to the # run context (contexts have index). - filter_query = f'contexts_run.id = {run_context.id} AND custom_properties.parent_dag_id.int_value = {root.id}' + filter_query = f'contexts_run.id = {run_context_id} AND custom_properties.parent_dag_id.int_value = {dag_id}' executions = self.mlmd_store.get_executions( list_options=ListOptions(filter_query=filter_query)) execution_types = self.mlmd_store.get_execution_types_by_id( @@ -546,21 +572,28 @@ def get_tasks(self, run_id: str): for e in events: events_by_execution_id[e.execution_id] = ( events_by_execution_id.get(e.execution_id) or []) + [e] - artifacts = self.mlmd_store.get_artifacts_by_context( - context_id=run_context.id) + artifacts = self.mlmd_store.get_artifacts_by_id( + artifact_ids=[e.artifact_id for e in events]) artifacts_by_id = {a.id: a for a in artifacts} artifact_types = self.mlmd_store.get_artifact_types_by_id( list(set([a.type_id for a in artifacts]))) artifact_types_by_id = {at.id: at for at in artifact_types} _validate_executions_have_task_names(executions) + + def get_children(e: Execution) -> Optional[dict[str, KfpTask]]: + if e.type_id == self.dag_type.id: + children = self._get_tasks(e.id, run_context_id) + return children + return None + tasks = [ KfpTask.new( - context=run_context, execution=e, execution_types_by_id=execution_types_by_id, events_by_execution_id=events_by_execution_id, artifacts_by_id=artifacts_by_id, artifact_types_by_id=artifact_types_by_id, + children=get_children(e), ) for e in executions ] tasks_by_name = {t.name: t for t in tasks} diff --git a/v2/Makefile b/v2/Makefile index c9d3c8a940c..9d8ee62db65 100644 --- a/v2/Makefile +++ b/v2/Makefile @@ -37,6 +37,11 @@ MLMD_VERSION=$$(cat ../third_party/ml-metadata/VERSION) test: mlmd # a MLMD server running in background is required by some tests go test ./... +.PHONY: test-update +test-update: mlmd + # Updating compiled argo YAML golden files... + go test ./compiler --args --update + # make test-watch watches file system changes and rerun go unit tests automatically # install inotifywait by: # sudo apt-get install inotify-tools @@ -67,8 +72,7 @@ image-launcher-dev: # make dev runs developing actions from end to end for KFP v2. # Build images, push them to dev image registry, build backend compiler, compile pipelines and run them. .PHONY: dev -dev: image-dev build/compiler \ - pipeline/v2/hello_world +dev: pipeline/v2/hello_world # TODO(v2): migrate v1 samples to v2. # pipeline/test/two_step \ # pipeline/test/lightweight_python_functions_v2_pipeline @@ -105,16 +109,17 @@ pipeline/v2/pipeline_with_importer: pipeline/test/two_step: pipeline/test/lightweight_python_functions_v2_pipeline: pipeline/test/lightweight_python_functions_v2_with_outputs: +pipeline/core/loop_output/loop_output_v2: # Run a test pipeline using v2 CLI compiler (v2 engine mode). .PHONY: pipeline/% -pipeline/%: +pipeline/%: build/compiler image-dev tmp="$$(mktemp -d)" \ && mkdir -p "$$(dirname $${tmp}/$*)" \ - && dsl-compile-v2 --py $(REPO_ROOT)/samples/$*.py --out "$${tmp}/$*.json" \ - && build/compiler --spec "$${tmp}/$*.json" --driver "$(DEV_IMAGE_PREFIX)driver:latest" --launcher "$(DEV_IMAGE_PREFIX)launcher-v2:latest" > "$${tmp}/$*.yaml" \ - && if which argo >/dev/null; then argo lint "$${tmp}/$*.yaml"; else echo "argo CLI not found, skip linting"; fi \ - && kfp run submit -f "$${tmp}/$*.yaml" -e default -r "$*_$${RANDOM}" + && echo "SDK Compiling to $${tmp}/$*.json" && dsl-compile-v2 --py $(REPO_ROOT)/samples/$*.py --out "$${tmp}/$*.json" \ + && echo "Backend Compiler compiling to $${tmp}/$*.yaml" && build/compiler --spec "$${tmp}/$*.json" --driver "$(DEV_IMAGE_PREFIX)driver:latest" --launcher "$(DEV_IMAGE_PREFIX)launcher-v2:latest" > "$${tmp}/$*.yaml" \ + && echo "Linting..." && if which argo >/dev/null; then argo lint "$${tmp}/$*.yaml"; else echo "argo CLI not found, skip linting"; fi \ + && echo "Running the pipeline..." && kfp run submit -f "$${tmp}/$*.yaml" -e default -r "$*_$${RANDOM}" ###### Common target implementation details ###### diff --git a/v2/cmd/driver/main.go b/v2/cmd/driver/main.go index e4e7a8ceb33..8374911f5b5 100644 --- a/v2/cmd/driver/main.go +++ b/v2/cmd/driver/main.go @@ -39,12 +39,13 @@ const ( var ( // inputs - driverType = flag.String(driverTypeArg, "", "task driver type, one of ROOT_DAG, CONTAINER") + driverType = flag.String(driverTypeArg, "", "task driver type, one of ROOT_DAG, DAG, CONTAINER") pipelineName = flag.String("pipeline_name", "", "pipeline context name") runID = flag.String("run_id", "", "pipeline run uid") componentSpecJson = flag.String("component", "{}", "component spec") taskSpecJson = flag.String("task", "{}", "task spec") runtimeConfigJson = flag.String("runtime_config", "{}", "jobruntime config") + iterationIndex = flag.Int("iteration_index", -1, "iteration index, -1 means not an interation") // container inputs dagExecutionID = flag.Int64("dag_execution_id", 0, "DAG execution ID") @@ -55,8 +56,9 @@ var ( mlmdServerPort = flag.String("mlmd_server_port", "", "MLMD server port") // output paths - executionIDPath = flag.String("execution_id_path", "", "Exeucution ID output path") - executorInputPath = flag.String("executor_input_path", "", "Executor Input output path") + executionIDPath = flag.String("execution_id_path", "", "Exeucution ID output path") + executorInputPath = flag.String("executor_input_path", "", "Executor Input output path") + iterationCountPath = flag.String("iteration_count_path", "", "Iteration Count output path") // output paths, the value stored in the paths will be either 'true' or 'false' cachedDecisionPath = flag.String("cached_decision_path", "", "Cached Decision output path") ) @@ -135,26 +137,32 @@ func drive() (err error) { Component: componentSpec, Task: taskSpec, DAGExecutionID: *dagExecutionID, + IterationIndex: *iterationIndex, } var execution *driver.Execution + var driverErr error switch *driverType { case "ROOT_DAG": options.RuntimeConfig = runtimeConfig - execution, err = driver.RootDAG(ctx, options, client) + execution, driverErr = driver.RootDAG(ctx, options, client) + case "DAG": + execution, driverErr = driver.DAG(ctx, options, client) case "CONTAINER": options.Container = containerSpec - execution, err = driver.Container(ctx, options, client, cacheClient) + execution, driverErr = driver.Container(ctx, options, client, cacheClient) default: err = fmt.Errorf("unknown driverType %s", *driverType) } - if err != nil { - return err - } - if execution.ID != 0 { - glog.Infof("output execution.ID=%v", execution.ID) - if err = writeFile(*executionIDPath, []byte(fmt.Sprint(execution.ID))); err != nil { - return fmt.Errorf("failed to write execution ID to file: %w", err) + if driverErr != nil { + if execution == nil { + return driverErr } + defer func() { + // Override error with driver error, because driver error is more important. + // However, we continue running, because the following code prints debug info that + // may be helpful for figuring out why this failed. + err = driverErr + }() } if execution.ExecutorInput != nil { marshaler := jsonpb.Marshaler{} @@ -163,11 +171,23 @@ func drive() (err error) { return fmt.Errorf("failed to marshal ExecutorInput to JSON: %w", err) } glog.Infof("output ExecutorInput:%s\n", prettyPrint(executorInputJSON)) - if err = writeFile(*executorInputPath, []byte(executorInputJSON)); err != nil { - return fmt.Errorf("failed to write ExecutorInput to file: %w", err) + if *driverType == "CONTAINER" { + if err = writeFile(*executorInputPath, []byte(executorInputJSON)); err != nil { + return fmt.Errorf("failed to write ExecutorInput to file: %w", err) + } + } + } + if execution.ID != 0 { + glog.Infof("output execution.ID=%v", execution.ID) + if err = writeFile(*executionIDPath, []byte(fmt.Sprint(execution.ID))); err != nil { + return fmt.Errorf("failed to write execution ID to file: %w", err) + } + } + if execution.IterationCount != nil { + if err = writeFile(*iterationCountPath, []byte(fmt.Sprintf("%v", *execution.IterationCount))); err != nil { + return fmt.Errorf("failed to write iteration count to file: %w", err) } } - if execution.Cached { if err = writeFile(*cachedDecisionPath, []byte("true")); err != nil { return fmt.Errorf("failed to write cached decision to file: %w", err) diff --git a/v2/compiler/argo.go b/v2/compiler/argo.go index a32de308fec..76124d874ba 100644 --- a/v2/compiler/argo.go +++ b/v2/compiler/argo.go @@ -100,9 +100,9 @@ func Compile(jobArg *pipelinespec.PipelineJob, opts *Options) (*wfapi.Workflow, } // compile - Accept(job, compiler) + err = Accept(job, compiler) - return compiler.wf, nil + return compiler.wf, err } func retrieveLastValidString(s string) string { @@ -152,7 +152,10 @@ const ( paramRuntimeConfig = "runtime-config" // job runtime config, pipeline level inputs paramDAGExecutionID = "dag-execution-id" paramExecutionID = "execution-id" + paramIterationCount = "iteration-count" + paramIterationIndex = "iteration-index" paramExecutorInput = "executor-input" + paramDriverType = "driver-type" paramCachedDecision = "cached-decision" // indicate hit cache or not ) @@ -182,3 +185,8 @@ func outputPath(parameter string) string { func taskOutputParameter(task string, param string) string { return fmt.Sprintf("{{tasks.%s.outputs.parameters.%s}}", task, param) } + +func loopItem() string { + // https://github.com/argoproj/argo-workflows/blob/13bf15309567ff10ec23b6e5cfed846312d6a1ab/examples/loops-sequence.yaml#L20 + return "{{item}}" +} diff --git a/v2/compiler/argo_test.go b/v2/compiler/argo_test.go index bc28311c821..17f5d909d5c 100644 --- a/v2/compiler/argo_test.go +++ b/v2/compiler/argo_test.go @@ -1,6 +1,9 @@ package compiler_test import ( + "flag" + "fmt" + "io/ioutil" "testing" wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" @@ -9,443 +12,57 @@ import ( "github.com/kubeflow/pipelines/v2/compiler" ) +var update = flag.Bool("update", false, "update golden files") + func Test_argo_compiler(t *testing.T) { tests := []struct { - jobPath string - expectedText string + jobPath string // path of input PipelineJob to compile + argoYAMLPath string // path of expected output argo workflow YAML }{ { - jobPath: "testdata/hello_world.json", - expectedText: ` - apiVersion: argoproj.io/v1alpha1 - kind: Workflow - metadata: - annotations: - pipelines.kubeflow.org/v2_pipeline: "true" - creationTimestamp: null - generateName: hello-world- - spec: - arguments: {} - entrypoint: root - podMetadata: - annotations: - pipelines.kubeflow.org/v2_component: "true" - labels: - pipelines.kubeflow.org/v2_component: "true" - serviceAccountName: pipeline-runner - templates: - - container: - args: - - --type - - CONTAINER - - --pipeline_name - - namespace/n1/pipeline/hello-world - - --run_id - - '{{workflow.uid}}' - - --dag_execution_id - - '{{inputs.parameters.dag-execution-id}}' - - --component - - '{{inputs.parameters.component}}' - - --task - - '{{inputs.parameters.task}}' - - --container - - '{{inputs.parameters.container}}' - - --execution_id_path - - '{{outputs.parameters.execution-id.path}}' - - --executor_input_path - - '{{outputs.parameters.executor-input.path}}' - - --cached_decision_path - - '{{outputs.parameters.cached-decision.path}}' - command: - - driver - image: gcr.io/ml-pipeline/kfp-driver:latest - name: "" - resources: {} - inputs: - parameters: - - name: component - - name: task - - name: container - - name: dag-execution-id - metadata: {} - name: system-container-driver - outputs: - parameters: - - name: execution-id - valueFrom: - path: /tmp/outputs/execution-id - - name: executor-input - valueFrom: - path: /tmp/outputs/executor-input - - default: "false" - name: cached-decision - valueFrom: - default: "false" - path: /tmp/outputs/cached-decision - - container: - args: - - sh - - -ec - - | - program_path=$(mktemp) - printf "%s" "$0" > "$program_path" - python3 -u "$program_path" "$@" - - | - def hello_world(text): - print(text) - return text - - import argparse - _parser = argparse.ArgumentParser(prog='Hello world', description='') - _parser.add_argument("--text", dest="text", type=str, required=True, default=argparse.SUPPRESS) - _parsed_args = vars(_parser.parse_args()) - - _outputs = hello_world(**_parsed_args) - - --text - - '{{$.inputs.parameters[''text'']}}' - command: - - /kfp-launcher/launch - - --pipeline_name - - namespace/n1/pipeline/hello-world - - --run_id - - '{{workflow.uid}}' - - --execution_id - - '{{inputs.parameters.execution-id}}' - - --executor_input - - '{{inputs.parameters.executor-input}}' - - --component_spec - - '{{inputs.parameters.component}}' - - --pod_name - - $(KFP_POD_NAME) - - --pod_uid - - $(KFP_POD_UID) - - --mlmd_server_address - - $(METADATA_GRPC_SERVICE_HOST) - - --mlmd_server_port - - $(METADATA_GRPC_SERVICE_PORT) - - -- - env: - - name: KFP_POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: KFP_POD_UID - valueFrom: - fieldRef: - fieldPath: metadata.uid - envFrom: - - configMapRef: - name: metadata-grpc-configmap - optional: true - image: python:3.7 - name: "" - resources: {} - volumeMounts: - - mountPath: /kfp-launcher - name: kfp-launcher - initContainers: - - command: - - launcher-v2 - - --copy - - /kfp-launcher/launch - image: gcr.io/ml-pipeline/kfp-launcher-v2:latest - imagePullPolicy: Always - name: kfp-launcher - resources: {} - volumeMounts: - - mountPath: /kfp-launcher - name: kfp-launcher - inputs: - parameters: - - name: executor-input - - name: execution-id - - name: component - metadata: {} - name: comp-hello-world-container - outputs: {} - volumes: - - emptyDir: {} - name: kfp-launcher - - dag: - tasks: - - arguments: - parameters: - - name: component - value: '{{inputs.parameters.component}}' - - name: task - value: '{{inputs.parameters.task}}' - - name: container - value: '{"image":"python:3.7","command":["sh","-ec","program_path=$(mktemp)\nprintf - \"%s\" \"$0\" \u003e \"$program_path\"\npython3 -u \"$program_path\" - \"$@\"\n","def hello_world(text):\n print(text)\n return text\n\nimport - argparse\n_parser = argparse.ArgumentParser(prog=''Hello world'', description='''')\n_parser.add_argument(\"--text\", - dest=\"text\", type=str, required=True, default=argparse.SUPPRESS)\n_parsed_args - = vars(_parser.parse_args())\n\n_outputs = hello_world(**_parsed_args)\n"],"args":["--text","{{$.inputs.parameters[''text'']}}"]}' - - name: dag-execution-id - value: '{{inputs.parameters.dag-execution-id}}' - name: driver - template: system-container-driver - - arguments: - parameters: - - name: executor-input - value: '{{tasks.driver.outputs.parameters.executor-input}}' - - name: execution-id - value: '{{tasks.driver.outputs.parameters.execution-id}}' - - name: component - value: '{{inputs.parameters.component}}' - dependencies: - - driver - name: container - template: comp-hello-world-container - when: '{{tasks.driver.outputs.parameters.cached-decision}} != true' - inputs: - parameters: - - name: task - - name: dag-execution-id - - default: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}' - name: component - metadata: {} - name: comp-hello-world - outputs: {} - - dag: - tasks: - - arguments: - parameters: - - name: dag-execution-id - value: '{{inputs.parameters.dag-execution-id}}' - - name: task - value: '{"taskInfo":{"name":"hello-world"},"inputs":{"parameters":{"text":{"componentInputParameter":"text"}}},"cachingOptions":{"enableCache":true},"componentRef":{"name":"comp-hello-world"}}' - name: hello-world - template: comp-hello-world - inputs: - parameters: - - name: dag-execution-id - metadata: {} - name: root-dag - outputs: {} - - container: - args: - - --type - - ROOT_DAG - - --pipeline_name - - namespace/n1/pipeline/hello-world - - --run_id - - '{{workflow.uid}}' - - --component - - '{{inputs.parameters.component}}' - - --runtime_config - - '{{inputs.parameters.runtime-config}}' - - --execution_id_path - - '{{outputs.parameters.execution-id.path}}' - command: - - driver - image: gcr.io/ml-pipeline/kfp-driver:latest - name: "" - resources: {} - inputs: - parameters: - - name: component - - name: runtime-config - metadata: {} - name: system-dag-driver - outputs: - parameters: - - name: execution-id - valueFrom: - path: /tmp/outputs/execution-id - - dag: - tasks: - - arguments: - parameters: - - name: component - value: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"dag":{"tasks":{"hello-world":{"taskInfo":{"name":"hello-world"},"inputs":{"parameters":{"text":{"componentInputParameter":"text"}}},"cachingOptions":{"enableCache":true},"componentRef":{"name":"comp-hello-world"}}}}}' - - name: task - value: '{}' - - name: runtime-config - value: '{"parameters":{"text":{"stringValue":"hi there"}}}' - name: driver - template: system-dag-driver - - arguments: - parameters: - - name: dag-execution-id - value: '{{tasks.driver.outputs.parameters.execution-id}}' - dependencies: - - driver - name: dag - template: root-dag - inputs: {} - metadata: {} - name: root - outputs: {} - status: - finishedAt: null - startedAt: null - `, + jobPath: "testdata/hello_world.json", + argoYAMLPath: "testdata/hello_world.yaml", }, { - jobPath: "testdata/importer.json", - expectedText: ` - apiVersion: argoproj.io/v1alpha1 - kind: Workflow - metadata: - annotations: - pipelines.kubeflow.org/v2_pipeline: "true" - creationTimestamp: null - generateName: pipeline-with-importer- - spec: - arguments: {} - entrypoint: root - podMetadata: - annotations: - pipelines.kubeflow.org/v2_component: "true" - labels: - pipelines.kubeflow.org/v2_component: "true" - serviceAccountName: pipeline-runner - templates: - - container: - args: - - --executor_type - - importer - - --task_spec - - '{{inputs.parameters.task}}' - - --component_spec - - '{{inputs.parameters.component}}' - - --importer_spec - - '{{inputs.parameters.importer}}' - - --pipeline_name - - pipeline-with-importer - - --run_id - - '{{workflow.uid}}' - - --pod_name - - $(KFP_POD_NAME) - - --pod_uid - - $(KFP_POD_UID) - - --mlmd_server_address - - $(METADATA_GRPC_SERVICE_HOST) - - --mlmd_server_port - - $(METADATA_GRPC_SERVICE_PORT) - command: - - launcher-v2 - env: - - name: KFP_POD_NAME - valueFrom: - fieldRef: - fieldPath: metadata.name - - name: KFP_POD_UID - valueFrom: - fieldRef: - fieldPath: metadata.uid - envFrom: - - configMapRef: - name: metadata-grpc-configmap - optional: true - image: gcr.io/ml-pipeline/kfp-launcher-v2:latest - name: "" - resources: {} - inputs: - parameters: - - name: task - - default: '{"inputDefinitions":{"parameters":{"uri":{"type":"STRING"}}},"outputDefinitions":{"artifacts":{"artifact":{"artifactType":{"schemaTitle":"system.Dataset"}}}},"executorLabel":"exec-importer"}' - name: component - - default: '{"artifactUri":{"constantValue":{"stringValue":"gs://ml-pipeline-playground/shakespeare1.txt"}},"typeSchema":{"schemaTitle":"system.Dataset"}}' - name: importer - metadata: {} - name: comp-importer - outputs: {} - - dag: - tasks: - - arguments: - parameters: - - name: dag-execution-id - value: '{{inputs.parameters.dag-execution-id}}' - - name: task - value: '{"taskInfo":{"name":"importer"},"inputs":{"parameters":{"uri":{"runtimeValue":{"constantValue":{"stringValue":"gs://ml-pipeline-playground/shakespeare1.txt"}}}}},"cachingOptions":{"enableCache":true},"componentRef":{"name":"comp-importer"}}' - name: importer - template: comp-importer - inputs: - parameters: - - name: dag-execution-id - metadata: {} - name: root-dag - outputs: {} - - container: - args: - - --type - - ROOT_DAG - - --pipeline_name - - pipeline-with-importer - - --run_id - - '{{workflow.uid}}' - - --component - - '{{inputs.parameters.component}}' - - --runtime_config - - '{{inputs.parameters.runtime-config}}' - - --execution_id_path - - '{{outputs.parameters.execution-id.path}}' - command: - - driver - image: gcr.io/ml-pipeline/kfp-driver:latest - name: "" - resources: {} - inputs: - parameters: - - name: component - - name: runtime-config - metadata: {} - name: system-dag-driver - outputs: - parameters: - - name: execution-id - valueFrom: - path: /tmp/outputs/execution-id - - dag: - tasks: - - arguments: - parameters: - - name: component - value: '{"inputDefinitions":{"parameters":{"dataset2":{"type":"STRING"}}},"dag":{"tasks":{"importer":{"taskInfo":{"name":"importer"},"inputs":{"parameters":{"uri":{"runtimeValue":{"constantValue":{"stringValue":"gs://ml-pipeline-playground/shakespeare1.txt"}}}}},"cachingOptions":{"enableCache":true},"componentRef":{"name":"comp-importer"}}}}}' - - name: task - value: '{}' - - name: runtime-config - value: '{}' - name: driver - template: system-dag-driver - - arguments: - parameters: - - name: dag-execution-id - value: '{{tasks.driver.outputs.parameters.execution-id}}' - dependencies: - - driver - name: dag - template: root-dag - inputs: {} - metadata: {} - name: root - outputs: {} - status: - finishedAt: null - startedAt: null - `, + jobPath: "testdata/importer.json", + argoYAMLPath: "testdata/importer.yaml", }, } for _, tt := range tests { - job := load(t, tt.jobPath) - wf, err := compiler.Compile(job, nil) - if err != nil { - t.Error(err) - } - var expected wfapi.Workflow - err = yaml.Unmarshal([]byte(tt.expectedText), &expected) - if err != nil { - t.Fatal(err) - } - if !cmp.Equal(wf, &expected) { - got, err := yaml.Marshal(wf) + t.Run(fmt.Sprintf("%+v", tt), func(t *testing.T) { + + job := load(t, tt.jobPath) + if *update { + wf, err := compiler.Compile(job, nil) + if err != nil { + t.Fatal(err) + } + got, err := yaml.Marshal(wf) + if err != nil { + t.Fatal(err) + } + err = ioutil.WriteFile(tt.argoYAMLPath, got, 0x664) + if err != nil { + t.Fatal(err) + } + } + argoYAML, err := ioutil.ReadFile(tt.argoYAMLPath) if err != nil { t.Fatal(err) } - t.Errorf("compiler.Compile(%s)!=expected, diff: %s\n got:\n%s\n", tt.jobPath, cmp.Diff(&expected, wf), string(got)) - } + wf, err := compiler.Compile(job, nil) + if err != nil { + t.Error(err) + } + var expected wfapi.Workflow + err = yaml.Unmarshal(argoYAML, &expected) + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(wf, &expected) { + t.Errorf("compiler.Compile(%s)!=expected, diff: %s\n", tt.jobPath, cmp.Diff(&expected, wf)) + } + }) } diff --git a/v2/compiler/container.go b/v2/compiler/container.go index 1a2c0f75bc8..0ebe0483e0c 100644 --- a/v2/compiler/container.go +++ b/v2/compiler/container.go @@ -24,10 +24,13 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon } driverTask, driverOutputs := c.containerDriverTask( "driver", - inputParameter(paramComponent), - inputParameter(paramTask), - containerJson, - inputParameter(paramDAGExecutionID), + containerDriverInputs{ + component: inputParameter(paramComponent), + task: inputParameter(paramTask), + container: containerJson, + dagExecutionID: inputParameter(paramDAGExecutionID), + iterationIndex: inputParameter(paramIterationIndex), + }, ) t := containerExecutorTemplate(container, c.launcherImage, c.spec.PipelineInfo.GetName()) // TODO(Bobgy): how can we avoid template name collisions? @@ -42,23 +45,29 @@ func (c *workflowCompiler) Container(name string, component *pipelinespec.Compon {Name: paramDAGExecutionID}, // TODO(Bobgy): reuse the entire 2-step container template {Name: paramComponent, Default: wfapi.AnyStringPtr(componentJson)}, + {Name: paramIterationIndex, Default: wfapi.AnyStringPtr("-1")}, }, }, DAG: &wfapi.DAGTemplate{ Tasks: []wfapi.DAGTask{ *driverTask, - {Name: "container", Template: containerTemplateName, Dependencies: []string{driverTask.Name}, When: taskOutputParameter(driverTask.Name, paramCachedDecision) + " != true", Arguments: wfapi.Arguments{ - Parameters: []wfapi.Parameter{{ - Name: paramExecutorInput, - Value: wfapi.AnyStringPtr(driverOutputs.executorInput), - }, { - Name: paramExecutionID, - Value: wfapi.AnyStringPtr(driverOutputs.executionID), - }, { - Name: paramComponent, - Value: wfapi.AnyStringPtr(inputParameter(paramComponent)), + { + Name: "container", + Template: containerTemplateName, + Dependencies: []string{driverTask.Name}, + When: taskOutputParameter(driverTask.Name, paramCachedDecision) + " != true", + Arguments: wfapi.Arguments{ + Parameters: []wfapi.Parameter{{ + Name: paramExecutorInput, + Value: wfapi.AnyStringPtr(driverOutputs.executorInput), + }, { + Name: paramExecutionID, + Value: wfapi.AnyStringPtr(driverOutputs.executionID), + }, { + Name: paramComponent, + Value: wfapi.AnyStringPtr(inputParameter(paramComponent)), + }}, }}, - }}, }, }, } @@ -72,16 +81,25 @@ type containerDriverOutputs struct { cached string } -func (c *workflowCompiler) containerDriverTask(name, component, task, container, dagExecutionID string) (*wfapi.DAGTask, *containerDriverOutputs) { +type containerDriverInputs struct { + component string + task string + container string + dagExecutionID string + iterationIndex string // optional, when this is an iteration task +} + +func (c *workflowCompiler) containerDriverTask(name string, inputs containerDriverInputs) (*wfapi.DAGTask, *containerDriverOutputs) { dagTask := &wfapi.DAGTask{ Name: name, Template: c.addContainerDriverTemplate(), Arguments: wfapi.Arguments{ Parameters: []wfapi.Parameter{ - {Name: paramComponent, Value: wfapi.AnyStringPtr(component)}, - {Name: paramTask, Value: wfapi.AnyStringPtr(task)}, - {Name: paramContainer, Value: wfapi.AnyStringPtr(container)}, - {Name: paramDAGExecutionID, Value: wfapi.AnyStringPtr(dagExecutionID)}, + {Name: paramComponent, Value: wfapi.AnyStringPtr(inputs.component)}, + {Name: paramTask, Value: wfapi.AnyStringPtr(inputs.task)}, + {Name: paramContainer, Value: wfapi.AnyStringPtr(inputs.container)}, + {Name: paramDAGExecutionID, Value: wfapi.AnyStringPtr(inputs.dagExecutionID)}, + {Name: paramIterationIndex, Value: wfapi.AnyStringPtr(inputs.iterationIndex)}, }, }, } @@ -107,6 +125,7 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { {Name: paramTask}, {Name: paramContainer}, {Name: paramDAGExecutionID}, + {Name: paramIterationIndex}, }, }, Outputs: wfapi.Outputs{ @@ -127,6 +146,7 @@ func (c *workflowCompiler) addContainerDriverTemplate() string { "--component", inputValue(paramComponent), "--task", inputValue(paramTask), "--container", inputValue(paramContainer), + "--iteration_index", inputValue(paramIterationIndex), "--execution_id_path", outputPath(paramExecutionID), "--executor_input_path", outputPath(paramExecutorInput), "--cached_decision_path", outputPath(paramCachedDecision), diff --git a/v2/compiler/dag.go b/v2/compiler/dag.go index d08aa55ada7..7585879ffb1 100644 --- a/v2/compiler/dag.go +++ b/v2/compiler/dag.go @@ -4,15 +4,13 @@ import ( "fmt" wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1" - "github.com/golang/protobuf/jsonpb" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "google.golang.org/protobuf/encoding/protojson" k8score "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/util/intstr" ) func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.ComponentSpec, dagSpec *pipelinespec.DagSpec) error { - if name != "root" { - return fmt.Errorf("SubDAG not implemented yet") - } err := addImplicitDependencies(dagSpec) if err != nil { return err @@ -25,15 +23,35 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen }, DAG: &wfapi.DAGTemplate{}, } - for _, kfpTask := range dagSpec.GetTasks() { - marshaler := jsonpb.Marshaler{} - taskJson, err := marshaler.MarshalToString(kfpTask) + for taskName, kfpTask := range dagSpec.GetTasks() { + bytes, err := protojson.Marshal(kfpTask) if err != nil { - return fmt.Errorf("DAG: marshaling task spec to proto JSON failed: %w", err) + return fmt.Errorf("DAG: marshaling task %q's spec to proto JSON failed: %w", taskName, err) + } + taskJson := string(bytes) + if kfpTask.GetParameterIterator() != nil && kfpTask.GetArtifactIterator() != nil { + return fmt.Errorf("DAG: invalid task %q: parameterIterator and artifactIterator cannot be specified at the same time", taskName) + } + if kfpTask.GetArtifactIterator() != nil { + return fmt.Errorf("DAG: artifact iterator not implemented yet") + } + // For a normal task, we execute the component's template directly. + templateName := c.templateName(kfpTask.GetComponentRef().GetName()) + // For iterator task, we need to use argo withSequence to iterate. + if kfpTask.GetParameterIterator() != nil { + iterator, err := c.iteratorTemplate(taskName, kfpTask, taskJson) + if err != nil { + return fmt.Errorf("DAG: invalid parameter iterator: %w", err) + } + iteratorTemplateName, err := c.addTemplate(iterator, name+"-"+taskName) + if err != nil { + return fmt.Errorf("DAG: %w", err) + } + templateName = iteratorTemplateName } dag.DAG.Tasks = append(dag.DAG.Tasks, wfapi.DAGTask{ Name: kfpTask.GetTaskInfo().GetName(), - Template: c.templateName(kfpTask.GetComponentRef().GetName()), + Template: templateName, Dependencies: kfpTask.GetDependentTasks(), Arguments: wfapi.Arguments{ Parameters: []wfapi.Parameter{ @@ -51,84 +69,170 @@ func (c *workflowCompiler) DAG(name string, componentSpec *pipelinespec.Componen } // TODO(Bobgy): how can we avoid template name collisions? dagName, err := c.addTemplate(dag, name+"-dag") - task := &pipelinespec.PipelineTaskSpec{} + if err != nil { + return fmt.Errorf("DAG: %w", err) + } var runtimeConfig *pipelinespec.PipelineJob_RuntimeConfig - if name == "root" { + if name == rootComponentName { // runtime config is input to the entire pipeline (root DAG) runtimeConfig = c.job.GetRuntimeConfig() } - driverTask, outputs, err := c.dagDriverTask("driver", componentSpec, task, runtimeConfig) + driverTask, driverOutputs, err := c.dagDriverTask("driver", dagDriverInputs{ + dagExecutionID: inputParameter(paramDAGExecutionID), + task: inputParameter(paramTask), + iterationIndex: inputParameter(paramIterationIndex), + component: componentSpec, + runtimeConfig: runtimeConfig, + }) if err != nil { return err } - wrapper := &wfapi.Template{} - wrapper.DAG = &wfapi.DAGTemplate{ - Tasks: []wfapi.DAGTask{ + wrapper := &wfapi.Template{ + Inputs: wfapi.Inputs{ + Parameters: []wfapi.Parameter{ + {Name: paramDAGExecutionID, Default: wfapi.AnyStringPtr("0")}, + {Name: paramTask, Default: wfapi.AnyStringPtr("{}")}, + {Name: paramIterationIndex, Default: wfapi.AnyStringPtr("-1")}, + }, + }, + DAG: &wfapi.DAGTemplate{Tasks: []wfapi.DAGTask{ *driverTask, { - Name: "dag", Template: dagName, Dependencies: []string{"driver"}, - Arguments: wfapi.Arguments{ - Parameters: []wfapi.Parameter{ - {Name: paramDAGExecutionID, Value: wfapi.AnyStringPtr(outputs.executionID)}, - }, - }, + Name: "dag", + Template: dagName, + Dependencies: []string{"driver"}, + Arguments: wfapi.Arguments{Parameters: []wfapi.Parameter{ + {Name: paramDAGExecutionID, Value: wfapi.AnyStringPtr(driverOutputs.executionID)}, + }}, }, - }, + }}, } _, err = c.addTemplate(wrapper, name) return err } +func (c *workflowCompiler) iteratorTemplate(taskName string, task *pipelinespec.PipelineTaskSpec, taskJson string) (tmpl *wfapi.Template, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("generating template for iterator task %q: %w", taskName, err) + } + }() + componentName := task.GetComponentRef().GetName() + component, ok := c.spec.GetComponents()[componentName] + if !ok { + return nil, fmt.Errorf("cannot find component %q in pipeline spec", componentName) + } + driverArgoName := "driver" + driverInputs := dagDriverInputs{ + component: component, + dagExecutionID: inputParameter(paramDAGExecutionID), + task: inputParameter(paramTask), + } + driverArgoTask, driverOutputs, err := c.dagDriverTask(driverArgoName, driverInputs) + if err != nil { + return nil, err + } + componentTemplateName := c.templateName(task.GetComponentRef().GetName()) + iterationCount := intstr.FromString(driverOutputs.iterationCount) + tmpl = &wfapi.Template{ + Inputs: wfapi.Inputs{ + Parameters: []wfapi.Parameter{ + {Name: paramDAGExecutionID}, + {Name: paramTask}, + }, + }, + DAG: &wfapi.DAGTemplate{Tasks: []wfapi.DAGTask{ + *driverArgoTask, + { + Name: "iterations", + Template: componentTemplateName, + Dependencies: []string{driverArgoName}, + Arguments: wfapi.Arguments{ + Parameters: []wfapi.Parameter{{ + Name: paramDAGExecutionID, + Value: wfapi.AnyStringPtr(driverOutputs.executionID), + }, { + Name: paramTask, + Value: wfapi.AnyStringPtr(inputParameter(paramTask)), + }, { + Name: paramIterationIndex, + Value: wfapi.AnyStringPtr(loopItem()), + }}, + }, + WithSequence: &wfapi.Sequence{Count: &iterationCount}, + }, + }}, + } + return tmpl, nil +} + type dagDriverOutputs struct { - executionID string + executionID string + iterationCount string // only returned for iterator DAG drivers +} + +type dagDriverInputs struct { + dagExecutionID string // parent DAG execution ID. optional, the root DAG does not have parent + component *pipelinespec.ComponentSpec + task string // optional, the root DAG does not have task spec. + runtimeConfig *pipelinespec.PipelineJob_RuntimeConfig // optional, only root DAG needs this + iterationIndex string // optional, iterator passes iteration index to iteration tasks } -func (c *workflowCompiler) dagDriverTask(name string, component *pipelinespec.ComponentSpec, task *pipelinespec.PipelineTaskSpec, runtimeConfig *pipelinespec.PipelineJob_RuntimeConfig) (*wfapi.DAGTask, *dagDriverOutputs, error) { +func (c *workflowCompiler) dagDriverTask(name string, inputs dagDriverInputs) (*wfapi.DAGTask, *dagDriverOutputs, error) { + component := inputs.component + runtimeConfig := inputs.runtimeConfig + if inputs.iterationIndex == "" { + inputs.iterationIndex = "-1" + } if component == nil { return nil, nil, fmt.Errorf("dagDriverTask: component must be non-nil") } - marshaler := jsonpb.Marshaler{} - componentJson, err := marshaler.MarshalToString(component) + componentBytes, err := protojson.Marshal(component) if err != nil { return nil, nil, fmt.Errorf("dagDriverTask: marlshaling component spec to proto JSON failed: %w", err) } - taskJson := "{}" - if task != nil { - taskJson, err = marshaler.MarshalToString(task) - if err != nil { - return nil, nil, fmt.Errorf("dagDriverTask: marshaling task spec to proto JSON failed: %w", err) - } - } + componentJson := string(componentBytes) runtimeConfigJson := "{}" if runtimeConfig != nil { - runtimeConfigJson, err = marshaler.MarshalToString(runtimeConfig) + bytes, err := protojson.Marshal(runtimeConfig) if err != nil { return nil, nil, fmt.Errorf("dagDriverTask: marshaling runtime config to proto JSON failed: %w", err) } + runtimeConfigJson = string(bytes) } templateName := c.addDAGDriverTemplate() t := &wfapi.DAGTask{ Name: name, Template: templateName, Arguments: wfapi.Arguments{ - Parameters: []wfapi.Parameter{ - { - Name: paramComponent, - Value: wfapi.AnyStringPtr(componentJson), - }, - { - Name: paramTask, - Value: wfapi.AnyStringPtr(taskJson), - }, - { - Name: paramRuntimeConfig, - Value: wfapi.AnyStringPtr(runtimeConfigJson), - }, - }, + Parameters: []wfapi.Parameter{{ + Name: paramDAGExecutionID, + Value: wfapi.AnyStringPtr(inputs.dagExecutionID), + }, { + Name: paramComponent, + Value: wfapi.AnyStringPtr(componentJson), + }, { + Name: paramTask, + Value: wfapi.AnyStringPtr(inputs.task), + }, { + Name: paramRuntimeConfig, + Value: wfapi.AnyStringPtr(runtimeConfigJson), + }, { + Name: paramIterationIndex, + Value: wfapi.AnyStringPtr(inputs.iterationIndex), + }}, }, } + if runtimeConfig != nil { + t.Arguments.Parameters = append(t.Arguments.Parameters, wfapi.Parameter{ + Name: paramDriverType, + Value: wfapi.AnyStringPtr("ROOT_DAG"), + }) + } return t, &dagDriverOutputs{ - executionID: taskOutputParameter(name, paramExecutionID), + executionID: taskOutputParameter(name, paramExecutionID), + iterationCount: taskOutputParameter(name, paramIterationCount), }, nil } @@ -144,23 +248,32 @@ func (c *workflowCompiler) addDAGDriverTemplate() string { Parameters: []wfapi.Parameter{ {Name: paramComponent}, {Name: paramRuntimeConfig}, + {Name: paramTask}, + {Name: paramDAGExecutionID, Default: wfapi.AnyStringPtr("0")}, + {Name: paramIterationIndex, Default: wfapi.AnyStringPtr("-1")}, + {Name: paramDriverType, Default: wfapi.AnyStringPtr("DAG")}, }, }, Outputs: wfapi.Outputs{ Parameters: []wfapi.Parameter{ {Name: paramExecutionID, ValueFrom: &wfapi.ValueFrom{Path: "/tmp/outputs/execution-id"}}, + {Name: paramIterationCount, ValueFrom: &wfapi.ValueFrom{Path: "/tmp/outputs/iteration-count", Default: wfapi.AnyStringPtr("0")}}, }, }, Container: &k8score.Container{ Image: c.driverImage, Command: []string{"driver"}, Args: []string{ - "--type", "ROOT_DAG", + "--type", inputValue(paramDriverType), "--pipeline_name", c.spec.GetPipelineInfo().GetName(), "--run_id", runID(), + "--dag_execution_id", inputValue(paramDAGExecutionID), "--component", inputValue(paramComponent), + "--task", inputValue(paramTask), "--runtime_config", inputValue(paramRuntimeConfig), + "--iteration_index", inputValue(paramIterationIndex), "--execution_id_path", outputPath(paramExecutionID), + "--iteration_count_path", outputPath(paramIterationCount), }, }, } diff --git a/v2/compiler/testdata/hello_world.yaml b/v2/compiler/testdata/hello_world.yaml new file mode 100644 index 00000000000..be318d5dea9 --- /dev/null +++ b/v2/compiler/testdata/hello_world.yaml @@ -0,0 +1,305 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + annotations: + pipelines.kubeflow.org/v2_pipeline: "true" + creationTimestamp: null + generateName: hello-world- +spec: + arguments: {} + entrypoint: root + podMetadata: + annotations: + pipelines.kubeflow.org/v2_component: "true" + labels: + pipelines.kubeflow.org/v2_component: "true" + serviceAccountName: pipeline-runner + templates: + - container: + args: + - --type + - CONTAINER + - --pipeline_name + - namespace/n1/pipeline/hello-world + - --run_id + - '{{workflow.uid}}' + - --dag_execution_id + - '{{inputs.parameters.dag-execution-id}}' + - --component + - '{{inputs.parameters.component}}' + - --task + - '{{inputs.parameters.task}}' + - --container + - '{{inputs.parameters.container}}' + - --iteration_index + - '{{inputs.parameters.iteration-index}}' + - --execution_id_path + - '{{outputs.parameters.execution-id.path}}' + - --executor_input_path + - '{{outputs.parameters.executor-input.path}}' + - --cached_decision_path + - '{{outputs.parameters.cached-decision.path}}' + command: + - driver + image: gcr.io/ml-pipeline/kfp-driver:latest + name: "" + resources: {} + inputs: + parameters: + - name: component + - name: task + - name: container + - name: dag-execution-id + - name: iteration-index + metadata: {} + name: system-container-driver + outputs: + parameters: + - name: execution-id + valueFrom: + path: /tmp/outputs/execution-id + - name: executor-input + valueFrom: + path: /tmp/outputs/executor-input + - default: "false" + name: cached-decision + valueFrom: + default: "false" + path: /tmp/outputs/cached-decision + - container: + args: + - sh + - -ec + - | + program_path=$(mktemp) + printf "%s" "$0" > "$program_path" + python3 -u "$program_path" "$@" + - | + def hello_world(text): + print(text) + return text + + import argparse + _parser = argparse.ArgumentParser(prog='Hello world', description='') + _parser.add_argument("--text", dest="text", type=str, required=True, default=argparse.SUPPRESS) + _parsed_args = vars(_parser.parse_args()) + + _outputs = hello_world(**_parsed_args) + - --text + - '{{$.inputs.parameters[''text'']}}' + command: + - /kfp-launcher/launch + - --pipeline_name + - namespace/n1/pipeline/hello-world + - --run_id + - '{{workflow.uid}}' + - --execution_id + - '{{inputs.parameters.execution-id}}' + - --executor_input + - '{{inputs.parameters.executor-input}}' + - --component_spec + - '{{inputs.parameters.component}}' + - --pod_name + - $(KFP_POD_NAME) + - --pod_uid + - $(KFP_POD_UID) + - --mlmd_server_address + - $(METADATA_GRPC_SERVICE_HOST) + - --mlmd_server_port + - $(METADATA_GRPC_SERVICE_PORT) + - -- + env: + - name: KFP_POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: KFP_POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid + envFrom: + - configMapRef: + name: metadata-grpc-configmap + optional: true + image: python:3.7 + name: "" + resources: {} + volumeMounts: + - mountPath: /kfp-launcher + name: kfp-launcher + initContainers: + - command: + - launcher-v2 + - --copy + - /kfp-launcher/launch + image: gcr.io/ml-pipeline/kfp-launcher-v2:latest + imagePullPolicy: Always + name: kfp-launcher + resources: {} + volumeMounts: + - mountPath: /kfp-launcher + name: kfp-launcher + inputs: + parameters: + - name: executor-input + - name: execution-id + - name: component + metadata: {} + name: comp-hello-world-container + outputs: {} + volumes: + - emptyDir: {} + name: kfp-launcher + - dag: + tasks: + - arguments: + parameters: + - name: component + value: '{{inputs.parameters.component}}' + - name: task + value: '{{inputs.parameters.task}}' + - name: container + value: '{"image":"python:3.7","command":["sh","-ec","program_path=$(mktemp)\nprintf + \"%s\" \"$0\" \u003e \"$program_path\"\npython3 -u \"$program_path\" + \"$@\"\n","def hello_world(text):\n print(text)\n return text\n\nimport + argparse\n_parser = argparse.ArgumentParser(prog=''Hello world'', description='''')\n_parser.add_argument(\"--text\", + dest=\"text\", type=str, required=True, default=argparse.SUPPRESS)\n_parsed_args + = vars(_parser.parse_args())\n\n_outputs = hello_world(**_parsed_args)\n"],"args":["--text","{{$.inputs.parameters[''text'']}}"]}' + - name: dag-execution-id + value: '{{inputs.parameters.dag-execution-id}}' + - name: iteration-index + value: '{{inputs.parameters.iteration-index}}' + name: driver + template: system-container-driver + - arguments: + parameters: + - name: executor-input + value: '{{tasks.driver.outputs.parameters.executor-input}}' + - name: execution-id + value: '{{tasks.driver.outputs.parameters.execution-id}}' + - name: component + value: '{{inputs.parameters.component}}' + dependencies: + - driver + name: container + template: comp-hello-world-container + when: '{{tasks.driver.outputs.parameters.cached-decision}} != true' + inputs: + parameters: + - name: task + - name: dag-execution-id + - default: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}},"executorLabel":"exec-hello-world"}' + name: component + - default: "-1" + name: iteration-index + metadata: {} + name: comp-hello-world + outputs: {} + - dag: + tasks: + - arguments: + parameters: + - name: dag-execution-id + value: '{{inputs.parameters.dag-execution-id}}' + - name: task + value: '{"taskInfo":{"name":"hello-world"}, "inputs":{"parameters":{"text":{"componentInputParameter":"text"}}}, + "cachingOptions":{"enableCache":true}, "componentRef":{"name":"comp-hello-world"}}' + name: hello-world + template: comp-hello-world + inputs: + parameters: + - name: dag-execution-id + metadata: {} + name: root-dag + outputs: {} + - container: + args: + - --type + - '{{inputs.parameters.driver-type}}' + - --pipeline_name + - namespace/n1/pipeline/hello-world + - --run_id + - '{{workflow.uid}}' + - --dag_execution_id + - '{{inputs.parameters.dag-execution-id}}' + - --component + - '{{inputs.parameters.component}}' + - --task + - '{{inputs.parameters.task}}' + - --runtime_config + - '{{inputs.parameters.runtime-config}}' + - --iteration_index + - '{{inputs.parameters.iteration-index}}' + - --execution_id_path + - '{{outputs.parameters.execution-id.path}}' + - --iteration_count_path + - '{{outputs.parameters.iteration-count.path}}' + command: + - driver + image: gcr.io/ml-pipeline/kfp-driver:latest + name: "" + resources: {} + inputs: + parameters: + - name: component + - name: runtime-config + - name: task + - default: "0" + name: dag-execution-id + - default: "-1" + name: iteration-index + - default: DAG + name: driver-type + metadata: {} + name: system-dag-driver + outputs: + parameters: + - name: execution-id + valueFrom: + path: /tmp/outputs/execution-id + - name: iteration-count + valueFrom: + default: "0" + path: /tmp/outputs/iteration-count + - dag: + tasks: + - arguments: + parameters: + - name: dag-execution-id + value: '{{inputs.parameters.dag-execution-id}}' + - name: component + value: '{"inputDefinitions":{"parameters":{"text":{"type":"STRING"}}}, + "dag":{"tasks":{"hello-world":{"taskInfo":{"name":"hello-world"}, "inputs":{"parameters":{"text":{"componentInputParameter":"text"}}}, + "cachingOptions":{"enableCache":true}, "componentRef":{"name":"comp-hello-world"}}}}}' + - name: task + value: '{{inputs.parameters.task}}' + - name: runtime-config + value: '{"parameters":{"text":{"stringValue":"hi there"}}}' + - name: iteration-index + value: '{{inputs.parameters.iteration-index}}' + - name: driver-type + value: ROOT_DAG + name: driver + template: system-dag-driver + - arguments: + parameters: + - name: dag-execution-id + value: '{{tasks.driver.outputs.parameters.execution-id}}' + dependencies: + - driver + name: dag + template: root-dag + inputs: + parameters: + - default: "0" + name: dag-execution-id + - default: '{}' + name: task + - default: "-1" + name: iteration-index + metadata: {} + name: root + outputs: {} +status: + finishedAt: null + startedAt: null diff --git a/v2/compiler/testdata/importer.yaml b/v2/compiler/testdata/importer.yaml new file mode 100644 index 00000000000..ed7b526dec3 --- /dev/null +++ b/v2/compiler/testdata/importer.yaml @@ -0,0 +1,175 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + annotations: + pipelines.kubeflow.org/v2_pipeline: "true" + creationTimestamp: null + generateName: pipeline-with-importer- +spec: + arguments: {} + entrypoint: root + podMetadata: + annotations: + pipelines.kubeflow.org/v2_component: "true" + labels: + pipelines.kubeflow.org/v2_component: "true" + serviceAccountName: pipeline-runner + templates: + - container: + args: + - --executor_type + - importer + - --task_spec + - '{{inputs.parameters.task}}' + - --component_spec + - '{{inputs.parameters.component}}' + - --importer_spec + - '{{inputs.parameters.importer}}' + - --pipeline_name + - pipeline-with-importer + - --run_id + - '{{workflow.uid}}' + - --pod_name + - $(KFP_POD_NAME) + - --pod_uid + - $(KFP_POD_UID) + - --mlmd_server_address + - $(METADATA_GRPC_SERVICE_HOST) + - --mlmd_server_port + - $(METADATA_GRPC_SERVICE_PORT) + command: + - launcher-v2 + env: + - name: KFP_POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: KFP_POD_UID + valueFrom: + fieldRef: + fieldPath: metadata.uid + envFrom: + - configMapRef: + name: metadata-grpc-configmap + optional: true + image: gcr.io/ml-pipeline/kfp-launcher-v2:latest + name: "" + resources: {} + inputs: + parameters: + - name: task + - default: '{"inputDefinitions":{"parameters":{"uri":{"type":"STRING"}}},"outputDefinitions":{"artifacts":{"artifact":{"artifactType":{"schemaTitle":"system.Dataset"}}}},"executorLabel":"exec-importer"}' + name: component + - default: '{"artifactUri":{"constantValue":{"stringValue":"gs://ml-pipeline-playground/shakespeare1.txt"}},"typeSchema":{"schemaTitle":"system.Dataset"}}' + name: importer + metadata: {} + name: comp-importer + outputs: {} + - dag: + tasks: + - arguments: + parameters: + - name: dag-execution-id + value: '{{inputs.parameters.dag-execution-id}}' + - name: task + value: '{"taskInfo":{"name":"importer"}, "inputs":{"parameters":{"uri":{"runtimeValue":{"constantValue":{"stringValue":"gs://ml-pipeline-playground/shakespeare1.txt"}}}}}, + "cachingOptions":{"enableCache":true}, "componentRef":{"name":"comp-importer"}}' + name: importer + template: comp-importer + inputs: + parameters: + - name: dag-execution-id + metadata: {} + name: root-dag + outputs: {} + - container: + args: + - --type + - '{{inputs.parameters.driver-type}}' + - --pipeline_name + - pipeline-with-importer + - --run_id + - '{{workflow.uid}}' + - --dag_execution_id + - '{{inputs.parameters.dag-execution-id}}' + - --component + - '{{inputs.parameters.component}}' + - --task + - '{{inputs.parameters.task}}' + - --runtime_config + - '{{inputs.parameters.runtime-config}}' + - --iteration_index + - '{{inputs.parameters.iteration-index}}' + - --execution_id_path + - '{{outputs.parameters.execution-id.path}}' + - --iteration_count_path + - '{{outputs.parameters.iteration-count.path}}' + command: + - driver + image: gcr.io/ml-pipeline/kfp-driver:latest + name: "" + resources: {} + inputs: + parameters: + - name: component + - name: runtime-config + - name: task + - default: "0" + name: dag-execution-id + - default: "-1" + name: iteration-index + - default: DAG + name: driver-type + metadata: {} + name: system-dag-driver + outputs: + parameters: + - name: execution-id + valueFrom: + path: /tmp/outputs/execution-id + - name: iteration-count + valueFrom: + default: "0" + path: /tmp/outputs/iteration-count + - dag: + tasks: + - arguments: + parameters: + - name: dag-execution-id + value: '{{inputs.parameters.dag-execution-id}}' + - name: component + value: '{"inputDefinitions":{"parameters":{"dataset2":{"type":"STRING"}}}, + "dag":{"tasks":{"importer":{"taskInfo":{"name":"importer"}, "inputs":{"parameters":{"uri":{"runtimeValue":{"constantValue":{"stringValue":"gs://ml-pipeline-playground/shakespeare1.txt"}}}}}, + "cachingOptions":{"enableCache":true}, "componentRef":{"name":"comp-importer"}}}}}' + - name: task + value: '{{inputs.parameters.task}}' + - name: runtime-config + value: '{}' + - name: iteration-index + value: '{{inputs.parameters.iteration-index}}' + - name: driver-type + value: ROOT_DAG + name: driver + template: system-dag-driver + - arguments: + parameters: + - name: dag-execution-id + value: '{{tasks.driver.outputs.parameters.execution-id}}' + dependencies: + - driver + name: dag + template: root-dag + inputs: + parameters: + - default: "0" + name: dag-execution-id + - default: '{}' + name: task + - default: "-1" + name: iteration-index + metadata: {} + name: root + outputs: {} +status: + finishedAt: null + startedAt: null diff --git a/v2/driver/driver.go b/v2/driver/driver.go index bf9268b5432..95e97808df4 100644 --- a/v2/driver/driver.go +++ b/v2/driver/driver.go @@ -8,18 +8,18 @@ import ( "strings" "github.com/golang/glog" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" pb "github.com/kubeflow/pipelines/third_party/ml-metadata/go/ml_metadata" "github.com/kubeflow/pipelines/v2/cacheutils" "github.com/kubeflow/pipelines/v2/component" "github.com/kubeflow/pipelines/v2/config" "github.com/kubeflow/pipelines/v2/metadata" + "google.golang.org/protobuf/types/known/structpb" "k8s.io/client-go/kubernetes" "k8s.io/client-go/rest" ) -// TODO Move driver to component package +// TODO(capri-xiyue): Move driver to component package // Driver options type Options struct { // required, pipeline context name @@ -28,15 +28,19 @@ type Options struct { RunID string // required, Component spec Component *pipelinespec.ComponentSpec - // required only by root DAG driver + // optional, iteration index. -1 means not an iteration. + IterationIndex int + + // optional, required only by root DAG driver RuntimeConfig *pipelinespec.PipelineJob_RuntimeConfig - // required by non-root drivers - Task *pipelinespec.PipelineTaskSpec - // required only by container driver + Namespace string + + // optional, required by non-root drivers + Task *pipelinespec.PipelineTaskSpec DAGExecutionID int64 - Container *pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec - // required only by root DAG driver - Namespace string + + // optional, required only by container driver + Container *pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec } // Identifying information used for error messages @@ -51,6 +55,9 @@ func (o Options) info() string { if o.DAGExecutionID != 0 { msg = msg + fmt.Sprintf(", dagExecutionID=%v", o.DAGExecutionID) } + if o.IterationIndex >= 0 { + msg = msg + fmt.Sprintf(", iterationIndex=%v", o.IterationIndex) + } if o.RuntimeConfig != nil { msg = msg + ", runtimeConfig" // this only means runtimeConfig is not empty } @@ -61,9 +68,10 @@ func (o Options) info() string { } type Execution struct { - ID int64 - ExecutorInput *pipelinespec.ExecutorInput - Cached bool // only specified when this is a Container execution + ID int64 + ExecutorInput *pipelinespec.ExecutorInput + IterationCount *int // number of iterations, -1 means not an iterator + Cached bool // only specified when this is a Container execution } func RootDAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *Execution, err error) { @@ -151,12 +159,14 @@ func validateRootDAG(opts Options) (err error) { return fmt.Errorf("DAG execution ID is unnecessary") } if opts.Container != nil { - return fmt.Errorf("container spec is unncessary") + return fmt.Errorf("container spec is unnecessary") + } + if opts.IterationIndex >= 0 { + return fmt.Errorf("iteration index is unnecessary") } return nil } -// TODO(Bobgy): 7-17, continue to build CLI args for container driver func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheClient *cacheutils.Client) (execution *Execution, err error) { defer func() { if err != nil { @@ -167,6 +177,11 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl if err != nil { return nil, err } + var iterationIndex *int + if opts.IterationIndex >= 0 { + index := opts.IterationIndex + iterationIndex = &index + } // TODO(Bobgy): there's no need to pass any parameters, because pipeline // and pipeline run context have been created by root DAG driver. pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "") @@ -178,7 +193,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return nil, err } glog.Infof("parent DAG: %+v", dag.Execution) - inputs, err := resolveInputs(ctx, dag, pipeline, opts.Task, mlmd) + inputs, err := resolveInputs(ctx, dag, iterationIndex, pipeline, opts.Task, mlmd) if err != nil { return nil, err } @@ -193,6 +208,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl ecfg.TaskName = opts.Task.GetTaskInfo().GetName() ecfg.ExecutionType = metadata.ContainerExecutionTypeName ecfg.ParentDagID = dag.Execution.GetID() + ecfg.IterationIndex = iterationIndex if opts.Task.GetCachingOptions() != nil && opts.Task.GetCachingOptions().GetEnableCache() { glog.Infof("Task {%s} enables cache", opts.Task.GetTaskInfo().GetName()) @@ -214,7 +230,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl } glog.Infof("Created execution: %s", createdExecution) - if opts.Task.GetCachingOptions() != nil && opts.Task.GetCachingOptions().EnableCache && ecfg.CachedMLMDExecutionID != "" { + if opts.Task.GetCachingOptions().GetEnableCache() && ecfg.CachedMLMDExecutionID != "" { executorOutput, outputArtifacts, err := reuseCachedOutputs(ctx, executorInput, opts.Component.GetOutputDefinitions(), mlmd, ecfg.CachedMLMDExecutionID) if err != nil { return nil, err @@ -239,6 +255,100 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl }, nil } +// TODO(Bobgy): merge DAG driver and container driver, because they are very similar. +func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *Execution, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("driver.DAG(%s) failed: %w", opts.info(), err) + } + }() + err = validateDAG(opts) + if err != nil { + return nil, err + } + var iterationIndex *int + if opts.IterationIndex >= 0 { + index := opts.IterationIndex + iterationIndex = &index + } + // TODO(Bobgy): there's no need to pass any parameters, because pipeline + // and pipeline run context have been created by root DAG driver. + pipeline, err := mlmd.GetPipeline(ctx, opts.PipelineName, opts.RunID, "", "", "") + if err != nil { + return nil, err + } + dag, err := mlmd.GetDAG(ctx, opts.DAGExecutionID) + if err != nil { + return nil, err + } + glog.Infof("parent DAG: %+v", dag.Execution) + inputs, err := resolveInputs(ctx, dag, iterationIndex, pipeline, opts.Task, mlmd) + if err != nil { + return nil, err + } + executorInput := &pipelinespec.ExecutorInput{ + Inputs: inputs, + Outputs: provisionOutputs(pipeline.GetPipelineRoot(), opts.Task.GetTaskInfo().GetName(), opts.Component.GetOutputDefinitions()), + } + execution = &Execution{ExecutorInput: executorInput} + ecfg, err := metadata.GenerateExecutionConfig(executorInput) + if err != nil { + return execution, err + } + ecfg.TaskName = opts.Task.GetTaskInfo().GetName() + ecfg.ExecutionType = metadata.DagExecutionTypeName + ecfg.ParentDagID = dag.Execution.GetID() + ecfg.IterationIndex = iterationIndex + if opts.Task.GetArtifactIterator() != nil { + return execution, fmt.Errorf("ArtifactIterator is not implemented") + } + isIterator := opts.Task.GetParameterIterator() != nil && opts.IterationIndex < 0 + if isIterator { + iterator := opts.Task.GetParameterIterator() + value, ok := executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()] + report := func(err error) error { + return fmt.Errorf("iterating on item input %q failed: %w", iterator.GetItemInput(), err) + } + if !ok { + return execution, report(fmt.Errorf("cannot find input parameter")) + } + items, err := getItems(value) + if err != nil { + return execution, report(err) + } + count := len(items) + ecfg.IterationCount = &count + execution.IterationCount = &count + } + // TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started. + createdExecution, err := mlmd.CreateExecution(ctx, pipeline, ecfg) + if err != nil { + return execution, err + } + glog.Infof("Created execution: %s", createdExecution) + execution.ID = createdExecution.GetID() + return execution, nil +} + +// Get iteration items from a structpb.Value. +// Return value may be +// * a list of JSON serializable structs +// * a list of structpb.Value +func getItems(value *structpb.Value) (items []*structpb.Value, err error) { + switch v := value.GetKind().(type) { + case *structpb.Value_ListValue: + return v.ListValue.GetValues(), nil + case *structpb.Value_StringValue: + listValue := structpb.Value{} + if err = listValue.UnmarshalJSON([]byte(v.StringValue)); err != nil { + return nil, err + } + return listValue.GetListValue().GetValues(), nil + default: + return nil, fmt.Errorf("value of type %T cannot be iterated", v) + } +} + func reuseCachedOutputs(ctx context.Context, executorInput *pipelinespec.ExecutorInput, outputDefinitions *pipelinespec.ComponentOutputsSpec, mlmd *metadata.Client, cachedMLMDExecutionID string) (*pipelinespec.ExecutorOutput, []*metadata.OutputArtifact, error) { cachedMLMDExecutionIDInt64, err := strconv.ParseInt(cachedMLMDExecutionID, 10, 64) if err != nil { @@ -310,6 +420,25 @@ func validateContainer(opts Options) (err error) { err = fmt.Errorf("invalid container driver args: %w", err) } }() + if opts.Container == nil { + return fmt.Errorf("container spec is required") + } + return validateNonRoot(opts) +} + +func validateDAG(opts Options) (err error) { + defer func() { + if err != nil { + err = fmt.Errorf("invalid DAG driver args: %w", err) + } + }() + if opts.Container != nil { + return fmt.Errorf("container spec is unnecessary") + } + return validateNonRoot(opts) +} + +func validateNonRoot(opts Options) error { if opts.PipelineName == "" { return fmt.Errorf("pipeline name is required") } @@ -328,22 +457,52 @@ func validateContainer(opts Options) (err error) { if opts.DAGExecutionID == 0 { return fmt.Errorf("DAG execution ID is required") } - if opts.Container == nil { - return fmt.Errorf("container spec is required") - } return nil } -func resolveInputs(ctx context.Context, dag *metadata.DAG, pipeline *metadata.Pipeline, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (*pipelinespec.ExecutorInput_Inputs, error) { +func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, pipeline *metadata.Pipeline, task *pipelinespec.PipelineTaskSpec, mlmd *metadata.Client) (inputs *pipelinespec.ExecutorInput_Inputs, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("failed to resolve inputs: %w", err) + } + }() inputParams, _, err := dag.Execution.GetParameters() if err != nil { - return nil, fmt.Errorf("failed to resolve inputs: %w", err) + return nil, err } glog.Infof("parent DAG input parameters %+v", inputParams) - inputs := &pipelinespec.ExecutorInput_Inputs{ + inputs = &pipelinespec.ExecutorInput_Inputs{ ParameterValues: make(map[string]*structpb.Value), Artifacts: make(map[string]*pipelinespec.ArtifactList), } + isIterationDriver := iterationIndex != nil + if isIterationDriver { + // resolve inputs for iteration driver is very different + artifacts, err := mlmd.GetInputArtifactsByExecutionID(ctx, dag.Execution.GetID()) + if err != nil { + return nil, err + } + inputs.ParameterValues = inputParams + inputs.Artifacts = artifacts + switch { + case task.GetArtifactIterator() != nil: + return nil, fmt.Errorf("artifact iterator not implemented yet") + case task.GetParameterIterator() != nil: + itemsInput := task.GetParameterIterator().GetItems().GetInputParameter() + items, err := getItems(inputs.ParameterValues[itemsInput]) + if err != nil { + return nil, err + } + if *iterationIndex >= len(items) { + return nil, fmt.Errorf("bug: %v items found, but getting index %v", len(items), *iterationIndex) + } + delete(inputs.ParameterValues, itemsInput) + inputs.ParameterValues[task.GetParameterIterator().GetItemInput()] = items[*iterationIndex] + default: + return nil, fmt.Errorf("bug: iteration_index>=0, but task iterator is empty") + } + return inputs, nil + } // get executions in context on demand var tasksCache map[string]*metadata.Execution getDAGTasks := func() (map[string]*metadata.Execution, error) { @@ -359,7 +518,7 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, pipeline *metadata.Pi } for name, paramSpec := range task.GetInputs().GetParameters() { paramError := func(err error) error { - return fmt.Errorf("failed to resolve input parameter %s with spec %s: %w", name, paramSpec, err) + return fmt.Errorf("resolving input parameter %s with spec %s: %w", name, paramSpec, err) } if paramSpec.GetParameterExpressionSelector() != "" { return nil, paramError(fmt.Errorf("parameter expression selector not implemented yet")) diff --git a/v2/metadata/client.go b/v2/metadata/client.go index 1670d33ea0f..404c7bfc61c 100644 --- a/v2/metadata/client.go +++ b/v2/metadata/client.go @@ -112,10 +112,14 @@ type ExecutionConfig struct { ParentDagID int64 // parent DAG execution ID. Only the root DAG does not have a parent DAG. InputParameters map[string]*structpb.Value InputArtifactIDs map[string][]int64 + IterationIndex *int // Index of the iteration. // ContainerExecution custom properties Image, CachedMLMDExecutionID, FingerPrint string PodName, PodUID, Namespace string + + // DAGExecution custom properties + IterationCount *int // Number of iterations for an iterator DAG. } // InputArtifact is a wrapper around an MLMD artifact used as component inputs. @@ -238,25 +242,27 @@ func (e *Execution) FingerPrint() string { // GetPipeline returns the current pipeline represented by the specified // pipeline name and run ID. -func (c *Client) GetPipeline(ctx context.Context, pipelineName, pipelineRunID, namespace, runResource, pipelineRoot string) (*Pipeline, error) { +func (c *Client) GetPipeline(ctx context.Context, pipelineName, runID, namespace, runResource, pipelineRoot string) (*Pipeline, error) { pipelineContext, err := c.getOrInsertContext(ctx, pipelineName, pipelineContextType, nil) if err != nil { return nil, err } - runMetadata := map[string]*pb.Value{ + glog.Infof("Pipeline Context: %+v", pipelineContext) + metadata := map[string]*pb.Value{ keyNamespace: stringValue(namespace), keyResourceName: stringValue(runResource), // pipeline root of this run - keyPipelineRoot: stringValue(strings.TrimRight(pipelineRoot, "/") + "/" + path.Join(pipelineName, pipelineRunID)), + keyPipelineRoot: stringValue(strings.TrimRight(pipelineRoot, "/") + "/" + path.Join(pipelineName, runID)), } - pipelineRunContext, err := c.getOrInsertContext(ctx, pipelineRunID, pipelineRunContextType, runMetadata) + runContext, err := c.getOrInsertContext(ctx, runID, pipelineRunContextType, metadata) + glog.Infof("Pipeline Run Context: %+v", runContext) if err != nil { return nil, err } err = c.putParentContexts(ctx, &pb.PutParentContextsRequest{ ParentContexts: []*pb.ParentContext{{ - ChildId: pipelineRunContext.Id, + ChildId: runContext.Id, ParentId: pipelineContext.Id, }}, }) @@ -266,7 +272,7 @@ func (c *Client) GetPipeline(ctx context.Context, pipelineName, pipelineRunID, n return &Pipeline{ pipelineCtx: pipelineContext, - pipelineRunCtx: pipelineRunContext, + pipelineRunCtx: runContext, }, nil } @@ -426,6 +432,8 @@ const ( keyInputs = "inputs" keyOutputs = "outputs" keyParentDagID = "parent_dag_id" // Parent DAG Execution ID. + keyIterationIndex = "iteration_index" + keyIterationCount = "iteration_count" ) // CreateExecution creates a new MLMD execution under the specified Pipeline. @@ -455,6 +463,12 @@ func (c *Client) CreateExecution(ctx context.Context, pipeline *Pipeline, config if config.ParentDagID != 0 { e.CustomProperties[keyParentDagID] = intValue(config.ParentDagID) } + if config.IterationIndex != nil { + e.CustomProperties[keyIterationIndex] = intValue(int64(*config.IterationIndex)) + } + if config.IterationCount != nil { + e.CustomProperties[keyIterationCount] = intValue(int64(*config.IterationCount)) + } if config.ExecutionType == ContainerExecutionTypeName { e.CustomProperties[keyPodName] = stringValue(config.PodName) e.CustomProperties[keyPodUID] = stringValue(config.PodUID) @@ -713,6 +727,50 @@ func (c *Client) GetOutputArtifactsByExecutionId(ctx context.Context, executionI return outputArtifactsByName, nil } +func (c *Client) GetInputArtifactsByExecutionID(ctx context.Context, executionID int64) (inputs map[string]*pipelinespec.ArtifactList, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("GetInputArtifactsByExecution(id=%v) failed: %w", executionID, err) + } + }() + eventsReq := &pb.GetEventsByExecutionIDsRequest{ExecutionIds: []int64{executionID}} + eventsRes, err := c.svc.GetEventsByExecutionIDs(ctx, eventsReq) + if err != nil { + return nil, err + } + var artifactIDs []int64 + nameByID := make(map[int64]string) + for _, event := range eventsRes.Events { + if *event.Type == pb.Event_INPUT { + artifactIDs = append(artifactIDs, event.GetArtifactId()) + name, err := getArtifactName(event.Path) + if err != nil { + return nil, err + } + nameByID[event.GetArtifactId()] = name + } + } + artifacts, err := c.GetArtifacts(ctx, artifactIDs) + if err != nil { + return nil, err + } + inputs = make(map[string]*pipelinespec.ArtifactList) + for _, artifact := range artifacts { + name, ok := nameByID[artifact.GetId()] + if !ok { + return nil, fmt.Errorf("failed to get name of artifact with id %v", artifact.GetId()) + } + runtimeArtifact, err := toRuntimeArtifact(artifact) + if err != nil { + return nil, err + } + inputs[name] = &pipelinespec.ArtifactList{ + Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, + } + } + return inputs, nil +} + // Only supports schema titles for now. type schemaObject struct { Title string `yaml:"title"`