Skip to content

Commit

Permalink
split the scan/marshal stuff a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
guregu committed Nov 2, 2024
1 parent 409eaf6 commit 05e0092
Show file tree
Hide file tree
Showing 4 changed files with 452 additions and 437 deletions.
308 changes: 308 additions & 0 deletions trealla/decode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,308 @@
package trealla

import (
"fmt"
"reflect"
"strconv"
"strings"
)

var (
compoundType = reflect.TypeFor[Compound]()
functorType = reflect.TypeFor[Functor]()
termType = reflect.TypeFor[Term]()
atomType = reflect.TypeFor[Atom]()
)

func scan(sub Substitution, rv reflect.Value) error {
switch rv.Kind() {
case reflect.Map:
vtype := rv.Type().Elem()
for k, v := range sub {
vv := reflect.ValueOf(v)
if !vv.CanConvert(vtype) {
return fmt.Errorf("trealla: invalid element type for Scan: %v", vtype)
}
rv.SetMapIndex(reflect.ValueOf(k), vv.Convert(vtype))
}
return nil
case reflect.Pointer:
rv = rv.Elem()
// we can't set the inner elements of *map and *interface directly,
// so they need to be swapped out with a new inner value
switch rv.Kind() {
case reflect.Map:
ev := reflect.MakeMap(rv.Type())
if err := scan(sub, ev); err != nil {
return err
}
rv.Set(ev)
return nil
case reflect.Interface:
ev := reflect.New(rv.Elem().Type())
if err := scan(sub, ev); err != nil {
return err
}
rv.Set(ev.Elem())
return nil
case reflect.Struct:
// happy path
default:
return fmt.Errorf("trealla: must pass pointer to struct or map for Scan. got: %v", rv.Type())
}

rtype := rv.Type()
fieldnum := rtype.NumField()
fields := make(map[string]reflect.Value, fieldnum)
info := make(map[string]reflect.StructField, fieldnum)
for i := 0; i < fieldnum; i++ {
f := rtype.Field(i)
name := f.Name
if tag := f.Tag.Get("prolog"); tag != "" {
name = tag
}
fields[name] = rv.Field(i)
info[name] = f
}

for k, v := range sub {
fv, ok := fields[k]
if !ok {
continue
}
if err := convert(fv, reflect.ValueOf(v), info[k]); err != nil {
return fmt.Errorf("trealla: error converting field %q in %v: %w", k, rv.Type(), err)
}
}
return nil
}

return fmt.Errorf("trealla: can't scan into type: %v; must be pointer to struct or map", rv.Type())
}

func convert(dstv, srcv reflect.Value, meta reflect.StructField) error {
if !srcv.IsValid() || !srcv.CanInterface() {
return fmt.Errorf("invalid src: %v", srcv)
}

ftype := dstv.Type()

if ftype == termType {
dstv.Set(srcv)
return nil
}

if srcv.Kind() == reflect.Interface && !srcv.IsNil() {
srcv = srcv.Elem()
}

if dstv.Kind() == reflect.Slice {
length := srcv.Len()
srctype := srcv.Type()
etype := ftype.Elem()
detype := dstv.Type().Elem()
var preconvert bool
switch {
case srctype == atomType && srcv.Interface().(Atom) == "[]":
// special case: empty list
length = 0
case srctype.Kind() == reflect.String:
// special case: convert string → list
runes := []rune(srcv.String())
srcv = reflect.ValueOf(runes)
length = len(runes)
// if []Atom or []Term
if dstv.Type().Elem().ConvertibleTo(atomType) || termType.AssignableTo(detype) {
preconvert = true
etype = atomType
}
}
slice := reflect.MakeSlice(ftype, length, length)
for i := 0; i < length; i++ {
x := srcv.Index(i)
if preconvert {
x = x.Convert(etype)
}
if err := convert(slice.Index(i), x, meta); err != nil {
return fmt.Errorf("can't convert %v[%d]; error: %w", srctype, i, err)
}
}
dstv.Set(slice)
return nil
}

// handle the empty string (rendered as [], the empty list)
if dstv.Kind() == reflect.String && srcv.Kind() == reflect.Slice && srcv.Len() == 0 {
dstv.SetString("")
return nil
}

// compound → struct
if srcv.Type() == compoundType && dstv.Kind() == reflect.Struct {
return decodeCompoundStruct(dstv, srcv.Interface().(Compound), meta)
}

if !srcv.CanConvert(ftype) {
return fmt.Errorf("can't convert from type %v to type: %v", srcv.Type(), ftype)
}
dstv.Set(srcv.Convert(ftype))
return nil
}

