Skip to content

Commit

Permalink
Add configuration validation after it is loaded (#175)
Browse files Browse the repository at this point in the history
* Remove unused previous iteration of configuration validation

* Add configuration validation after it is loaded

* Expand comment in Validate function

* Add tests for validation function

* Fix false positive validation constraint for Required Together

* Show the option as required in the CLI description

* Add missing helper function to create Configuration structs

* Panic if the user attempts to make required a boolean field

* Mark flags as required if they are
  • Loading branch information
shackra authored Jul 16, 2024
1 parent f0fdc2b commit c4d7f74
Show file tree
Hide file tree
Showing 7 changed files with 257 additions and 68 deletions.
13 changes: 13 additions & 0 deletions pkg/cli/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2"
v1 "github.com/conductorone/baton-sdk/pb/c1/connector_wrapper/v1"
"github.com/conductorone/baton-sdk/pkg/connectorrunner"
"github.com/conductorone/baton-sdk/pkg/field"
"github.com/conductorone/baton-sdk/pkg/logging"
"github.com/conductorone/baton-sdk/pkg/types"
"github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap"
Expand All @@ -28,10 +29,16 @@ func MakeMainCommand(
ctx context.Context,
name string,
v *viper.Viper,
confschema field.Configuration,
getconnector GetConnectorFunc,
opts ...connectorrunner.Option,
) func(*cobra.Command, []string) error {
return func(*cobra.Command, []string) error {
// validate required fields and relationship constraints
if err := field.Validate(confschema, v); err != nil {
return err
}

runCtx, err := initLogger(
ctx,
name,
Expand Down Expand Up @@ -162,9 +169,15 @@ func MakeGRPCServerCommand(
ctx context.Context,
name string,
v *viper.Viper,
confschema field.Configuration,
getconnector GetConnectorFunc,
) func(*cobra.Command, []string) error {
return func(*cobra.Command, []string) error {
// validate required fields and relationship constraints
if err := field.Validate(confschema, v); err != nil {
return err
}

runCtx, err := initLogger(
ctx,
name,
Expand Down
21 changes: 19 additions & 2 deletions pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func DefineConfiguration(
Short: connectorName,
SilenceErrors: true,
SilenceUsage: true,
RunE: cli.MakeMainCommand(ctx, connectorName, v, connector, options...),
RunE: cli.MakeMainCommand(ctx, connectorName, v, schema, connector, options...),
}

// add options to the main command
Expand Down Expand Up @@ -114,6 +114,23 @@ func DefineConfiguration(
)
}
}

// mark required
if field.Required {
if field.FieldType == reflect.Bool {
return nil, nil, fmt.Errorf("requiring %s of type %s does not make sense", field.FieldName, field.FieldType)
}

err := mainCMD.MarkFlagRequired(field.FieldName)
if err != nil {
return nil, nil, fmt.Errorf(
"cannot require field %s, %s: %w",
field.FieldName,
field.FieldType,
err,
)
}
}
}

// apply constrains
Expand All @@ -137,7 +154,7 @@ func DefineConfiguration(
Use: "_connector-service",
Short: "Start the connector service",
Hidden: true,
RunE: cli.MakeGRPCServerCommand(ctx, connectorName, v, connector),
RunE: cli.MakeGRPCServerCommand(ctx, connectorName, v, schema, connector),
}
mainCMD.AddCommand(grpcServerCmd)

Expand Down
64 changes: 0 additions & 64 deletions pkg/config/validation.go

This file was deleted.

15 changes: 13 additions & 2 deletions pkg/field/fields.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,18 @@ func (s SchemaField) String() (string, error) {
}

func (s SchemaField) GetDescription() string {
var line string
if s.Description == "" {
return fmt.Sprintf("($BATON_%s)", toUpperCase(s.FieldName))
line = fmt.Sprintf("($BATON_%s)", toUpperCase(s.FieldName))
} else {
line = fmt.Sprintf("%s ($BATON_%s)", s.Description, toUpperCase(s.FieldName))
}

return fmt.Sprintf("%s ($BATON_%s)", s.Description, toUpperCase(s.FieldName))
if s.Required {
line = fmt.Sprintf("required: %s", line)
}

return line
}

func (s SchemaField) GetName() string {
Expand All @@ -78,6 +85,10 @@ func BoolField(name string, optional ...fieldOption) SchemaField {
field = o(field)
}

if field.Required {
panic(fmt.Sprintf("requiring %s of type %s does not make sense", field.FieldName, field.FieldType))
}

return field
}

Expand Down
7 changes: 7 additions & 0 deletions pkg/field/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,10 @@ type Configuration struct {
Fields []SchemaField
Constraints []SchemaFieldRelationship
}

func NewConfiguration(fields []SchemaField, constraints ...SchemaFieldRelationship) Configuration {
return Configuration{
Fields: fields,
Constraints: constraints,
}
}
108 changes: 108 additions & 0 deletions pkg/field/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package field

import (
"fmt"
"reflect"
"strings"

"github.com/spf13/viper"
)

type ErrConfigurationMissingFields struct {
errors []error
}

func (e *ErrConfigurationMissingFields) Error() string {
var messages []string

for _, err := range e.errors {
messages = append(messages, err.Error())
}

return fmt.Sprintf("errors found:\n%s", strings.Join(messages, "\n"))
}

func (e *ErrConfigurationMissingFields) Push(err error) {
e.errors = append(e.errors, err)
}

// Validate perform validation of field requirement and constraints
// relationships after the configuration is read.
// We don't check the following:
// - if required fields are mutually exclusive
// - repeated fields (by name) are defined
// - if sets of fields are mutually exclusive and required
// together at the same time
func Validate(c Configuration, v *viper.Viper) error {
present := make(map[string]int)
missingFieldsError := &ErrConfigurationMissingFields{}

// check if required fields are present
for _, f := range c.Fields {
isNonZero := false
switch f.FieldType {
case reflect.Bool:
isNonZero = v.GetBool(f.FieldName)
case reflect.Int:
isNonZero = v.GetInt(f.FieldName) != 0
case reflect.String:
isNonZero = v.GetString(f.FieldName) != ""
default:
return fmt.Errorf("field %s has unsupported type %s", f.FieldName, f.FieldType)
}

if isNonZero {
present[f.FieldName] = 1
}

if f.Required && !isNonZero {
missingFieldsError.Push(fmt.Errorf("field %s of type %s is marked as required but it has a zero-value", f.FieldName, f.FieldType))
}
}

if len(missingFieldsError.errors) > 0 {
return missingFieldsError
}

// check constraints
return validateConstraints(present, c.Constraints)
}

func validateConstraints(fieldsPresent map[string]int, relationships []SchemaFieldRelationship) error {
for _, relationship := range relationships {
var present int
for _, f := range relationship.Fields {
present += fieldsPresent[f.FieldName]
}
if present > 1 && relationship.Kind == MutuallyExclusive {
return makeMutuallyExclusiveError(fieldsPresent, relationship)
}
if present > 0 && present < len(relationship.Fields) && relationship.Kind == RequiredTogether {
return makeNeededTogetherError(fieldsPresent, relationship)
}
}

return nil
}

func makeMutuallyExclusiveError(fields map[string]int, relation SchemaFieldRelationship) error {
var found []string
for _, f := range relation.Fields {
if fields[f.FieldName] == 1 {
found = append(found, f.FieldName)
}
}

return fmt.Errorf("fields marked as mutually exclusive were set: %s", strings.Join(found, ", "))
}

func makeNeededTogetherError(fields map[string]int, relation SchemaFieldRelationship) error {
var found []string
for _, f := range relation.Fields {
if fields[f.FieldName] == 0 {
found = append(found, f.FieldName)
}
}

return fmt.Errorf("fields marked as needed together are missing: %s", strings.Join(found, ", "))
}
97 changes: 97 additions & 0 deletions pkg/field/validation_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package field

import (
"testing"

"github.com/spf13/viper"
"github.com/stretchr/testify/require"
)

func TestValidateRequiredFieldsNotFound(t *testing.T) {
carrier := Configuration{
Fields: []SchemaField{
StringField("foo", WithRequired(true)),
StringField("bar", WithRequired(false)),
},
}

// create configuration using viper
v := viper.New()
v.Set("foo", "")
v.Set("bar", "")

err := Validate(carrier, v)
require.Error(t, err)
require.EqualError(t, err, "errors found:\nfield foo of type string is marked as required but it has a zero-value")
}

func TestValidateRelationshipMutuallyExclusiveAllPresent(t *testing.T) {
foo := StringField("foo")
bar := StringField("bar")

carrier := Configuration{
Fields: []SchemaField{
foo,
bar,
},
Constraints: []SchemaFieldRelationship{
FieldsMutuallyExclusive(foo, bar),
},
}

// create configuration using viper
v := viper.New()
v.Set("foo", "hello")
v.Set("bar", "world")

err := Validate(carrier, v)
require.Error(t, err)
require.EqualError(t, err, "fields marked as mutually exclusive were set: foo, bar")
}

func TestValidationRequiredTogetherOneMissing(t *testing.T) {
foo := StringField("foo")
bar := StringField("bar")

carrier := Configuration{
Fields: []SchemaField{
foo,
bar,
},
Constraints: []SchemaFieldRelationship{
FieldsRequiredTogether(foo, bar),
},
}

// create configuration using viper
v := viper.New()
v.Set("foo", "hello")
v.Set("bar", "")

err := Validate(carrier, v)
require.Error(t, err)
require.EqualError(t, err, "fields marked as needed together are missing: bar")
}

func TestValidationRequiredTogetherAllMissing(t *testing.T) {
foo := StringField("foo")
bar := StringField("bar")

carrier := Configuration{
Fields: []SchemaField{
foo,
bar,
},
Constraints: []SchemaFieldRelationship{
FieldsRequiredTogether(foo, bar),
},
}

// create configuration using viper
v := viper.New()
v.Set("foo", "")
v.Set("bar", "")

err := Validate(carrier, v)
require.NoError(t, err)
}

0 comments on commit c4d7f74

Please sign in to comment.