Skip to content

Commit

Permalink
Fix update.Response getting skipped by the proxy (#209)
Browse files Browse the repository at this point in the history
Fix update.Response getting skipped by the proxy
  • Loading branch information
Quinn-With-Two-Ns authored Feb 10, 2025
1 parent 44277ce commit dad8b16
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 8 deletions.
60 changes: 54 additions & 6 deletions cmd/proxygenerator/interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ type VisitFailuresOptions struct {
// Context is the same for every call of a visit, callers should not store it.
// Visitor is free to mutate the passed failure struct.
Visitor func(*VisitFailuresContext, *failure.Failure) (error)
// Will be called for each Any encountered. If not set, the default is to recurse into the Any
// object, unmarshal it, visit, and re-marshal it always (even if there are no changes).
WellKnownAnyVisitor func(*VisitFailuresContext, *anypb.Any) error
}
// VisitFailures calls the options.Visitor function for every Failure proto within msg.
Expand Down Expand Up @@ -162,6 +165,25 @@ func NewFailureVisitorInterceptor(options FailureVisitorInterceptorOptions) (grp
}, nil
}
func (o *VisitFailuresOptions) defaultWellKnownAnyVisitor(ctx *VisitFailuresContext, p *anypb.Any) error {
child, err := p.UnmarshalNew()
if err != nil {
return fmt.Errorf("failed to unmarshal any: %w", err)
}
// We choose to visit and re-marshal always instead of cloning, visiting,
// and checking if anything changed before re-marshaling. It is assumed the
// clone + equality check is not much cheaper than re-marshal.
if err := visitFailures(ctx, o, child); err != nil {
return err
}
// Confirmed this replaces both Any fields on non-error, there is nothing
// left over
if err := p.MarshalFrom(child); err != nil {
return fmt.Errorf("failed to marshal any: %w", err)
}
return nil
}
func (o *VisitPayloadsOptions) defaultWellKnownAnyVisitor(ctx *VisitPayloadsContext, p *anypb.Any) error {
child, err := p.UnmarshalNew()
if err != nil {
Expand Down Expand Up @@ -299,6 +321,20 @@ func visitFailures(ctx *VisitFailuresContext, options *VisitFailuresOptions, obj
if o == nil { continue }
if err := options.Visitor(ctx, o); err != nil { return err }
if err := visitFailures(ctx, options, o.GetCause()); err != nil { return err }
case *anypb.Any:
if o == nil {
continue
}
visitor := options.WellKnownAnyVisitor
if visitor == nil {
visitor = options.defaultWellKnownAnyVisitor
}
ctx.Parent = o
err := visitor(ctx, o)
ctx.Parent = nil
if err != nil {
return err
}
{{range $type, $record := .FailureTypes}}
{{if $record.Slice}}
case []{{$type}}:
Expand Down Expand Up @@ -508,17 +544,19 @@ func generateInterceptor(cfg config) error {
if err != nil {
return err
}
// For the purposes of payloads, we also consider the Any well known type as

failureTypes, err := lookupTypes("go.temporal.io/api/failure/v1", []string{"Failure"})
if err != nil {
return err
}

// For the purposes of payloads and failures, we also consider the Any well known type as
// possible
if anyTypes, err := lookupTypes("google.golang.org/protobuf/types/known/anypb", []string{"Any"}); err != nil {
return err
} else {
payloadTypes = append(payloadTypes, anyTypes...)
}

failureTypes, err := lookupTypes("go.temporal.io/api/failure/v1", []string{"Failure"})
if err != nil {
return err
failureTypes = append(failureTypes, anyTypes...)
}

// UnimplementedWorkflowServiceServer is auto-generated via our API package
Expand All @@ -542,6 +580,11 @@ func generateInterceptor(cfg config) error {
}
workflowExecutions := types.NewPointer(exportTypes[0])

updateTypes, err := lookupTypes("go.temporal.io/api/update/v1", []string{"Acceptance", "Rejection", "Response"})
if err != nil {
return err
}

payloadRecords := map[string]*TypeRecord{}
failureRecords := map[string]*TypeRecord{}

Expand Down Expand Up @@ -572,6 +615,11 @@ func generateInterceptor(cfg config) error {
walk(payloadTypes, workflowExecutions, &payloadRecords, true)
walk(failureTypes, workflowExecutions, &failureRecords, false)

for _, ut := range updateTypes {
walk(payloadTypes, types.NewPointer(ut), &payloadRecords, true)
walk(failureTypes, types.NewPointer(ut), &failureRecords, false)
}

payloadRecords = pruneRecords(payloadRecords)
failureRecords = pruneRecords(failureRecords)

Expand Down
131 changes: 131 additions & 0 deletions proxy/interceptor.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit dad8b16

Please sign in to comment.