diff --git a/internal/reflectext/visitor.go b/internal/reflectext/visitor.go index 27bcbbe..9b31bf3 100644 --- a/internal/reflectext/visitor.go +++ b/internal/reflectext/visitor.go @@ -3,6 +3,8 @@ package reflectext import ( "fmt" "reflect" + "slices" + "strings" ) // Visitor visits values in a reflect.Value graph. @@ -136,6 +138,7 @@ func (flags VisitorFlags) Has(flag VisitorFlags) bool { // VisitorContext is Visitor context. type VisitorContext struct { parent *VisitorContext + path []pathSegment impl Visitor flags VisitorFlags } @@ -153,7 +156,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) { // Special case for reflect.Value. if ctx.flags.Has(VisitReflectValues) && v.Type() == ReflectValueType { - ctx.Visit(v.Interface().(reflect.Value)) + ctx.WithReflectValue(v).Visit(v.Interface().(reflect.Value)) return } @@ -185,7 +188,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) { case reflect.Pointer: if visitor.VisitPointer(ctx, v) && !v.IsNil() { - ctx.Visit(v.Elem()) + ctx.WithPointer(v).Visit(v.Elem()) } case reflect.UnsafePointer: @@ -194,7 +197,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) { case reflect.Array: if visitor.VisitArray(ctx, v) { for i := 0; i < v.Len(); i++ { - ctx.Visit(v.Index(i)) + ctx.WithArrayIndex(v, i).Visit(v.Index(i)) } } @@ -204,7 +207,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) { v = MakeSlice(v.Type(), v.UnsafePointer(), v.Cap(), v.Cap()) } for i := 0; i < v.Len(); i++ { - ctx.Visit(v.Index(i)) + ctx.WithSliceIndex(v, i).Visit(v.Index(i)) } } @@ -212,8 +215,10 @@ func (ctx VisitorContext) Visit(v reflect.Value) { if visitor.VisitMap(ctx, v) && !v.IsNil() { iter := v.MapRange() for iter.Next() { - ctx.Visit(iter.Key()) - ctx.Visit(iter.Value()) + key := iter.Key() + val := iter.Value() + ctx.WithMapKey(v, key).Visit(key) + ctx.WithMapValue(v, key).Visit(val) } } @@ -223,16 +228,16 @@ func (ctx VisitorContext) Visit(v reflect.Value) { case reflect.Struct: if visitor.VisitStruct(ctx, v) { + t := v.Type() if ctx.flags.Has(VisitUnexportedFields) { unrestricted := StructValueOf(v) for i := 0; i < unrestricted.NumField(); i++ { - ctx.Visit(unrestricted.Field(i)) + ctx.WithStructField(v, t.Field(i)).Visit(unrestricted.Field(i)) } } else { - t := v.Type() for i := 0; i < v.NumField(); i++ { if ft := t.Field(i); ft.IsExported() { - ctx.Visit(v.Field(i)) + ctx.WithStructField(v, t.Field(i)).Visit(v.Field(i)) } } } @@ -241,13 +246,13 @@ func (ctx VisitorContext) Visit(v reflect.Value) { case reflect.Func: if visitor.VisitFunc(ctx, v) && !v.IsNil() && ctx.flags.Has(VisitClosures) { if closure, ok := FuncValueOf(v).Closure(); ok { - ctx.Visit(closure) + ctx.WithClosure(v).Visit(closure) } } case reflect.Interface: if visitor.VisitInterface(ctx, v) && !v.IsNil() { - ctx.Visit(v.Elem()) + ctx.WithInterface(v).Visit(v.Elem()) } default: @@ -255,6 +260,20 @@ func (ctx VisitorContext) Visit(v reflect.Value) { } } +// Path returns a string that represents the location in the +// reflect.Value graph. +func (ctx VisitorContext) Path() string { + var b strings.Builder + for _, s := range ctx.path { + b.WriteString(s.String()) + } + parent := "$" + if ctx.parent != nil { + parent = ctx.parent.Path() + } + return parent + b.String() +} + // Fork creates a new Visitor context, linked to the current // context and its location. func (ctx VisitorContext) Fork(impl Visitor) VisitorContext { @@ -265,6 +284,125 @@ func (ctx VisitorContext) Fork(impl Visitor) VisitorContext { } } +func (ctx VisitorContext) withPathSegment(s pathSegment) VisitorContext { + c := ctx // shallow copy + c.path = append(slices.Clip(c.path), s) + return c +} + +func (ctx VisitorContext) WithArrayIndex(a reflect.Value, i int) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: indexType, + val: a, + idx: i, + }) +} + +func (ctx VisitorContext) WithSliceIndex(s reflect.Value, i int) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: indexType, + val: s, + idx: i, + }) +} + +func (ctx VisitorContext) WithStructField(s reflect.Value, f reflect.StructField) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: structType, + val: s, + field: f, + }) +} + +func (ctx VisitorContext) WithMapKey(m, k reflect.Value) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: mapKeyType, + val: m, + elem: k, + }) +} + +func (ctx VisitorContext) WithMapValue(m, k reflect.Value) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: mapValType, + val: m, + elem: k, + }) +} + +func (ctx VisitorContext) WithClosure(f reflect.Value) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: closureType, + val: f, + }) +} + +func (ctx VisitorContext) WithInterface(i reflect.Value) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: interfaceType, + val: i, + }) +} + +func (ctx VisitorContext) WithPointer(p reflect.Value) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: pointerType, + val: p, + }) +} + +func (ctx VisitorContext) WithReflectValue(v reflect.Value) VisitorContext { + return ctx.withPathSegment(pathSegment{ + typ: reflectType, + val: v, + }) +} + +type pathSegment struct { + typ segmentType + val reflect.Value + elem reflect.Value + + field reflect.StructField + idx int +} + +type segmentType int + +const ( + indexType segmentType = iota + structType + mapKeyType + mapValType + closureType + interfaceType + pointerType + reflectType +) + +func (s pathSegment) String() string { + switch s.typ { + case indexType: + return fmt.Sprintf("[%d]", s.idx) + case structType: + return "." + s.field.Name + case mapKeyType: + return fmt.Sprintf(".$key(%v)", s.elem) + case mapValType: + return fmt.Sprintf("[%v]", s.elem) + case closureType: + return ".$closure" + case interfaceType: + return fmt.Sprintf(".(%T)", s.val.Interface()) + case pointerType: + return "" + case reflectType: + return ".$reflect" + default: + panic("unreachable") + } +} + // DefaultVisitor is a Visitor that visits all values in a // reflect.Value graph. type DefaultVisitor struct{} diff --git a/types/reflect.go b/types/reflect.go index ebeae76..9304a03 100644 --- a/types/reflect.go +++ b/types/reflect.go @@ -46,7 +46,7 @@ func (s *Serializer) Visit(ctx reflectext.VisitorContext, v reflect.Value) bool case reflectext.ReflectValueType: rv := v.Interface().(reflect.Value) s.appendReflectType(rv.Type()) - ctx.Visit(rv) + ctx.WithReflectValue(v).Visit(rv) return false } @@ -85,7 +85,7 @@ func (d *Deserializer) Visit(ctx reflectext.VisitorContext, v reflect.Value) boo rt = reflect.ArrayOf(length, rt) } rv := reflect.New(rt).Elem() - ctx.Visit(rv) + ctx.WithReflectValue(v).Visit(rv) v.Set(reflect.ValueOf(rv)) return false } @@ -297,8 +297,10 @@ func (s *Serializer) VisitMap(ctx reflectext.VisitorContext, v reflect.Value) bo iter := v.MapRange() for iter.Next() { - mapVisitor.Visit(iter.Key()) - mapVisitor.Visit(iter.Value()) + key := iter.Key() + val := iter.Value() + mapVisitor.WithMapKey(v, key).Visit(key) + mapVisitor.WithMapValue(v, val).Visit(val) } region.Data = regionSerializer.buffer @@ -342,15 +344,16 @@ func (d *Deserializer) VisitMap(ctx reflectext.VisitorContext, v reflect.Value) mapVisitor := ctx.Fork(regionDeserializer) - nv := reflect.MakeMapWithSize(t, n) - v.Set(nv) + v.Set(reflect.MakeMapWithSize(t, n)) d.store(sID(id), v.Addr().UnsafePointer()) for i := 0; i < n; i++ { - kv := reflect.New(keyType).Elem() - mapVisitor.Visit(kv) - vv := reflect.New(valType).Elem() - mapVisitor.Visit(vv) - v.SetMapIndex(kv, vv) + key := reflect.New(keyType).Elem() + mapVisitor.WithMapKey(v, key).Visit(key) + + val := reflect.New(valType).Elem() + mapVisitor.WithMapKey(v, val).Visit(val) + + v.SetMapIndex(key, val) } return false } @@ -511,9 +514,9 @@ func (s *Serializer) serializeRegion(ctx reflectext.VisitorContext, et reflect.T regionSerializer := s.fork() regionVisitor := ctx.Fork(regionSerializer) if r.len >= 0 { // array - data := reflectext.MakeSlice(reflect.SliceOf(r.typ), r.addr, r.len, r.len) - for i := 0; i < data.Len(); i++ { - regionVisitor.Visit(data.Index(i)) + v := reflectext.MakeSlice(reflect.SliceOf(r.typ), r.addr, r.len, r.len) + for i := 0; i < v.Len(); i++ { + regionVisitor.WithArrayIndex(v, i).Visit(v.Index(i)) } } else { v := reflect.NewAt(r.typ, r.addr).Elem() @@ -566,9 +569,9 @@ func (d *Deserializer) deserializeRegion(ctx reflectext.VisitorContext, t reflec } else { regionDeserializer := d.fork(region.Data) regionVisitor := ctx.Fork(regionDeserializer) - data := reflectext.MakeSlice(reflect.SliceOf(regionType), p, length, length) + v := reflectext.MakeSlice(reflect.SliceOf(regionType), p, length, length) for i := 0; i < length; i++ { - regionVisitor.Visit(data.Index(i)) + regionVisitor.WithArrayIndex(v, i).Visit(v.Index(i)) } } } else {