Skip to content

Commit

Permalink
fix(backend): ignore unknown fields for pb json unmarshaling (#11662)
Browse files Browse the repository at this point in the history
Signed-off-by: Humair Khan <HumairAK@users.noreply.github.com>
  • Loading branch information
HumairAK authored Feb 24, 2025
1 parent ebaaf75 commit 9afe23e
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 32 deletions.
5 changes: 2 additions & 3 deletions backend/src/apiserver/server/list_request_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"strconv"
"strings"

"github.com/golang/protobuf/jsonpb"
apiv1beta1 "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client"
apiv2beta1 "github.com/kubeflow/pipelines/backend/api/v2beta1/go_client"
"github.com/kubeflow/pipelines/backend/src/apiserver/common"
Expand Down Expand Up @@ -148,13 +147,13 @@ func parseAPIFilter(encoded string, apiVersion string) (interface{}, error) {
switch apiVersion {
case "v2beta1":
f := &apiv2beta1.Filter{}
if err := jsonpb.UnmarshalString(decoded, f); err != nil {
if err := util.UnmarshalString(decoded, f); err != nil {
return nil, util.NewInvalidInputError("failed to parse valid filter from %q: %v", encoded, err)
}
return f, nil
case "v1beta1":
f := &apiv1beta1.Filter{}
if err := jsonpb.UnmarshalString(decoded, f); err != nil {
if err := util.UnmarshalString(decoded, f); err != nil {
return nil, util.NewInvalidInputError("failed to parse valid filter from %q: %v", encoded, err)
}
return f, nil
Expand Down
12 changes: 11 additions & 1 deletion backend/src/common/util/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@
package util

import (
"encoding/json"
"github.com/golang/protobuf/jsonpb"

"encoding/json"
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
"strings"
)

func UnmarshalJsonOrFail(data string, v interface{}) {
Expand Down Expand Up @@ -63,3 +66,10 @@ func UnmarshalJsonWithError(data interface{}, v *interface{}) error {
}
return nil
}

// UnmarshalString unmarshals a JSON object from s into m.
// Allows unknown fields
func UnmarshalString(s string, m proto.Message) error {
unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true}
return unmarshaler.Unmarshal(strings.NewReader(s), m)
}
3 changes: 1 addition & 2 deletions backend/src/common/util/pipelinerun.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (

"github.com/ghodss/yaml"
"github.com/golang/glog"
"github.com/golang/protobuf/jsonpb"
api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client"
exec "github.com/kubeflow/pipelines/backend/src/common"
swfregister "github.com/kubeflow/pipelines/backend/src/crd/pkg/apis/scheduledworkflow"
Expand Down Expand Up @@ -664,7 +663,7 @@ func collectTaskRunMetricsOrNil(
// ReportRunMetricsRequest as a workaround to hold user's metrics, which is a superset of what
// user can provide.
reportMetricsRequest := new(api.ReportRunMetricsRequest)
err = jsonpb.UnmarshalString(metricsJSON, reportMetricsRequest)
err = UnmarshalString(metricsJSON, reportMetricsRequest)
if err != nil {
// User writes invalid metrics JSON.
// TODO(#1426): report the error back to api server to notify user
Expand Down
3 changes: 1 addition & 2 deletions backend/src/common/util/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ import (
"github.com/argoproj/argo-workflows/v3/workflow/packer"
"github.com/argoproj/argo-workflows/v3/workflow/validate"
"github.com/golang/glog"
"github.com/golang/protobuf/jsonpb"
api "github.com/kubeflow/pipelines/backend/api/v1beta1/go_client"
exec "github.com/kubeflow/pipelines/backend/src/common"
swfregister "github.com/kubeflow/pipelines/backend/src/crd/pkg/apis/scheduledworkflow"
Expand Down Expand Up @@ -517,7 +516,7 @@ func collectNodeMetricsOrNil(runID string, nodeStatus *workflowapi.NodeStatus, r
// ReportRunMetricsRequest as a workaround to hold user's metrics, which is a superset of what
// user can provide.
reportMetricsRequest := new(api.ReportRunMetricsRequest)
err = jsonpb.UnmarshalString(metricsJSON, reportMetricsRequest)
err = UnmarshalString(metricsJSON, reportMetricsRequest)
if err != nil {
// User writes invalid metrics JSON.
// TODO(#1426): report the error back to api server to notify user
Expand Down
47 changes: 28 additions & 19 deletions backend/src/v2/cmd/driver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ import (
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strconv"

"github.com/kubeflow/pipelines/backend/src/v2/cacheutils"
"github.com/kubeflow/pipelines/backend/src/v2/driver"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
"github.com/kubeflow/pipelines/backend/src/common/util"

"github.com/golang/glog"
"github.com/golang/protobuf/jsonpb"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/cacheutils"
"github.com/kubeflow/pipelines/backend/src/v2/config"
"github.com/kubeflow/pipelines/backend/src/v2/driver"
"github.com/kubeflow/pipelines/backend/src/v2/metadata"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
"os"
"path/filepath"
"strconv"
)

const (
Expand Down Expand Up @@ -116,37 +116,33 @@ func drive() (err error) {
}
glog.Infof("input ComponentSpec:%s\n", prettyPrint(*componentSpecJson))
componentSpec := &pipelinespec.ComponentSpec{}
if err := jsonpb.UnmarshalString(*componentSpecJson, componentSpec); err != nil {
if err := util.UnmarshalString(*componentSpecJson, componentSpec); err != nil {
return fmt.Errorf("failed to unmarshal component spec, error: %w\ncomponentSpec: %v", err, prettyPrint(*componentSpecJson))
}
var taskSpec *pipelinespec.PipelineTaskSpec
if *taskSpecJson != "" {
glog.Infof("input TaskSpec:%s\n", prettyPrint(*taskSpecJson))
taskSpec = &pipelinespec.PipelineTaskSpec{}
if err := jsonpb.UnmarshalString(*taskSpecJson, taskSpec); err != nil {
if err := util.UnmarshalString(*taskSpecJson, taskSpec); err != nil {
return fmt.Errorf("failed to unmarshal task spec, error: %w\ntask: %v", err, taskSpecJson)
}
}
glog.Infof("input ContainerSpec:%s\n", prettyPrint(*containerSpecJson))
containerSpec := &pipelinespec.PipelineDeploymentConfig_PipelineContainerSpec{}
if err := jsonpb.UnmarshalString(*containerSpecJson, containerSpec); err != nil {
if err := util.UnmarshalString(*containerSpecJson, containerSpec); err != nil {
return fmt.Errorf("failed to unmarshal container spec, error: %w\ncontainerSpec: %v", err, containerSpecJson)
}
var runtimeConfig *pipelinespec.PipelineJob_RuntimeConfig
if *runtimeConfigJson != "" {
glog.Infof("input RuntimeConfig:%s\n", prettyPrint(*runtimeConfigJson))
runtimeConfig = &pipelinespec.PipelineJob_RuntimeConfig{}
if err := jsonpb.UnmarshalString(*runtimeConfigJson, runtimeConfig); err != nil {
if err := util.UnmarshalString(*runtimeConfigJson, runtimeConfig); err != nil {
return fmt.Errorf("failed to unmarshal runtime config, error: %w\nruntimeConfig: %v", err, runtimeConfigJson)
}
}
var k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig
if *k8sExecConfigJson != "" {
glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJson))
k8sExecCfg = &kubernetesplatform.KubernetesExecutorConfig{}
if err := jsonpb.UnmarshalString(*k8sExecConfigJson, k8sExecCfg); err != nil {
return fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJson)
}
k8sExecCfg, err := parseExecConfigJson(k8sExecConfigJson)
if err != nil {
return err
}
namespace, err := config.InPodNamespace()
if err != nil {
Expand Down Expand Up @@ -204,11 +200,24 @@ func drive() (err error) {
IterationCount: *iterationCountPath,
CachedDecision: *cachedDecisionPath,
Condition: *conditionPath,
PodSpecPatch: *podSpecPatchPath}
PodSpecPatch: *podSpecPatchPath,
}

return handleExecution(execution, *driverType, executionPaths)
}

func parseExecConfigJson(k8sExecConfigJson *string) (*kubernetesplatform.KubernetesExecutorConfig, error) {
var k8sExecCfg *kubernetesplatform.KubernetesExecutorConfig
if *k8sExecConfigJson != "" {
glog.Infof("input kubernetesConfig:%s\n", prettyPrint(*k8sExecConfigJson))
k8sExecCfg = &kubernetesplatform.KubernetesExecutorConfig{}
if err := util.UnmarshalString(*k8sExecConfigJson, k8sExecCfg); err != nil {
return nil, fmt.Errorf("failed to unmarshal Kubernetes config, error: %w\nKubernetesConfig: %v", err, k8sExecConfigJson)
}
}
return k8sExecCfg, nil
}

func handleExecution(execution *driver.Execution, driverType string, executionPaths *ExecutionPaths) error {
if execution.ID != 0 {
glog.Infof("output execution.ID=%v", execution.ID)
Expand Down
44 changes: 44 additions & 0 deletions backend/src/v2/cmd/driver/main_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,55 @@
package main

import (
"github.com/golang/protobuf/proto"
"github.com/kubeflow/pipelines/backend/src/v2/driver"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
"github.com/stretchr/testify/assert"
"os"
"testing"
)

func strPtr(s string) *string {
return &s
}

func TestSpecParsing(t *testing.T) {
tt := []struct {
name string
input *string
expected *kubernetesplatform.KubernetesExecutorConfig
wantErr bool
}{
{
"Valid - test kubecfg value parse.",
strPtr("{\"imagePullSecret\":[{\"secret_name\":\"value1\"}]}"),
&kubernetesplatform.KubernetesExecutorConfig{
ImagePullSecret: []*kubernetesplatform.ImagePullSecret{
{SecretName: "value1"},
},
},
false,
},
{
"Valid - test kubecfg value ignores unknown field.",
strPtr("{\"imagePullSecret\":[{\"secret_name\":\"value1\"}], \"unknown_field\": \"something\"}"),
&kubernetesplatform.KubernetesExecutorConfig{
ImagePullSecret: []*kubernetesplatform.ImagePullSecret{
{SecretName: "value1"},
},
},
false,
},
}

for _, tc := range tt {
t.Logf("Running test case: %s", tc.name)
cfg, err := parseExecConfigJson(tc.input)
assert.Equal(t, tc.wantErr, err != nil)
assert.True(t, proto.Equal(tc.expected, cfg))
}
}

func Test_handleExecutionContainer(t *testing.T) {
execution := &driver.Execution{}

Expand Down
4 changes: 2 additions & 2 deletions backend/src/v2/compiler/argocompiler/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ package argocompiler

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/common/util"
"os"
"strconv"
"strings"

wfapi "github.com/argoproj/argo-workflows/v3/pkg/apis/workflow/v1alpha1"
"github.com/golang/glog"
"github.com/golang/protobuf/jsonpb"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/component"
"github.com/kubeflow/pipelines/kubernetes_platform/go/kubernetesplatform"
Expand Down Expand Up @@ -419,7 +419,7 @@ func (c *workflowCompiler) addContainerExecutorTemplate(refName string) string {

if kubernetesConfigParam != nil {
k8sExecCfg := &kubernetesplatform.KubernetesExecutorConfig{}
if err := jsonpb.UnmarshalString(string(*kubernetesConfigParam.Value), k8sExecCfg); err == nil {
if err := util.UnmarshalString(string(*kubernetesConfigParam.Value), k8sExecCfg); err == nil {
extendPodMetadata(&executor.Metadata, k8sExecCfg)
}
}
Expand Down
3 changes: 2 additions & 1 deletion backend/src/v2/compiler/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ package compiler
import (
"bytes"
"fmt"
"github.com/kubeflow/pipelines/backend/src/common/util"
"sort"

"github.com/golang/protobuf/jsonpb"
Expand Down Expand Up @@ -187,7 +188,7 @@ func GetPipelineSpec(job *pipelinespec.PipelineJob) (*pipelinespec.PipelineSpec,
return nil, fmt.Errorf("failed marshal pipeline spec to json: %w", err)
}
spec := &pipelinespec.PipelineSpec{}
if err := jsonpb.UnmarshalString(json, spec); err != nil {
if err := util.UnmarshalString(json, spec); err != nil {
return nil, fmt.Errorf("failed to parse pipeline spec: %v", err)
}
return spec, nil
Expand Down
4 changes: 2 additions & 2 deletions backend/src/v2/compiler/visitor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ package compiler_test

import (
"fmt"
"github.com/kubeflow/pipelines/backend/src/common/util"
"os"
"testing"

"github.com/golang/protobuf/jsonpb"
"github.com/google/go-cmp/cmp"
"github.com/kubeflow/pipelines/api/v2alpha1/go/pipelinespec"
"github.com/kubeflow/pipelines/backend/src/v2/compiler"
Expand Down Expand Up @@ -93,7 +93,7 @@ func load(t *testing.T, path string) *pipelinespec.PipelineJob {
}
json := string(content)
job := &pipelinespec.PipelineJob{}
if err := jsonpb.UnmarshalString(json, job); err != nil {
if err := util.UnmarshalString(json, job); err != nil {
t.Errorf("Failed to parse pipeline job, error: %s, job: %v", err, json)
}
return job
Expand Down

0 comments on commit 9afe23e

Please sign in to comment.