// TODO: break out reflect stuff into something like this:
// type structInfo struct {
// fields []reflect.Value
// meta []reflect.StructField
// functor string
// arity int
// }

func decodeCompoundStruct(dstv reflect.Value, src Compound, meta reflect.StructField) error {
rtype := dstv.Type()
fieldnum := rtype.NumField()
fields := make([]reflect.Value, 0, fieldnum)
fieldInfo := make([]reflect.StructField, 0, fieldnum)

var functor reflect.Value

var collect func(dstv reflect.Value) error
collect = func(dstv reflect.Value) error {
for i := 0; i < fieldnum; i++ {
field := rtype.Field(i)
fv := dstv.Field(i)
tag := field.Tag.Get("prolog")
if tag == "-" {
continue
}
exported := field.IsExported()
if field.Type == functorType && exported {
functor = fv
continue
}
if field.Anonymous && field.Type.Kind() == reflect.Struct {
if err := collect(fv); err != nil {
return err
}
continue
}
if !exported {
continue
}
fields = append(fields, fv)
fieldInfo = append(fieldInfo, field)
}
return nil
}
if err := collect(dstv); err != nil {
return err
}

if functor.IsValid() && functor.CanSet() {
// TODO: check tag?
functor.Set(reflect.ValueOf(Functor(src.Functor)))
}

for i := 0; i < min(len(fields), len(src.Args)); i++ {
info := fieldInfo[i]
if err := convert(fields[i], reflect.ValueOf(src.Args[i]), info); err != nil {
return fmt.Errorf("can't convert compound (%v) argument #%d (type %T, value: %v) into field %q: %w",
src.pi().String(), i, src.Args[i], src.Args[i], info.Name, err)
}
}
return nil
}

func encodeCompoundStruct(src any) (Compound, error) {
marker, ok := src.(compoundStruct)
if !ok {
return Compound{}, fmt.Errorf("can't encode %T to compound; no Functor field found", src)
}
srcv := reflect.ValueOf(src)
for srcv.Kind() == reflect.Pointer && !srcv.IsNil() {
srcv = srcv.Elem()
}
if srcv.Kind() != reflect.Struct {
return Compound{}, fmt.Errorf("not a struct: %T", src)
}

rtype := srcv.Type()
fieldnum := rtype.NumField()
fields := make([]reflect.Value, 0, fieldnum)
fieldInfo := make([]reflect.StructField, 0, fieldnum)
functor := marker.functor()

// var functor reflect.Value
// var expectFunctor Functor
// expectArity := -1
var expect string
var arity int
var collect func(dstv reflect.Value) error
collect = func(dstv reflect.Value) error {
for i := 0; i < fieldnum; i++ {
field := rtype.Field(i)
fv := dstv.Field(i)
tag := field.Tag.Get("prolog")
if tag == "-" {
continue
}
exported := field.IsExported()
if field.Type == functorType && exported {
expect, arity = structTag(tag)
continue
}
if field.Anonymous && field.Type.Kind() == reflect.Struct {
if err := collect(fv); err != nil {
return err
}
continue
}
if !exported {
continue
}
fields = append(fields, fv)
fieldInfo = append(fieldInfo, field)
}
return nil
}
if err := collect(srcv); err != nil {
return Compound{}, err
}

c := Compound{Functor: Atom(functor), Args: make([]Term, 0, len(fields))}
for i := 0; i < len(fields); i++ {
// info := fieldInfo[i]
iface := fields[i].Interface()
// tt, err := marshal(iface.(Term))
// if err != nil {
// return c, fmt.Errorf("can't encode compound (%v) argument #%d (type %T, value: %v) from field %q: %w",
// functor, i, iface, iface, info.Name, err)
// }
c.Args = append(c.Args, iface)
}
if c.Functor == "" && expect != "" {
c.Functor = Atom(expect)
}
if arity > 0 && len(c.Args) != arity {
names := make([]string, len(fields))
for i := 0; i < len(fields); i++ {
names[i] = fieldInfo[i].Name
}
return c, fmt.Errorf("# of fields in %T does not match arity of struct tag (%s/%d): have %d fields %v but expected %d",
src, expect, arity, len(c.Args), names, arity)
}

return c, nil
}

func structTag(tag string) (name string, arity int) {
if tag == "" {
return
}
name, _, _ = strings.Cut(tag, ",")
slash := strings.LastIndexByte(name, '/')
if slash > 0 && slash < len(name)-1 {
arity, _ = strconv.Atoi(name[slash+1:])
name = name[:slash]
}
return
}
Loading

0 comments on commit 05e0092

Please sign in to comment.