Skip to content

Commit

Permalink
Merge pull request #14 from railsware/unmarshal-text-support
Browse files Browse the repository at this point in the history
Support encoding.TextUnmarshaler
  • Loading branch information
leonid-shevtsov authored Sep 10, 2024
2 parents 636fddf + 03ecd96 commit d971cf4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
32 changes: 32 additions & 0 deletions tree/leaf.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tree

import (
"encoding"
"fmt"
"reflect"
"strconv"
Expand All @@ -11,6 +12,11 @@ func (paramTree Node) writeLeafValue(destination reflect.Value) WriteErrors { //
return newWriteErrors("value is not writable")
}

// If destination implements TextUnmarshaler, it takes priority
if handled, unmarshalerErrors := tryWriteUnmarshaler(paramTree.Value, destination); handled {
return unmarshalerErrors
}

switch destination.Kind() { //nolint:exhaustive // we don't cover all types
case reflect.String:
destination.SetString(paramTree.Value)
Expand Down Expand Up @@ -73,3 +79,29 @@ func writeBool(source string, destination reflect.Value) WriteErrors {
}
return WriteErrors{}
}

func tryWriteUnmarshaler(source string, destination reflect.Value) (bool, WriteErrors) {
if destination.CanInterface() {
if unmarshaler, ok := destination.Interface().(encoding.TextUnmarshaler); ok {
return true, writeUnmarshaler(source, unmarshaler)
}
}

// In some cases unmarshaling requires a pointer receiver. So if the value itself does not implement the interface,
// check a pointer to it as well.
if destination.CanAddr() {
if unmarshaler, ok := destination.Addr().Interface().(encoding.TextUnmarshaler); ok {
return true, writeUnmarshaler(source, unmarshaler)
}
}

return false, WriteErrors{}
}

func writeUnmarshaler(source string, unmarshaler encoding.TextUnmarshaler) WriteErrors {
err := unmarshaler.UnmarshalText([]byte(source))
if err != nil {
return newWriteErrors(fmt.Sprintf("cannot write param: UnmarshalText returned error: %v", err))
}
return WriteErrors{}
}
15 changes: 14 additions & 1 deletion tree/write_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ type testStructType struct {
NestedSlicePtr []*testStructType `json:"nsliceptr"`
SliceOfMap []map[string]string `json:"sliceofmap"`
MapOfSlice map[string][]string `json:"mapofslice"`
Unmarshalable unmarshalableType `json:"unmarshalable"`
// Fields just for testing errors
Complex64 complex64 `global:"complex"`
BadMap map[int]string `json:"badmap"`
Expand All @@ -36,6 +37,15 @@ type nestedType struct {
NestedField string `json:"nested_field"`
}

type unmarshalableType struct {
value string
}

func (u *unmarshalableType) UnmarshalText(text []byte) error {
u.value = string(text)
return nil
}

// TODO: maybe split the test into atomic parts so it's not so hard to review
func TestWrite(t *testing.T) { //nolint:maintidx
t.Parallel()
Expand Down Expand Up @@ -139,11 +149,12 @@ func TestWrite(t *testing.T) { //nolint:maintidx
},
},
},
"unmarshalable": {Value: "unmarshaled_value"},
},
}

errors := tree.Write(reflect.ValueOf(&testStruct))
require.False(t, errors.Present())
require.False(t, errors.Present(), "had errors: %v", errors.Join())

expectedStruct := testStructType{
Str: "foo",
Expand Down Expand Up @@ -178,6 +189,7 @@ func TestWrite(t *testing.T) { //nolint:maintidx
NestedSlicePtr: []*testStructType{{Str: "nested_foo_in_slice_ptr"}},
SliceOfMap: []map[string]string{{"foo": "map_in_slice"}},
MapOfSlice: map[string][]string{"foo": {"slice_in_map"}},
Unmarshalable: unmarshalableType{value: "unmarshaled_value"},
}

assert.Equal(t, expectedStruct, testStruct, "assignment works correctly")
Expand Down Expand Up @@ -328,6 +340,7 @@ func TestWrite(t *testing.T) { //nolint:maintidx
NestedSlicePtr: []*testStructType{{Str: "updated_foo_in_slice_ptr"}},
SliceOfMap: []map[string]string{{"foo": "updated_map_in_slice"}},
MapOfSlice: map[string][]string{"foo": {"updated_slice_in_map"}},
Unmarshalable: unmarshalableType{value: "unmarshaled_value"},
}

assert.Equal(t, expectedMergedStruct, testStruct, "merging changes works correctly")
Expand Down

0 comments on commit d971cf4

Please sign in to comment.