Skip to content

Commit

Permalink
Provide the current visitor location as a path string
Browse files Browse the repository at this point in the history
  • Loading branch information
chriso committed Jul 9, 2024
1 parent b24f1d4 commit 2de08bd
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 27 deletions.
160 changes: 149 additions & 11 deletions internal/reflectext/visitor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package reflectext
import (
"fmt"
"reflect"
"slices"
"strings"
)

// Visitor visits values in a reflect.Value graph.
Expand Down Expand Up @@ -136,6 +138,7 @@ func (flags VisitorFlags) Has(flag VisitorFlags) bool {
// VisitorContext is Visitor context.
type VisitorContext struct {
parent *VisitorContext
path []pathSegment
impl Visitor
flags VisitorFlags
}
Expand All @@ -153,7 +156,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) {

// Special case for reflect.Value.
if ctx.flags.Has(VisitReflectValues) && v.Type() == ReflectValueType {
ctx.Visit(v.Interface().(reflect.Value))
ctx.WithReflectValue(v).Visit(v.Interface().(reflect.Value))
return
}

Expand Down Expand Up @@ -185,7 +188,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) {

case reflect.Pointer:
if visitor.VisitPointer(ctx, v) && !v.IsNil() {
ctx.Visit(v.Elem())
ctx.WithPointer(v).Visit(v.Elem())
}

case reflect.UnsafePointer:
Expand All @@ -194,7 +197,7 @@ func (ctx VisitorContext) Visit(v reflect.Value) {
case reflect.Array:
if visitor.VisitArray(ctx, v) {
for i := 0; i < v.Len(); i++ {
ctx.Visit(v.Index(i))
ctx.WithArrayIndex(v, i).Visit(v.Index(i))
}
}

Expand All @@ -204,16 +207,18 @@ func (ctx VisitorContext) Visit(v reflect.Value) {
v = MakeSlice(v.Type(), v.UnsafePointer(), v.Cap(), v.Cap())
}
for i := 0; i < v.Len(); i++ {
ctx.Visit(v.Index(i))
ctx.WithSliceIndex(v, i).Visit(v.Index(i))
}
}

case reflect.Map:
if visitor.VisitMap(ctx, v) && !v.IsNil() {
iter := v.MapRange()
for iter.Next() {
ctx.Visit(iter.Key())
ctx.Visit(iter.Value())
key := iter.Key()
val := iter.Value()
ctx.WithMapKey(v, key).Visit(key)
ctx.WithMapValue(v, key).Visit(val)
}
}

Expand All @@ -223,16 +228,16 @@ func (ctx VisitorContext) Visit(v reflect.Value) {

case reflect.Struct:
if visitor.VisitStruct(ctx, v) {
t := v.Type()
if ctx.flags.Has(VisitUnexportedFields) {
unrestricted := StructValueOf(v)
for i := 0; i < unrestricted.NumField(); i++ {
ctx.Visit(unrestricted.Field(i))
ctx.WithStructField(v, t.Field(i)).Visit(unrestricted.Field(i))
}
} else {
t := v.Type()
for i := 0; i < v.NumField(); i++ {
if ft := t.Field(i); ft.IsExported() {
ctx.Visit(v.Field(i))
ctx.WithStructField(v, t.Field(i)).Visit(v.Field(i))
}
}
}
Expand All @@ -241,20 +246,34 @@ func (ctx VisitorContext) Visit(v reflect.Value) {
case reflect.Func:
if visitor.VisitFunc(ctx, v) && !v.IsNil() && ctx.flags.Has(VisitClosures) {
if closure, ok := FuncValueOf(v).Closure(); ok {
ctx.Visit(closure)
ctx.WithClosure(v).Visit(closure)
}
}

case reflect.Interface:
if visitor.VisitInterface(ctx, v) && !v.IsNil() {
ctx.Visit(v.Elem())
ctx.WithInterface(v).Visit(v.Elem())
}

default:
panic("unreachable")
}
}

// Path returns a string that represents the location in the
// reflect.Value graph.
func (ctx VisitorContext) Path() string {
var b strings.Builder
for _, s := range ctx.path {
b.WriteString(s.String())
}
parent := "$"
if ctx.parent != nil {
parent = ctx.parent.Path()
}
return parent + b.String()
}

// Fork creates a new Visitor context, linked to the current
// context and its location.
func (ctx VisitorContext) Fork(impl Visitor) VisitorContext {
Expand All @@ -265,6 +284,125 @@ func (ctx VisitorContext) Fork(impl Visitor) VisitorContext {
}
}

func (ctx VisitorContext) withPathSegment(s pathSegment) VisitorContext {
c := ctx // shallow copy
c.path = append(slices.Clip(c.path), s)
return c
}

func (ctx VisitorContext) WithArrayIndex(a reflect.Value, i int) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: indexType,
val: a,
idx: i,
})
}

func (ctx VisitorContext) WithSliceIndex(s reflect.Value, i int) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: indexType,
val: s,
idx: i,
})
}

func (ctx VisitorContext) WithStructField(s reflect.Value, f reflect.StructField) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: structType,
val: s,
field: f,
})
}

func (ctx VisitorContext) WithMapKey(m, k reflect.Value) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: mapKeyType,
val: m,
elem: k,
})
}

