diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 8cd02d46508..0c988446383 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -315,6 +315,7 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl // TODO(Bobgy): change execution state to pending, because this is driver, execution hasn't started. createdExecution, err := mlmd.CreateExecution(ctx, pipeline, ecfg) + if err != nil { return execution, err } @@ -350,11 +351,11 @@ func Container(ctx context.Context, opts Options, mlmd *metadata.Client, cacheCl return execution, err } if opts.KubernetesExecutorConfig != nil { - dagTasks, err := mlmd.GetExecutionsInDAG(ctx, dag, pipeline, true) + inputParams, _, err := dag.Execution.GetParameters() if err != nil { - return execution, err + return nil, fmt.Errorf("failed to fetch input parameters from execution: %w", err) } - err = extendPodSpecPatch(podSpec, opts.KubernetesExecutorConfig, dag, dagTasks) + err = extendPodSpecPatch(ctx, podSpec, opts, dag, pipeline, mlmd, inputParams) if err != nil { return execution, err } @@ -553,19 +554,27 @@ func initPodSpecPatch( } // Extends the PodSpec to include Kubernetes-specific executor config. +// inputParams is a map of the input parameter name to a resolvable value. func extendPodSpecPatch( + ctx context.Context, podSpec *k8score.PodSpec, - kubernetesExecutorConfig *kubernetesplatform.KubernetesExecutorConfig, + opts Options, dag *metadata.DAG, - dagTasks map[string]*metadata.Execution, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + inputParams map[string]*structpb.Value, ) error { + kubernetesExecutorConfig := opts.KubernetesExecutorConfig + // Return an error if the podSpec has no user container. if len(podSpec.Containers) == 0 { return fmt.Errorf("failed to patch the pod with kubernetes-specific config due to missing user container: %v", podSpec) } + // Get volume mount information if kubernetesExecutorConfig.GetPvcMount() != nil { - volumeMounts, volumes, err := makeVolumeMountPatch(kubernetesExecutorConfig.GetPvcMount(), dag, dagTasks) + volumeMounts, volumes, err := makeVolumeMountPatch(ctx, opts, kubernetesExecutorConfig.GetPvcMount(), + dag, pipeline, mlmd, inputParams) if err != nil { return fmt.Errorf("failed to extract volume mount info: %w", err) } @@ -594,7 +603,17 @@ func extendPodSpecPatch( // Get node selector information if kubernetesExecutorConfig.GetNodeSelector() != nil { - podSpec.NodeSelector = kubernetesExecutorConfig.GetNodeSelector().GetLabels() + if kubernetesExecutorConfig.GetNodeSelector().GetNodeSelectorJson() != nil { + var nodeSelector map[string]string + err := resolveK8sJsonParameter(ctx, opts, dag, pipeline, mlmd, + kubernetesExecutorConfig.GetNodeSelector().GetNodeSelectorJson(), inputParams, &nodeSelector) + if err != nil { + return fmt.Errorf("failed to resolve node selector: %w", err) + } + podSpec.NodeSelector = nodeSelector + } else { + podSpec.NodeSelector = kubernetesExecutorConfig.GetNodeSelector().GetLabels() + } } if tolerations := kubernetesExecutorConfig.GetTolerations(); tolerations != nil { @@ -604,32 +623,52 @@ func extendPodSpecPatch( for _, toleration := range tolerations { if toleration != nil { - k8sToleration := k8score.Toleration{ - Key: toleration.Key, - Operator: k8score.TolerationOperator(toleration.Operator), - Value: toleration.Value, - Effect: k8score.TaintEffect(toleration.Effect), - TolerationSeconds: toleration.TolerationSeconds, + k8sToleration := &k8score.Toleration{} + if toleration.TolerationJson != nil { + err := resolveK8sJsonParameter(ctx, opts, dag, pipeline, mlmd, + toleration.GetTolerationJson(), inputParams, k8sToleration) + if err != nil { + return fmt.Errorf("failed to resolve toleration: %w", err) + } + } else { + k8sToleration.Key = toleration.Key + k8sToleration.Operator = k8score.TolerationOperator(toleration.Operator) + k8sToleration.Value = toleration.Value + k8sToleration.Effect = k8score.TaintEffect(toleration.Effect) + k8sToleration.TolerationSeconds = toleration.TolerationSeconds } - - k8sTolerations = append(k8sTolerations, k8sToleration) + k8sTolerations = append(k8sTolerations, *k8sToleration) } } - podSpec.Tolerations = k8sTolerations } // Get secret mount information for _, secretAsVolume := range kubernetesExecutorConfig.GetSecretAsVolume() { + var secretName string + if secretAsVolume.SecretNameParameter != nil { + resolvedSecretName, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + secretAsVolume.SecretNameParameter, inputParams) + if err != nil { + return fmt.Errorf("failed to resolve secret name: %w", err) + } + secretName = resolvedSecretName.GetStringValue() + } else { + secretName = secretAsVolume.SecretName + } + optional := secretAsVolume.Optional != nil && *secretAsVolume.Optional secretVolume := k8score.Volume{ - Name: secretAsVolume.GetSecretName(), + Name: secretName, VolumeSource: k8score.VolumeSource{ - Secret: &k8score.SecretVolumeSource{SecretName: secretAsVolume.GetSecretName(), Optional: &optional}, + Secret: &k8score.SecretVolumeSource{ + SecretName: secretName, + Optional: &optional, + }, }, } secretVolumeMount := k8score.VolumeMount{ - Name: secretAsVolume.GetSecretName(), + Name: secretName, MountPath: secretAsVolume.GetMountPath(), } podSpec.Volumes = append(podSpec.Volumes, secretVolume) @@ -647,23 +686,52 @@ func extendPodSpecPatch( }, }, } - secretEnvVar.ValueFrom.SecretKeyRef.LocalObjectReference.Name = secretAsEnv.GetSecretName() + + var secretName string + if secretAsEnv.SecretNameParameter != nil { + resolvedSecretName, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + secretAsEnv.SecretNameParameter, inputParams) + if err != nil { + return fmt.Errorf("failed to resolve secret name: %w", err) + } + secretName = resolvedSecretName.GetStringValue() + } else { + secretName = secretAsEnv.SecretName + } + + secretEnvVar.ValueFrom.SecretKeyRef.LocalObjectReference.Name = secretName podSpec.Containers[0].Env = append(podSpec.Containers[0].Env, secretEnvVar) } } // Get config map mount information for _, configMapAsVolume := range kubernetesExecutorConfig.GetConfigMapAsVolume() { + var configMapName string + if configMapAsVolume.ConfigNameParameter != nil { + resolvedSecretName, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + configMapAsVolume.ConfigNameParameter, inputParams) + if err != nil { + return fmt.Errorf("failed to resolve configmap name: %w", err) + } + configMapName = resolvedSecretName.GetStringValue() + } else { + configMapName = configMapAsVolume.ConfigMapName + } + optional := configMapAsVolume.Optional != nil && *configMapAsVolume.Optional configMapVolume := k8score.Volume{ - Name: configMapAsVolume.GetConfigMapName(), + Name: configMapName, VolumeSource: k8score.VolumeSource{ ConfigMap: &k8score.ConfigMapVolumeSource{ - LocalObjectReference: k8score.LocalObjectReference{Name: configMapAsVolume.GetConfigMapName()}, Optional: &optional}, + LocalObjectReference: k8score.LocalObjectReference{ + Name: configMapName, + }, + Optional: &optional, + }, }, } configMapVolumeMount := k8score.VolumeMount{ - Name: configMapAsVolume.GetConfigMapName(), + Name: configMapName, MountPath: configMapAsVolume.GetMountPath(), } podSpec.Volumes = append(podSpec.Volumes, configMapVolume) @@ -681,14 +749,44 @@ func extendPodSpecPatch( }, }, } - configMapEnvVar.ValueFrom.ConfigMapKeyRef.LocalObjectReference.Name = configMapAsEnv.GetConfigMapName() + + var configMapName string + if configMapAsEnv.ConfigNameParameter != nil { + resolvedSecretName, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + configMapAsEnv.ConfigNameParameter, inputParams) + if err != nil { + return fmt.Errorf("failed to resolve configmap name: %w", err) + } + configMapName = resolvedSecretName.GetStringValue() + } else { + configMapName = configMapAsEnv.ConfigMapName + } + + configMapEnvVar.ValueFrom.ConfigMapKeyRef.LocalObjectReference.Name = configMapName podSpec.Containers[0].Env = append(podSpec.Containers[0].Env, configMapEnvVar) } } // Get image pull secret information for _, imagePullSecret := range kubernetesExecutorConfig.GetImagePullSecret() { - podSpec.ImagePullSecrets = append(podSpec.ImagePullSecrets, k8score.LocalObjectReference{Name: imagePullSecret.GetSecretName()}) + var secretName string + if imagePullSecret.SecretNameParameter != nil { + resolvedSecretName, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + imagePullSecret.SecretNameParameter, inputParams) + if err != nil { + return fmt.Errorf("failed to resolve image pull secret name: %w", err) + } + secretName = resolvedSecretName.GetStringValue() + } else { + secretName = imagePullSecret.SecretName + } + + podSpec.ImagePullSecrets = append( + podSpec.ImagePullSecrets, + k8score.LocalObjectReference{ + Name: secretName, + }, + ) } // Get Kubernetes FieldPath Env information @@ -1039,7 +1137,15 @@ func validateNonRoot(opts Options) error { return nil } -func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, pipeline *metadata.Pipeline, opts Options, mlmd *metadata.Client, expr *expression.Expr) (inputs *pipelinespec.ExecutorInput_Inputs, err error) { +func resolveInputs( + ctx context.Context, + dag *metadata.DAG, + iterationIndex *int, + pipeline *metadata.Pipeline, + opts Options, + mlmd *metadata.Client, + expr *expression.Expr, +) (inputs *pipelinespec.ExecutorInput_Inputs, err error) { defer func() { if err != nil { err = fmt.Errorf("failed to resolve inputs: %w", err) @@ -1217,109 +1323,140 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, // Handle parameters. for name, paramSpec := range task.GetInputs().GetParameters() { - glog.V(4).Infof("name: %v", name) - glog.V(4).Infof("paramSpec: %v", paramSpec) - paramError := func(err error) error { - return fmt.Errorf("resolving input parameter %s with spec %s: %w", name, paramSpec, err) - } - switch t := paramSpec.Kind.(type) { - case *pipelinespec.TaskInputsSpec_InputParameterSpec_ComponentInputParameter: - componentInput := paramSpec.GetComponentInputParameter() - if componentInput == "" { - return nil, paramError(fmt.Errorf("empty component input")) - } - v, ok := inputParams[componentInput] - if !ok { - return nil, paramError(fmt.Errorf("parent DAG does not have input parameter %s", componentInput)) - } - inputs.ParameterValues[name] = v - - // This is the case where the input comes from the output of an upstream task. - case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: - cfg := resolveUpstreamOutputsConfig{ - ctx: ctx, - paramSpec: paramSpec, - dag: dag, - pipeline: pipeline, - mlmd: mlmd, - inputs: inputs, - name: name, - err: paramError, - } - if err := resolveUpstreamParameters(cfg); err != nil { - return nil, err - } + v, err := resolveInputParameter(ctx, dag, pipeline, opts, mlmd, paramSpec, inputParams) + if err != nil { + return nil, err + } + inputs.ParameterValues[name] = v + } - case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: - runtimeValue := paramSpec.GetRuntimeValue() - switch t := runtimeValue.Value.(type) { - case *pipelinespec.ValueOrRuntimeParameter_Constant: - val := runtimeValue.GetConstant() - - switch val.GetStringValue() { - case "{{$.pipeline_job_name}}": - inputs.ParameterValues[name] = structpb.NewStringValue(opts.RunDisplayName) - case "{{$.pipeline_job_resource_name}}": - inputs.ParameterValues[name] = structpb.NewStringValue(opts.RunName) - case "{{$.pipeline_job_uuid}}": - inputs.ParameterValues[name] = structpb.NewStringValue(opts.RunID) - case "{{$.pipeline_task_name}}": - inputs.ParameterValues[name] = structpb.NewStringValue(task.GetTaskInfo().GetName()) - case "{{$.pipeline_task_uuid}}": - inputs.ParameterValues[name] = structpb.NewStringValue(fmt.Sprintf("%d", opts.DAGExecutionID)) - default: - inputs.ParameterValues[name] = val - } + // Handle artifacts. + for name, artifactSpec := range task.GetInputs().GetArtifacts() { + v, err := resolveInputArtifact(ctx, dag, pipeline, mlmd, name, artifactSpec, inputArtifacts, task) + if err != nil { + return nil, err + } + inputs.Artifacts[name] = v + } + // TODO(Bobgy): validate executor inputs match component inputs definition + return inputs, nil +} + +func resolveInputParameter( + ctx context.Context, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + opts Options, + mlmd *metadata.Client, + paramSpec *pipelinespec.TaskInputsSpec_InputParameterSpec, + inputParams map[string]*structpb.Value, +) (*structpb.Value, error) { + glog.V(4).Infof("paramSpec: %v", paramSpec) + paramError := func(err error) error { + return fmt.Errorf("resolving input parameter with spec %s: %w", paramSpec, err) + } + switch t := paramSpec.Kind.(type) { + case *pipelinespec.TaskInputsSpec_InputParameterSpec_ComponentInputParameter: + componentInput := paramSpec.GetComponentInputParameter() + if componentInput == "" { + return nil, paramError(fmt.Errorf("empty component input")) + } + v, ok := inputParams[componentInput] + if !ok { + return nil, paramError(fmt.Errorf("parent DAG does not have input parameter %s", componentInput)) + } + return v, nil + + // This is the case where the input comes from the output of an upstream task. + case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskOutputParameter: + cfg := resolveUpstreamOutputsConfig{ + ctx: ctx, + paramSpec: paramSpec, + dag: dag, + pipeline: pipeline, + mlmd: mlmd, + err: paramError, + } + v, err := resolveUpstreamParameters(cfg) + if err != nil { + return nil, err + } + return v, nil + case *pipelinespec.TaskInputsSpec_InputParameterSpec_RuntimeValue: + runtimeValue := paramSpec.GetRuntimeValue() + switch t := runtimeValue.Value.(type) { + case *pipelinespec.ValueOrRuntimeParameter_Constant: + val := runtimeValue.GetConstant() + var v *structpb.Value + switch val.GetStringValue() { + case "{{$.pipeline_job_name}}": + v = structpb.NewStringValue(opts.RunDisplayName) + case "{{$.pipeline_job_resource_name}}": + v = structpb.NewStringValue(opts.RunName) + case "{{$.pipeline_job_uuid}}": + v = structpb.NewStringValue(opts.RunID) + case "{{$.pipeline_task_name}}": + v = structpb.NewStringValue(opts.Task.GetTaskInfo().GetName()) + case "{{$.pipeline_task_uuid}}": + v = structpb.NewStringValue(fmt.Sprintf("%d", opts.DAGExecutionID)) default: - return nil, paramError(fmt.Errorf("param runtime value spec of type %T not implemented", t)) + v = val } - // TODO(Bobgy): implement the following cases - // case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_: + return v, nil default: - return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t)) + return nil, paramError(fmt.Errorf("param runtime value spec of type %T not implemented", t)) } + // TODO(Bobgy): implement the following cases + // case *pipelinespec.TaskInputsSpec_InputParameterSpec_TaskFinalStatus_: + default: + return nil, paramError(fmt.Errorf("parameter spec of type %T not implemented yet", t)) } +} - // Handle artifacts. - for name, artifactSpec := range task.GetInputs().GetArtifacts() { - glog.V(4).Infof("inputs: %#v", task.GetInputs()) - glog.V(4).Infof("artifacts: %#v", task.GetInputs().GetArtifacts()) - artifactError := func(err error) error { - return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err) - } - switch t := artifactSpec.Kind.(type) { - case *pipelinespec.TaskInputsSpec_InputArtifactSpec_ComponentInputArtifact: - inputArtifactName := artifactSpec.GetComponentInputArtifact() - if inputArtifactName == "" { - return nil, artifactError(fmt.Errorf("component input artifact key is empty")) - } - v, ok := inputArtifacts[inputArtifactName] - if !ok { - return nil, artifactError(fmt.Errorf("parent DAG does not have input artifact %s", inputArtifactName)) - } - inputs.Artifacts[name] = v - - case *pipelinespec.TaskInputsSpec_InputArtifactSpec_TaskOutputArtifact: - cfg := resolveUpstreamOutputsConfig{ - ctx: ctx, - artifactSpec: artifactSpec, - dag: dag, - pipeline: pipeline, - mlmd: mlmd, - inputs: inputs, - name: name, - err: artifactError, - } - if err := resolveUpstreamArtifacts(cfg); err != nil { - return nil, err - } - default: - return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) +func resolveInputArtifact( + ctx context.Context, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + name string, + artifactSpec *pipelinespec.TaskInputsSpec_InputArtifactSpec, + inputArtifacts map[string]*pipelinespec.ArtifactList, + task *pipelinespec.PipelineTaskSpec, +) (*pipelinespec.ArtifactList, error) { + glog.V(4).Infof("inputs: %#v", task.GetInputs()) + glog.V(4).Infof("artifacts: %#v", task.GetInputs().GetArtifacts()) + artifactError := func(err error) error { + return fmt.Errorf("failed to resolve input artifact %s with spec %s: %w", name, artifactSpec, err) + } + switch t := artifactSpec.Kind.(type) { + case *pipelinespec.TaskInputsSpec_InputArtifactSpec_ComponentInputArtifact: + inputArtifactName := artifactSpec.GetComponentInputArtifact() + if inputArtifactName == "" { + return nil, artifactError(fmt.Errorf("component input artifact key is empty")) + } + v, ok := inputArtifacts[inputArtifactName] + if !ok { + return nil, artifactError(fmt.Errorf("parent DAG does not have input artifact %s", inputArtifactName)) + } + return v, nil + case *pipelinespec.TaskInputsSpec_InputArtifactSpec_TaskOutputArtifact: + cfg := resolveUpstreamOutputsConfig{ + ctx: ctx, + artifactSpec: artifactSpec, + dag: dag, + pipeline: pipeline, + mlmd: mlmd, + err: artifactError, + } + artifacts, err := resolveUpstreamArtifacts(cfg) + if err != nil { + return nil, err } + return artifacts, nil + default: + return nil, artifactError(fmt.Errorf("artifact spec of type %T not implemented yet", t)) } - // TODO(Bobgy): validate executor inputs match component inputs definition - return inputs, nil } // getDAGTasks is a recursive function that returns a map of all tasks across all DAGs in the context of nested DAGs. @@ -1380,8 +1517,6 @@ type resolveUpstreamOutputsConfig struct { dag *metadata.DAG pipeline *metadata.Pipeline mlmd *metadata.Client - inputs *pipelinespec.ExecutorInput_Inputs - name string err func(error) error } @@ -1389,35 +1524,35 @@ type resolveUpstreamOutputsConfig struct { // tasks. These tasks can be components/containers, which is relatively // straightforward, or DAGs, in which case, we need to traverse the graph until // we arrive at a component/container (since there can be n nested DAGs). -func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) error { +func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) (*structpb.Value, error) { taskOutput := cfg.paramSpec.GetTaskOutputParameter() glog.V(4).Info("taskOutput: ", taskOutput) producerTaskName := taskOutput.GetProducerTask() if producerTaskName == "" { - return cfg.err(fmt.Errorf("producerTaskName is empty")) + return nil, cfg.err(fmt.Errorf("producerTaskName is empty")) } outputParameterKey := taskOutput.GetOutputParameterKey() if outputParameterKey == "" { - return cfg.err(fmt.Errorf("output parameter key is empty")) + return nil, cfg.err(fmt.Errorf("output parameter key is empty")) } // Get a list of tasks for the current DAG first. - // The reason we use gatDAGTasks instead of mlmd.GetExecutionsInDAG is because the latter does not handle task name collisions in the map which results in a bunch of unhandled edge cases and test failures. + // The reason we use gatDAGTasks instead of mlmd.GetExecutionsInDAG is because the latter does not handle + // task name collisions in the map which results in a bunch of unhandled edge cases and test failures. tasks, err := getDAGTasks(cfg.ctx, cfg.dag, cfg.pipeline, cfg.mlmd, nil) if err != nil { - return cfg.err(err) + return nil, cfg.err(err) } producer, ok := tasks[producerTaskName] if !ok { - return cfg.err(fmt.Errorf("producer task, %v, not in tasks", producerTaskName)) + return nil, cfg.err(fmt.Errorf("producer task, %v, not in tasks", producerTaskName)) } glog.V(4).Info("producer: ", producer) glog.V(4).Infof("tasks: %#v", tasks) currentTask := producer - currentSubTaskMaybeDAG := true // Continue looping until we reach a sub-task that is NOT a DAG. - for currentSubTaskMaybeDAG { + for { glog.V(4).Info("currentTask: ", currentTask.TaskName()) // If the current task is a DAG: if *currentTask.GetExecution().Type == "system.DAGExecution" { @@ -1427,7 +1562,7 @@ func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) error { // and iterate through this loop again. outputParametersCustomProperty, ok := currentTask.GetExecution().GetCustomProperties()["parameter_producer_task"] if !ok { - return cfg.err(fmt.Errorf("task, %v, does not have a parameter_producer_task custom property", currentTask.TaskName())) + return nil, cfg.err(fmt.Errorf("task, %v, does not have a parameter_producer_task custom property", currentTask.TaskName())) } glog.V(4).Infof("outputParametersCustomProperty: %#v", outputParametersCustomProperty) @@ -1438,7 +1573,7 @@ func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) error { outputSpec := &pipelinespec.DagOutputsSpec_DagOutputParameterSpec{} err := protojson.Unmarshal([]byte(value.GetStringValue()), outputSpec) if err != nil { - return err + return nil, err } dagOutputParametersMap[name] = outputSpec } @@ -1475,14 +1610,14 @@ func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) error { } } if !successfulOneOfTask { - return cfg.err(fmt.Errorf("processing OneOf: No successful task found")) + return nil, cfg.err(fmt.Errorf("processing OneOf: No successful task found")) } } } glog.V(4).Infof("SubTaskName from outputParams: %v", subTaskName) glog.V(4).Infof("OutputParameterKey from outputParams: %v", outputParameterKey) if subTaskName == "" { - return cfg.err(fmt.Errorf("producer_subtask not in outputParams")) + return nil, cfg.err(fmt.Errorf("producer_subtask not in outputParams")) } glog.V(4).Infof( "Overriding currentTask, %v, output with currentTask's producer_subtask, %v, output.", @@ -1491,32 +1626,28 @@ func resolveUpstreamParameters(cfg resolveUpstreamOutputsConfig) error { ) currentTask, ok = tasks[subTaskName] if !ok { - return cfg.err(fmt.Errorf("subTaskName, %v, not in tasks", subTaskName)) + return nil, cfg.err(fmt.Errorf("subTaskName, %v, not in tasks", subTaskName)) } - } else { _, outputParametersCustomProperty, err := currentTask.GetParameters() if err != nil { - return err + return nil, err } - cfg.inputs.ParameterValues[cfg.name] = outputParametersCustomProperty[outputParameterKey] - // Exit the loop. - currentSubTaskMaybeDAG = false + // Base case + return outputParametersCustomProperty[outputParameterKey], nil } } - - return nil } // resolveUpstreamArtifacts resolves input artifacts that come from upstream // tasks. These tasks can be components/containers, which is relatively // straightforward, or DAGs, in which case, we need to traverse the graph until // we arrive at a component/container (since there can be n nested DAGs). -func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) error { +func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) (*pipelinespec.ArtifactList, error) { glog.V(4).Infof("artifactSpec: %#v", cfg.artifactSpec) taskOutput := cfg.artifactSpec.GetTaskOutputArtifact() if taskOutput.GetProducerTask() == "" { - return cfg.err(fmt.Errorf("producer task is empty")) + return nil, cfg.err(fmt.Errorf("producer task is empty")) } if taskOutput.GetOutputArtifactKey() == "" { cfg.err(fmt.Errorf("output artifact key is empty")) @@ -1535,9 +1666,9 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) error { glog.V(4).Info("producer: ", producer) currentTask := producer outputArtifactKey := taskOutput.GetOutputArtifactKey() - currentSubTaskMaybeDAG := true + // Continue looping until we reach a sub-task that is NOT a DAG. - for currentSubTaskMaybeDAG { + for { glog.V(4).Info("currentTask: ", currentTask.TaskName()) // If the current task is a DAG: if *currentTask.GetExecution().Type == "system.DAGExecution" { @@ -1547,7 +1678,7 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) error { var outputArtifacts map[string]*pipelinespec.DagOutputsSpec_DagOutputArtifactSpec err := json.Unmarshal([]byte(outputArtifactsCustomProperty.GetStringValue()), &outputArtifacts) if err != nil { - return err + return nil, err } glog.V(4).Infof("Deserialized outputArtifacts: %v", outputArtifacts) // Adding support for multiple output artifacts @@ -1586,15 +1717,12 @@ func resolveUpstreamArtifacts(cfg resolveUpstreamOutputsConfig) error { if err != nil { cfg.err(err) } - cfg.inputs.Artifacts[cfg.name] = &pipelinespec.ArtifactList{ + // Base case + return &pipelinespec.ArtifactList{ Artifacts: []*pipelinespec.RuntimeArtifact{runtimeArtifact}, - } - // Since we are in the base case, escape the loop. - currentSubTaskMaybeDAG = false + }, nil } } - - return nil } func provisionOutputs(pipelineRoot, taskName string, outputsSpec *pipelinespec.ComponentOutputsSpec, outputUriSalt string) *pipelinespec.ExecutorInput_Outputs { @@ -2030,67 +2158,59 @@ func createK8sClient() (*kubernetes.Clientset, error) { return k8sClient, nil } -func makeVolumeMountPatch(pvcMount []*kubernetesplatform.PvcMount, dag *metadata.DAG, dagTasks map[string]*metadata.Execution) ([]k8score.VolumeMount, []k8score.Volume, error) { - if pvcMount == nil { +func makeVolumeMountPatch( + ctx context.Context, + opts Options, + pvcMounts []*kubernetesplatform.PvcMount, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + inputParams map[string]*structpb.Value, +) ([]k8score.VolumeMount, []k8score.Volume, error) { + if pvcMounts == nil { return nil, nil, nil } var volumeMounts []k8score.VolumeMount var volumes []k8score.Volume - for _, vmc := range pvcMount { - // Find mount path - if vmc.GetMountPath() == "" { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: volume mount path not provided") + for _, pvcMount := range pvcMounts { + var pvcNameParameter *kubernetesplatform.InputParameterSpec + if pvcMount.PvcNameParameter != nil { + pvcNameParameter = pvcMount.PvcNameParameter + } else { // Support deprecated fields + if pvcMount.GetConstant() != "" { + pvcNameParameter = strInputParamConstant(pvcMount.GetConstant()) + } else if pvcMount.GetTaskOutputParameter() != nil { + pvcNameParameter = strInputParamTaskOutput( + pvcMount.GetTaskOutputParameter().GetProducerTask(), + pvcMount.GetTaskOutputParameter().GetOutputParameterKey(), + ) + } else if pvcMount.GetComponentInputParameter() != "" { + pvcNameParameter = strInputParamComponent(pvcMount.GetComponentInputParameter()) + } else { + return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: volume name not provided") + } + } + + resolvedPvcName, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + pvcNameParameter, inputParams) + if err != nil { + return nil, nil, fmt.Errorf("failed to resolve pvc name: %w", err) + } + pvcName := resolvedPvcName.GetStringValue() + + pvcMountPath := pvcMount.GetMountPath() + if pvcName == "" || pvcMountPath == "" { + return nil, nil, fmt.Errorf("failed to mount volume, missing mountpath or pvc name") } volumeMount := k8score.VolumeMount{ - MountPath: vmc.GetMountPath(), - } - volume := k8score.Volume{} - - // Volume name may come from three different sources: - // 1) A constant - // 2) As a task output parameter - // 3) As a component input parameter - if vmc.GetConstant() != "" { - volumeMount.Name = vmc.GetConstant() - volume.Name = vmc.GetConstant() - volume.PersistentVolumeClaim = &k8score.PersistentVolumeClaimVolumeSource{ClaimName: vmc.GetConstant()} - } else if vmc.GetTaskOutputParameter() != nil { - if vmc.GetTaskOutputParameter().GetProducerTask() == "" { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: producer task empty") - } - if vmc.GetTaskOutputParameter().GetOutputParameterKey() == "" { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: OutputParameterKey") - } - producer, ok := dagTasks[vmc.GetTaskOutputParameter().GetProducerTask()] - if !ok { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: cannot find producer task %s", vmc.GetTaskOutputParameter().GetProducerTask()) - } - _, outputs, err := producer.GetParameters() - if err != nil { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: cannot get producer output: %w", err) - } - pvcName, ok := outputs[vmc.GetTaskOutputParameter().GetOutputParameterKey()] - if !ok { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: cannot find output parameter %s from producer task %s", vmc.GetTaskOutputParameter().GetOutputParameterKey(), vmc.GetTaskOutputParameter().GetProducerTask()) - } - volumeMount.Name = pvcName.GetStringValue() - volume.Name = pvcName.GetStringValue() - volume.PersistentVolumeClaim = &k8score.PersistentVolumeClaimVolumeSource{ClaimName: pvcName.GetStringValue()} - } else if vmc.GetComponentInputParameter() != "" { - inputParams, _, err := dag.Execution.GetParameters() - if err != nil { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: error getting input parameters") - } - glog.Infof("parent DAG input parameters %+v", inputParams) - pvcName, ok := inputParams[vmc.GetComponentInputParameter()] - if !ok { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount:component input parameters %s doesn't exist", vmc.GetComponentInputParameter()) - } - volumeMount.Name = pvcName.GetStringValue() - volume.Name = pvcName.GetStringValue() - volume.PersistentVolumeClaim = &k8score.PersistentVolumeClaimVolumeSource{ClaimName: pvcName.GetStringValue()} - } else { - return nil, nil, fmt.Errorf("failed to make podSpecPatch: volume mount: volume name not provided") + Name: pvcName, + MountPath: pvcMountPath, + } + volume := k8score.Volume{ + Name: pvcName, + VolumeSource: k8score.VolumeSource{ + PersistentVolumeClaim: &k8score.PersistentVolumeClaimVolumeSource{ClaimName: pvcName}, + }, } volumeMounts = append(volumeMounts, volumeMount) volumes = append(volumes, volume) diff --git a/backend/src/v2/driver/driver_test.go b/backend/src/v2/driver/driver_test.go index c84896c41de..975cc9a6f6b 100644 --- a/backend/src/v2/driver/driver_test.go +++ b/backend/src/v2/driver/driver_test.go @@ -14,6 +14,7 @@ package driver import ( + "context" "encoding/json" "testing" @@ -397,18 +398,18 @@ func Test_makeVolumeMountPatch(t *testing.T) { dag *metadata.DAG dagTasks map[string]*metadata.Execution } - // TODO(lingqinggan): add more test cases for task output parameter and component input. - // Omitted now due to type Execution defined in metadata has unexported fields. + tests := []struct { - name string - args args - wantPath string - wantName string - wantErr bool - errMsg string + name string + args args + wantPath string + wantName string + wantErr bool + inputParams map[string]*structpb.Value + errMsg string }{ { - "pvc name: constant", + "pvc name: constant (deprecated)", args{ []*kubernetesplatform.PvcMount{ { @@ -422,12 +423,61 @@ func Test_makeVolumeMountPatch(t *testing.T) { "/mnt/path", "pvc-name", false, + nil, + "", + }, + { + "pvc name: constant parameter", + args{ + []*kubernetesplatform.PvcMount{ + { + MountPath: "/mnt/path", + PvcReference: &kubernetesplatform.PvcMount_Constant{Constant: "not-used"}, + PvcNameParameter: strInputParamConstant("pvc-name"), + }, + }, + nil, + nil, + }, + "/mnt/path", + "pvc-name", + false, + nil, + "", + }, + { + "pvc name: component input parameter", + args{ + []*kubernetesplatform.PvcMount{ + { + MountPath: "/mnt/path", + PvcNameParameter: strInputParamComponent("param_1"), + }, + }, + nil, + nil, + }, + "/mnt/path", + "pvc-name", + false, + map[string]*structpb.Value{ + "param_1": structpb.NewStringValue("pvc-name"), + }, "", }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - volumeMounts, volumes, err := makeVolumeMountPatch(tt.args.pvcMount, tt.args.dag, tt.args.dagTasks) + volumeMounts, volumes, err := makeVolumeMountPatch( + context.Background(), + Options{}, + tt.args.pvcMount, + tt.args.dag, + nil, + nil, + tt.inputParams, + ) if tt.wantErr { assert.NotNil(t, err) assert.Nil(t, volumeMounts) @@ -538,7 +588,14 @@ func Test_initPodSpecPatch_resourceRequests(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - podSpec, err := initPodSpecPatch(tt.args.container, tt.args.componentSpec, tt.args.executorInput, tt.args.executionID, tt.args.pipelineName, tt.args.runID) + podSpec, err := initPodSpecPatch( + tt.args.container, + tt.args.componentSpec, + tt.args.executorInput, + tt.args.executionID, + tt.args.pipelineName, + tt.args.runID, + ) assert.Nil(t, err) assert.NotEmpty(t, podSpec) podSpecString, err := json.Marshal(podSpec) @@ -557,9 +614,10 @@ func Test_makePodSpecPatch_nodeSelector(t *testing.T) { viper.Set("KFP_POD_NAME", "MyWorkflowPod") viper.Set("KFP_POD_UID", "a1b2c3d4-a1b2-a1b2-a1b2-a1b2c3d4e5f6") tests := []struct { - name string - k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig - expected *k8score.PodSpec + name string + k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig + expected *k8score.PodSpec + inputParams map[string]*structpb.Value }{ { "Valid - NVIDIA GPU on GKE", @@ -578,6 +636,7 @@ func Test_makePodSpecPatch_nodeSelector(t *testing.T) { }, NodeSelector: map[string]string{"cloud.google.com/gke-accelerator": "nvidia-tesla-k80"}, }, + nil, }, { "Valid - operating system and arch", @@ -597,6 +656,29 @@ func Test_makePodSpecPatch_nodeSelector(t *testing.T) { }, NodeSelector: map[string]string{"beta.kubernetes.io/arch": "amd64", "beta.kubernetes.io/os": "linux"}, }, + nil, + }, + { + "Valid - Json Parameter", + &kubernetesplatform.KubernetesExecutorConfig{ + NodeSelector: &kubernetesplatform.NodeSelector{ + NodeSelectorJson: strInputParamComponent("param_1"), + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + NodeSelector: map[string]string{"beta.kubernetes.io/arch": "amd64", "beta.kubernetes.io/os": "linux"}, + }, + map[string]*structpb.Value{ + "param_1": validValueStructOrPanic(map[string]interface{}{ + "beta.kubernetes.io/arch": "amd64", + "beta.kubernetes.io/os": "linux", + }), + }, }, { "Valid - empty", @@ -608,6 +690,27 @@ func Test_makePodSpecPatch_nodeSelector(t *testing.T) { }, }, }, + nil, + }, + { + "Valid - empty json", + &kubernetesplatform.KubernetesExecutorConfig{ + NodeSelector: &kubernetesplatform.NodeSelector{ + NodeSelectorJson: strInputParamComponent("param_1"), + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + // valid node selector, pod can be scheduled on any node + NodeSelector: map[string]string{}, + }, + map[string]*structpb.Value{ + "param_1": validValueStructOrPanic(map[string]interface{}{}), + }, }, } for _, tt := range tests { @@ -617,7 +720,15 @@ func Test_makePodSpecPatch_nodeSelector(t *testing.T) { Name: "main", }, }} - err := extendPodSpecPatch(got, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + got, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + tt.inputParams, + ) assert.Nil(t, err) assert.NotNil(t, got) assert.Equal(t, tt.expected, got) @@ -627,13 +738,14 @@ func Test_makePodSpecPatch_nodeSelector(t *testing.T) { func Test_extendPodSpecPatch_Secret(t *testing.T) { tests := []struct { - name string - k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig - podSpec *k8score.PodSpec - expected *k8score.PodSpec + name string + k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig + podSpec *k8score.PodSpec + expected *k8score.PodSpec + inputParams map[string]*structpb.Value }{ { - "Valid - secret as volume", + "Valid - secret as volume (deprecated)", &kubernetesplatform.KubernetesExecutorConfig{ SecretAsVolume: []*kubernetesplatform.SecretAsVolume{ { @@ -670,15 +782,58 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { }, }, }, + nil, + }, + { + "Valid - secret as volume", + &kubernetesplatform.KubernetesExecutorConfig{ + SecretAsVolume: []*kubernetesplatform.SecretAsVolume{ + { + SecretName: "not-used", + SecretNameParameter: strInputParamConstant("secret1"), + MountPath: "/data/path", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + VolumeMounts: []k8score.VolumeMount{ + { + Name: "secret1", + MountPath: "/data/path", + }, + }, + }, + }, + Volumes: []k8score.Volume{ + { + Name: "secret1", + VolumeSource: k8score.VolumeSource{ + Secret: &k8score.SecretVolumeSource{SecretName: "secret1", Optional: &[]bool{false}[0]}, + }, + }, + }, + }, + nil, }, { "Valid - secret as volume with optional false", &kubernetesplatform.KubernetesExecutorConfig{ SecretAsVolume: []*kubernetesplatform.SecretAsVolume{ { - SecretName: "secret1", - MountPath: "/data/path", - Optional: &[]bool{false}[0], + SecretName: "not-used", + SecretNameParameter: strInputParamConstant("secret1"), + MountPath: "/data/path", + Optional: &[]bool{false}[0], }, }, }, @@ -710,15 +865,17 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { }, }, }, + nil, }, { "Valid - secret as volume with optional true", &kubernetesplatform.KubernetesExecutorConfig{ SecretAsVolume: []*kubernetesplatform.SecretAsVolume{ { - SecretName: "secret1", - MountPath: "/data/path", - Optional: &[]bool{true}[0], + SecretName: "not-used", + SecretNameParameter: strInputParamConstant("secret1"), + MountPath: "/data/path", + Optional: &[]bool{true}[0], }, }, }, @@ -750,6 +907,7 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { }, }, }, + nil, }, { "Valid - secret not specified", @@ -768,9 +926,54 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { }, }, }, + nil, }, { - "Valid - secret as env", + "Valid - secret as volume: component input parameter", + &kubernetesplatform.KubernetesExecutorConfig{ + SecretAsVolume: []*kubernetesplatform.SecretAsVolume{ + { + SecretName: "not-used", + SecretNameParameter: strInputParamComponent("param_1"), + MountPath: "/data/path", + Optional: &[]bool{true}[0], + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + VolumeMounts: []k8score.VolumeMount{ + { + Name: "secret-name", + MountPath: "/data/path", + }, + }, + }, + }, + Volumes: []k8score.Volume{ + { + Name: "secret-name", + VolumeSource: k8score.VolumeSource{ + Secret: &k8score.SecretVolumeSource{SecretName: "secret-name", Optional: &[]bool{true}[0]}, + }, + }, + }, + }, + map[string]*structpb.Value{ + "param_1": structpb.NewStringValue("secret-name"), + }, + }, + { + "Valid - secret as env (deprecated)", &kubernetesplatform.KubernetesExecutorConfig{ SecretAsEnv: []*kubernetesplatform.SecretAsEnv{ { @@ -810,11 +1013,109 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { }, }, }, + nil, + }, + { + "Valid - secret as env", + &kubernetesplatform.KubernetesExecutorConfig{ + SecretAsEnv: []*kubernetesplatform.SecretAsEnv{ + { + SecretName: "not-used", + SecretNameParameter: strInputParamConstant("my-secret"), + KeyToEnv: []*kubernetesplatform.SecretAsEnv_SecretKeyToEnvMap{ + { + SecretKey: "password", + EnvVar: "SECRET_VAR", + }, + }, + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + Env: []k8score.EnvVar{ + { + Name: "SECRET_VAR", + ValueFrom: &k8score.EnvVarSource{ + SecretKeyRef: &k8score.SecretKeySelector{ + k8score.LocalObjectReference{Name: "my-secret"}, + "password", + nil, + }, + }, + }, + }, + }, + }, + }, + nil, + }, + { + "Valid - secret as env: component input parameter", + &kubernetesplatform.KubernetesExecutorConfig{ + SecretAsEnv: []*kubernetesplatform.SecretAsEnv{ + { + SecretNameParameter: strInputParamComponent("param_1"), + KeyToEnv: []*kubernetesplatform.SecretAsEnv_SecretKeyToEnvMap{ + { + SecretKey: "password", + EnvVar: "SECRET_VAR", + }, + }, + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + Env: []k8score.EnvVar{ + { + Name: "SECRET_VAR", + ValueFrom: &k8score.EnvVarSource{ + SecretKeyRef: &k8score.SecretKeySelector{ + k8score.LocalObjectReference{Name: "secret-name"}, + "password", + nil, + }, + }, + }, + }, + }, + }, + }, + map[string]*structpb.Value{ + "param_1": structpb.NewStringValue("secret-name"), + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := extendPodSpecPatch(tt.podSpec, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + tt.podSpec, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + tt.inputParams, + ) assert.Nil(t, err) assert.Equal(t, tt.expected, tt.podSpec) }) @@ -823,13 +1124,14 @@ func Test_extendPodSpecPatch_Secret(t *testing.T) { func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { tests := []struct { - name string - k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig - podSpec *k8score.PodSpec - expected *k8score.PodSpec + name string + k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig + podSpec *k8score.PodSpec + expected *k8score.PodSpec + inputParams map[string]*structpb.Value }{ { - "Valid - config map as volume", + "Valid - config map as volume (deprecated)", &kubernetesplatform.KubernetesExecutorConfig{ ConfigMapAsVolume: []*kubernetesplatform.ConfigMapAsVolume{ { @@ -863,20 +1165,172 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { VolumeSource: k8score.VolumeSource{ ConfigMap: &k8score.ConfigMapVolumeSource{ LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, - Optional: &[]bool{false}[0]}, + Optional: &[]bool{false}[0]}, + }, + }, + }, + }, + nil, + }, + { + "Valid - config map as volume", + &kubernetesplatform.KubernetesExecutorConfig{ + ConfigMapAsVolume: []*kubernetesplatform.ConfigMapAsVolume{ + { + ConfigMapName: "not-used", + ConfigNameParameter: strInputParamConstant("cm1"), + MountPath: "/data/path", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + VolumeMounts: []k8score.VolumeMount{ + { + Name: "cm1", + MountPath: "/data/path", + }, + }, + }, + }, + Volumes: []k8score.Volume{ + { + Name: "cm1", + VolumeSource: k8score.VolumeSource{ + ConfigMap: &k8score.ConfigMapVolumeSource{ + LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, + Optional: &[]bool{false}[0]}, + }, + }, + }, + }, + nil, + }, + { + "Valid - config map as volume with optional false", + &kubernetesplatform.KubernetesExecutorConfig{ + ConfigMapAsVolume: []*kubernetesplatform.ConfigMapAsVolume{ + { + ConfigMapName: "not-used", + ConfigNameParameter: strInputParamConstant("cm1"), + MountPath: "/data/path", + Optional: &[]bool{false}[0], + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + VolumeMounts: []k8score.VolumeMount{ + { + Name: "cm1", + MountPath: "/data/path", + }, + }, + }, + }, + Volumes: []k8score.Volume{ + { + Name: "cm1", + VolumeSource: k8score.VolumeSource{ + ConfigMap: &k8score.ConfigMapVolumeSource{ + LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, + Optional: &[]bool{false}[0]}, + }, + }, + }, + }, + nil, + }, + { + "Valid - config map as volume with optional true", + &kubernetesplatform.KubernetesExecutorConfig{ + ConfigMapAsVolume: []*kubernetesplatform.ConfigMapAsVolume{ + { + ConfigMapName: "not-used", + ConfigNameParameter: strInputParamConstant("cm1"), + MountPath: "/data/path", + Optional: &[]bool{true}[0], + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + VolumeMounts: []k8score.VolumeMount{ + { + Name: "cm1", + MountPath: "/data/path", + }, + }, + }, + }, + Volumes: []k8score.Volume{ + { + Name: "cm1", + VolumeSource: k8score.VolumeSource{ + ConfigMap: &k8score.ConfigMapVolumeSource{ + LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, + Optional: &[]bool{true}[0]}, }, }, }, }, + nil, }, { - "Valid - config map as volume with optional false", + "Valid - config map not specified", + &kubernetesplatform.KubernetesExecutorConfig{}, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + }, + nil, + }, + { + "Valid - config map volume: component input parameter", &kubernetesplatform.KubernetesExecutorConfig{ ConfigMapAsVolume: []*kubernetesplatform.ConfigMapAsVolume{ { - ConfigMapName: "cm1", - MountPath: "/data/path", - Optional: &[]bool{false}[0], + ConfigMapName: "not-used", + ConfigNameParameter: strInputParamComponent("param_1"), + MountPath: "/data/path", + Optional: &[]bool{true}[0], }, }, }, @@ -893,7 +1347,7 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { Name: "main", VolumeMounts: []k8score.VolumeMount{ { - Name: "cm1", + Name: "cm-name", MountPath: "/data/path", }, }, @@ -901,24 +1355,31 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { }, Volumes: []k8score.Volume{ { - Name: "cm1", + Name: "cm-name", VolumeSource: k8score.VolumeSource{ ConfigMap: &k8score.ConfigMapVolumeSource{ - LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, - Optional: &[]bool{false}[0]}, + LocalObjectReference: k8score.LocalObjectReference{Name: "cm-name"}, + Optional: &[]bool{true}[0]}, }, }, }, }, + map[string]*structpb.Value{ + "param_1": structpb.NewStringValue("cm-name"), + }, }, { - "Valid - config map as volume with optional true", + "Valid - config map as env (deprecated)", &kubernetesplatform.KubernetesExecutorConfig{ - ConfigMapAsVolume: []*kubernetesplatform.ConfigMapAsVolume{ + ConfigMapAsEnv: []*kubernetesplatform.ConfigMapAsEnv{ { - ConfigMapName: "cm1", - MountPath: "/data/path", - Optional: &[]bool{true}[0], + ConfigMapName: "my-cm", + KeyToEnv: []*kubernetesplatform.ConfigMapAsEnv_ConfigMapKeyToEnvMap{ + { + ConfigMapKey: "foo", + EnvVar: "CONFIG_MAP_VAR", + }, + }, }, }, }, @@ -933,29 +1394,39 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { Containers: []k8score.Container{ { Name: "main", - VolumeMounts: []k8score.VolumeMount{ + Env: []k8score.EnvVar{ { - Name: "cm1", - MountPath: "/data/path", + Name: "CONFIG_MAP_VAR", + ValueFrom: &k8score.EnvVarSource{ + ConfigMapKeyRef: &k8score.ConfigMapKeySelector{ + k8score.LocalObjectReference{Name: "my-cm"}, + "foo", + nil, + }, + }, }, }, }, }, - Volumes: []k8score.Volume{ + }, + nil, + }, + { + "Valid - config map as env", + &kubernetesplatform.KubernetesExecutorConfig{ + ConfigMapAsEnv: []*kubernetesplatform.ConfigMapAsEnv{ { - Name: "cm1", - VolumeSource: k8score.VolumeSource{ - ConfigMap: &k8score.ConfigMapVolumeSource{ - LocalObjectReference: k8score.LocalObjectReference{Name: "cm1"}, - Optional: &[]bool{true}[0]}, + ConfigMapName: "not-used", + ConfigNameParameter: strInputParamConstant("my-cm"), + KeyToEnv: []*kubernetesplatform.ConfigMapAsEnv_ConfigMapKeyToEnvMap{ + { + ConfigMapKey: "foo", + EnvVar: "CONFIG_MAP_VAR", + }, }, }, }, }, - }, - { - "Valid - config map not specified", - &kubernetesplatform.KubernetesExecutorConfig{}, &k8score.PodSpec{ Containers: []k8score.Container{ { @@ -967,16 +1438,30 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { Containers: []k8score.Container{ { Name: "main", + Env: []k8score.EnvVar{ + { + Name: "CONFIG_MAP_VAR", + ValueFrom: &k8score.EnvVarSource{ + ConfigMapKeyRef: &k8score.ConfigMapKeySelector{ + k8score.LocalObjectReference{Name: "my-cm"}, + "foo", + nil, + }, + }, + }, + }, }, }, }, + nil, }, { - "Valid - config map as env", + "Valid - config map as env: component input parameter", &kubernetesplatform.KubernetesExecutorConfig{ ConfigMapAsEnv: []*kubernetesplatform.ConfigMapAsEnv{ { - ConfigMapName: "my-cm", + ConfigMapName: "not-used", + ConfigNameParameter: strInputParamComponent("param_1"), KeyToEnv: []*kubernetesplatform.ConfigMapAsEnv_ConfigMapKeyToEnvMap{ { ConfigMapKey: "foo", @@ -1002,7 +1487,7 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { Name: "CONFIG_MAP_VAR", ValueFrom: &k8score.EnvVarSource{ ConfigMapKeyRef: &k8score.ConfigMapKeySelector{ - k8score.LocalObjectReference{Name: "my-cm"}, + k8score.LocalObjectReference{Name: "cm-name"}, "foo", nil, }, @@ -1012,11 +1497,22 @@ func Test_extendPodSpecPatch_ConfigMap(t *testing.T) { }, }, }, + map[string]*structpb.Value{ + "param_1": structpb.NewStringValue("cm-name"), + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := extendPodSpecPatch(tt.podSpec, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + tt.podSpec, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + tt.inputParams, + ) assert.Nil(t, err) assert.Equal(t, tt.expected, tt.podSpec) }) @@ -1175,7 +1671,15 @@ func Test_extendPodSpecPatch_EmptyVolumeMount(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := extendPodSpecPatch(tt.podSpec, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + tt.podSpec, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + map[string]*structpb.Value{}, + ) assert.Nil(t, err) assert.Equal(t, tt.expected, tt.podSpec) }) @@ -1184,12 +1688,13 @@ func Test_extendPodSpecPatch_EmptyVolumeMount(t *testing.T) { func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { tests := []struct { - name string - k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig - expected *k8score.PodSpec + name string + k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig + expected *k8score.PodSpec + inputParams map[string]*structpb.Value }{ { - "Valid - SecretA and SecretB", + "Valid - SecretA and SecretB (deprecated)", &kubernetesplatform.KubernetesExecutorConfig{ ImagePullSecret: []*kubernetesplatform.ImagePullSecret{ {SecretName: "SecretA"}, @@ -1207,6 +1712,28 @@ func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { {Name: "SecretB"}, }, }, + nil, + }, + { + "Valid - SecretA and SecretB", + &kubernetesplatform.KubernetesExecutorConfig{ + ImagePullSecret: []*kubernetesplatform.ImagePullSecret{ + {SecretName: "SecretA", SecretNameParameter: strInputParamConstant("SecretA")}, + {SecretName: "SecretB", SecretNameParameter: strInputParamConstant("SecretB")}, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + ImagePullSecrets: []k8score.LocalObjectReference{ + {Name: "SecretA"}, + {Name: "SecretB"}, + }, + }, + nil, }, { "Valid - No ImagePullSecrets", @@ -1220,6 +1747,7 @@ func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { }, }, }, + nil, }, { "Valid - empty", @@ -1231,6 +1759,31 @@ func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { }, }, }, + nil, + }, + { + "Valid - multiple input parameter secret names", + &kubernetesplatform.KubernetesExecutorConfig{ + ImagePullSecret: []*kubernetesplatform.ImagePullSecret{ + {SecretName: "not-used1", SecretNameParameter: strInputParamComponent("param_1")}, + {SecretName: "not-used2", SecretNameParameter: strInputParamComponent("param_2")}, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + ImagePullSecrets: []k8score.LocalObjectReference{ + {Name: "secret-name-1"}, + {Name: "secret-name-2"}, + }, + }, + map[string]*structpb.Value{ + "param_1": structpb.NewStringValue("secret-name-1"), + "param_2": structpb.NewStringValue("secret-name-2"), + }, }, } for _, tt := range tests { @@ -1240,7 +1793,15 @@ func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { Name: "main", }, }} - err := extendPodSpecPatch(got, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + got, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + tt.inputParams, + ) assert.Nil(t, err) assert.NotNil(t, got) assert.Equal(t, tt.expected, got) @@ -1250,9 +1811,10 @@ func Test_extendPodSpecPatch_ImagePullSecrets(t *testing.T) { func Test_extendPodSpecPatch_Tolerations(t *testing.T) { tests := []struct { - name string - k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig - expected *k8score.PodSpec + name string + k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig + expected *k8score.PodSpec + inputParams map[string]*structpb.Value }{ { "Valid - toleration", @@ -1282,6 +1844,7 @@ func Test_extendPodSpecPatch_Tolerations(t *testing.T) { }, }, }, + nil, }, { "Valid - no tolerations", @@ -1293,6 +1856,7 @@ func Test_extendPodSpecPatch_Tolerations(t *testing.T) { }, }, }, + nil, }, { "Valid - only pass operator", @@ -1315,6 +1879,143 @@ func Test_extendPodSpecPatch_Tolerations(t *testing.T) { }, }, }, + nil, + }, + { + "Valid - toleration json - constant", + &kubernetesplatform.KubernetesExecutorConfig{ + Tolerations: []*kubernetesplatform.Toleration{ + { + TolerationJson: structInputParamConstant(map[string]interface{}{ + "key": "key1", + "operator": "Equal", + "value": "value1", + "effect": "NoSchedule", + "tolerationSeconds": nil, + }), + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + Tolerations: []k8score.Toleration{ + { + Key: "key1", + Operator: "Equal", + Value: "value1", + Effect: "NoSchedule", + TolerationSeconds: nil, + }, + }, + }, + nil, + }, + { + "Valid - toleration json - component input", + &kubernetesplatform.KubernetesExecutorConfig{ + Tolerations: []*kubernetesplatform.Toleration{ + { + TolerationJson: strInputParamComponent("param_1"), + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + Tolerations: []k8score.Toleration{ + { + Key: "key1", + Operator: "Equal", + Value: "value1", + Effect: "NoSchedule", + TolerationSeconds: int64_ptr(3600), + }, + }, + }, + map[string]*structpb.Value{ + "param_1": validValueStructOrPanic(map[string]interface{}{ + "key": "key1", + "operator": "Equal", + "value": "value1", + "effect": "NoSchedule", + "tolerationSeconds": 3600, + }), + }, + }, + { + "Valid - toleration json - multiple input types", + &kubernetesplatform.KubernetesExecutorConfig{ + Tolerations: []*kubernetesplatform.Toleration{ + { + TolerationJson: strInputParamComponent("param_1"), + }, + { + TolerationJson: structInputParamConstant(map[string]interface{}{ + "key": "key2", + "operator": "Equal", + "value": "value2", + "effect": "NoSchedule", + "tolerationSeconds": 3602, + }), + // Json takes precedence, these should not get used + Key: "key3", + Value: "value3", + }, + { + Key: "key4", + Operator: "Equal", + Value: "value4", + Effect: "NoSchedule", + TolerationSeconds: int64_ptr(3604), + }, + }, + }, + &k8score.PodSpec{ + Containers: []k8score.Container{ + { + Name: "main", + }, + }, + Tolerations: []k8score.Toleration{ + { + Key: "key1", + Operator: "Equal", + Value: "value1", + Effect: "NoSchedule", + TolerationSeconds: int64_ptr(3601), + }, + { + Key: "key2", + Operator: "Equal", + Value: "value2", + Effect: "NoSchedule", + TolerationSeconds: int64_ptr(3602), + }, + { + Key: "key4", + Operator: "Equal", + Value: "value4", + Effect: "NoSchedule", + TolerationSeconds: int64_ptr(3604), + }, + }, + }, + map[string]*structpb.Value{ + "param_1": validValueStructOrPanic(map[string]interface{}{ + "key": "key1", + "operator": "Equal", + "value": "value1", + "effect": "NoSchedule", + "tolerationSeconds": 3601, + }), + }, }, } for _, tt := range tests { @@ -1324,7 +2025,15 @@ func Test_extendPodSpecPatch_Tolerations(t *testing.T) { Name: "main", }, }} - err := extendPodSpecPatch(got, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + got, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + tt.inputParams, + ) assert.Nil(t, err) assert.NotNil(t, got) assert.Equal(t, tt.expected, got) @@ -1368,7 +2077,8 @@ func Test_extendPodSpecPatch_FieldPathAsEnv(t *testing.T) { &kubernetesplatform.KubernetesExecutorConfig{ SecretAsEnv: []*kubernetesplatform.SecretAsEnv{ { - SecretName: "my-secret", + SecretName: "my-secret", + SecretNameParameter: strInputParamConstant("my-secret"), KeyToEnv: []*kubernetesplatform.SecretAsEnv_SecretKeyToEnvMap{ { SecretKey: "password", @@ -1417,7 +2127,15 @@ func Test_extendPodSpecPatch_FieldPathAsEnv(t *testing.T) { Name: "main", }, }} - err := extendPodSpecPatch(got, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + got, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + map[string]*structpb.Value{}, + ) assert.Nil(t, err) assert.NotNil(t, got) assert.Equal(t, tt.expected, got) @@ -1479,7 +2197,15 @@ func Test_extendPodSpecPatch_ActiveDeadlineSeconds(t *testing.T) { Name: "main", }, }} - err := extendPodSpecPatch(got, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + got, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + map[string]*structpb.Value{}, + ) assert.Nil(t, err) assert.NotNil(t, got) assert.Equal(t, tt.expected, got) @@ -1560,7 +2286,15 @@ func Test_extendPodSpecPatch_ImagePullPolicy(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := extendPodSpecPatch(tt.podSpec, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + tt.podSpec, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + map[string]*structpb.Value{}, + ) assert.Nil(t, err) assert.Equal(t, tt.expected, tt.podSpec) }) @@ -1747,9 +2481,42 @@ func Test_extendPodSpecPatch_GenericEphemeralVolume(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := extendPodSpecPatch(tt.podSpec, tt.k8sExecCfg, nil, nil) + err := extendPodSpecPatch( + context.Background(), + tt.podSpec, + Options{KubernetesExecutorConfig: tt.k8sExecCfg}, + nil, + nil, + nil, + map[string]*structpb.Value{}, + ) assert.Nil(t, err) assert.Equal(t, tt.expected, tt.podSpec) }) } } + +func validValueStructOrPanic(data map[string]interface{}) *structpb.Value { + s, err := structpb.NewStruct(data) + if err != nil { + panic(err) + } + return structpb.NewStructValue(s) +} + +// TODO: combine with strInputParamConstant in util, maybe use template +func structInputParamConstant(value map[string]interface{}) *kubernetesplatform.InputParameterSpec { + return &kubernetesplatform.InputParameterSpec{ + Kind: &kubernetesplatform.InputParameterSpec_RuntimeValue{ + RuntimeValue: &kubernetesplatform.ValueOrRuntimeParameter{ + Value: &kubernetesplatform.ValueOrRuntimeParameter_Constant{ + Constant: validValueStructOrPanic(value), + }, + }, + }, + } +} + +func int64_ptr(val int64) *int64 { + return &val +} diff --git a/backend/src/v2/driver/util.go b/backend/src/v2/driver/util.go index b85e08ffe10..e219b04e928 100644 --- a/backend/src/v2/driver/util.go +++ b/backend/src/v2/driver/util.go @@ -15,8 +15,15 @@ package driver import ( + "context" + "encoding/json" "fmt" + "github.com/golang/protobuf/proto" "github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec" + "github.com/kubeflow/pipelines/backend/src/v2/metadata" + "github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/types/known/structpb" "regexp" ) @@ -76,3 +83,91 @@ func resolvePodSpecInputRuntimeParameter(parameterValue string, executorInput *p } return parameterValue, nil } + +func resolveK8sParameter( + ctx context.Context, + opts Options, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + k8sParamSpec *kubernetesplatform.InputParameterSpec, + inputParams map[string]*structpb.Value, +) (*structpb.Value, error) { + pipelineParamSpec := &pipelinespec.TaskInputsSpec_InputParameterSpec{} + err := convertToProtoMessages(k8sParamSpec, pipelineParamSpec) + if err != nil { + return nil, fmt.Errorf("failed to convert input parameter spec to pipeline spec: %v", err) + } + resolvedSecretName, err := resolveInputParameter(ctx, dag, pipeline, + opts, mlmd, pipelineParamSpec, inputParams) + if err != nil { + return nil, fmt.Errorf("failed to resolve input parameter name: %w", err) + } + return resolvedSecretName, nil +} + +func resolveK8sJsonParameter[k8sResource any]( + ctx context.Context, + opts Options, + dag *metadata.DAG, + pipeline *metadata.Pipeline, + mlmd *metadata.Client, + k8sParamSpec *kubernetesplatform.InputParameterSpec, + inputParams map[string]*structpb.Value, + res *k8sResource, +) error { + resolvedParam, err := resolveK8sParameter(ctx, opts, dag, pipeline, mlmd, + k8sParamSpec, inputParams) + if err != nil { + return fmt.Errorf("failed to resolve k8s parameter: %w", err) + } + paramJSON, err := resolvedParam.GetStructValue().MarshalJSON() + if err != nil { + return err + } + err = json.Unmarshal(paramJSON, &res) + if err != nil { + return fmt.Errorf("failed to unmarshal k8s Resource json "+ + "ensure that k8s Resource json correctly adheres to its respective k8s spec: %w", err) + } + return nil +} + +func convertToProtoMessages(src *kubernetesplatform.InputParameterSpec, dst *pipelinespec.TaskInputsSpec_InputParameterSpec) error { + data, err := protojson.Marshal(proto.MessageV2(src)) + if err != nil { + return err + } + return protojson.Unmarshal(data, proto.MessageV2(dst)) +} + +func strInputParamConstant(value string) *kubernetesplatform.InputParameterSpec { + return &kubernetesplatform.InputParameterSpec{ + Kind: &kubernetesplatform.InputParameterSpec_RuntimeValue{ + RuntimeValue: &kubernetesplatform.ValueOrRuntimeParameter{ + Value: &kubernetesplatform.ValueOrRuntimeParameter_Constant{ + Constant: structpb.NewStringValue(value), + }, + }, + }, + } +} + +func strInputParamComponent(value string) *kubernetesplatform.InputParameterSpec { + return &kubernetesplatform.InputParameterSpec{ + Kind: &kubernetesplatform.InputParameterSpec_ComponentInputParameter{ + ComponentInputParameter: value, + }, + } +} + +func strInputParamTaskOutput(producerTask, outputParamKey string) *kubernetesplatform.InputParameterSpec { + return &kubernetesplatform.InputParameterSpec{ + Kind: &kubernetesplatform.InputParameterSpec_TaskOutputParameter{ + TaskOutputParameter: &kubernetesplatform.InputParameterSpec_TaskOutputParameterSpec{ + ProducerTask: producerTask, + OutputParameterKey: outputParamKey, + }, + }, + } +}