func (ctx VisitorContext) WithMapValue(m, k reflect.Value) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: mapValType,
val: m,
elem: k,
})
}

func (ctx VisitorContext) WithClosure(f reflect.Value) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: closureType,
val: f,
})
}

func (ctx VisitorContext) WithInterface(i reflect.Value) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: interfaceType,
val: i,
})
}

func (ctx VisitorContext) WithPointer(p reflect.Value) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: pointerType,
val: p,
})
}

func (ctx VisitorContext) WithReflectValue(v reflect.Value) VisitorContext {
return ctx.withPathSegment(pathSegment{
typ: reflectType,
val: v,
})
}

type pathSegment struct {
typ segmentType
val reflect.Value
elem reflect.Value

field reflect.StructField
idx int
}

type segmentType int

const (
indexType segmentType = iota
structType
mapKeyType
mapValType
closureType
interfaceType
pointerType
reflectType
)

func (s pathSegment) String() string {
switch s.typ {
case indexType:
return fmt.Sprintf("[%d]", s.idx)
case structType:
return "." + s.field.Name
case mapKeyType:
return fmt.Sprintf(".$key(%v)", s.elem)
case mapValType:
return fmt.Sprintf("[%v]", s.elem)
case closureType:
return ".$closure"
case interfaceType:
return fmt.Sprintf(".(%T)", s.val.Interface())
case pointerType:
return ""
case reflectType:
return ".$reflect"
default:
panic("unreachable")
}
}

// DefaultVisitor is a Visitor that visits all values in a
// reflect.Value graph.
type DefaultVisitor struct{}
Expand Down
35 changes: 19 additions & 16 deletions types/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func (s *Serializer) Visit(ctx reflectext.VisitorContext, v reflect.Value) bool
case reflectext.ReflectValueType:
rv := v.Interface().(reflect.Value)
s.appendReflectType(rv.Type())
ctx.Visit(rv)
ctx.WithReflectValue(v).Visit(rv)
return false
}

Expand Down Expand Up @@ -85,7 +85,7 @@ func (d *Deserializer) Visit(ctx reflectext.VisitorContext, v reflect.Value) boo
rt = reflect.ArrayOf(length, rt)
}
rv := reflect.New(rt).Elem()
ctx.Visit(rv)
ctx.WithReflectValue(v).Visit(rv)
v.Set(reflect.ValueOf(rv))
return false
}
Expand Down Expand Up @@ -297,8 +297,10 @@ func (s *Serializer) VisitMap(ctx reflectext.VisitorContext, v reflect.Value) bo

iter := v.MapRange()
for iter.Next() {
mapVisitor.Visit(iter.Key())
mapVisitor.Visit(iter.Value())
key := iter.Key()
val := iter.Value()
mapVisitor.WithMapKey(v, key).Visit(key)
mapVisitor.WithMapValue(v, val).Visit(val)
}

region.Data = regionSerializer.buffer
Expand Down Expand Up @@ -342,15 +344,16 @@ func (d *Deserializer) VisitMap(ctx reflectext.VisitorContext, v reflect.Value)

mapVisitor := ctx.Fork(regionDeserializer)

nv := reflect.MakeMapWithSize(t, n)
v.Set(nv)
v.Set(reflect.MakeMapWithSize(t, n))
d.store(sID(id), v.Addr().UnsafePointer())
for i := 0; i < n; i++ {
kv := reflect.New(keyType).Elem()
mapVisitor.Visit(kv)
vv := reflect.New(valType).Elem()
mapVisitor.Visit(vv)
v.SetMapIndex(kv, vv)
key := reflect.New(keyType).Elem()
mapVisitor.WithMapKey(v, key).Visit(key)

val := reflect.New(valType).Elem()
mapVisitor.WithMapKey(v, val).Visit(val)

v.SetMapIndex(key, val)
}
return false
}
Expand Down Expand Up @@ -511,9 +514,9 @@ func (s *Serializer) serializeRegion(ctx reflectext.VisitorContext, et reflect.T
regionSerializer := s.fork()
regionVisitor := ctx.Fork(regionSerializer)
if r.len >= 0 { // array
data := reflectext.MakeSlice(reflect.SliceOf(r.typ), r.addr, r.len, r.len)
for i := 0; i < data.Len(); i++ {
regionVisitor.Visit(data.Index(i))
v := reflectext.MakeSlice(reflect.SliceOf(r.typ), r.addr, r.len, r.len)
for i := 0; i < v.Len(); i++ {
regionVisitor.WithArrayIndex(v, i).Visit(v.Index(i))
}
} else {
v := reflect.NewAt(r.typ, r.addr).Elem()
Expand Down Expand Up @@ -566,9 +569,9 @@ func (d *Deserializer) deserializeRegion(ctx reflectext.VisitorContext, t reflec
} else {
regionDeserializer := d.fork(region.Data)
regionVisitor := ctx.Fork(regionDeserializer)
data := reflectext.MakeSlice(reflect.SliceOf(regionType), p, length, length)
v := reflectext.MakeSlice(reflect.SliceOf(regionType), p, length, length)
for i := 0; i < length; i++ {
regionVisitor.Visit(data.Index(i))
regionVisitor.WithArrayIndex(v, i).Visit(v.Index(i))
}
}
} else {
Expand Down

0 comments on commit 2de08bd

Please sign in to comment.