Skip to content

Commit

Permalink
Global param propagation for sequential merger
Browse files Browse the repository at this point in the history
Signed-off-by: thedae <thedae@gmail.com>
  • Loading branch information
thedae committed Jan 13, 2025
1 parent b8244dd commit 2604f92
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 71 deletions.
210 changes: 143 additions & 67 deletions proxy/merging.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package proxy
import (
"context"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
Expand All @@ -27,7 +28,7 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
}
serviceTimeout := time.Duration(85*endpointConfig.Timeout.Nanoseconds()/100) * time.Nanosecond
combiner := getResponseCombiner(endpointConfig.ExtraConfig)
isSequential := shouldRunSequentialMerger(endpointConfig)
isSequential, propagatedParams := sequentialMergerConfig(endpointConfig)

logger.Debug(
fmt.Sprintf(
Expand Down Expand Up @@ -57,24 +58,86 @@ func NewMergeDataMiddleware(logger logging.Logger, endpointConfig *config.Endpoi
return parallelMerge(reqClone, serviceTimeout, combiner, next...)
}

patterns := make([]string, len(endpointConfig.Backend))
sequentialReplacements := make([][]sequentialBackendReplacement, totalBackends)

var rePropagatedParams = regexp.MustCompile(`[Rr]esp(\d+)_?([\w-\.]+)?`)
var reUrlPatterns = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-\.]+)\}\}`)
destKeyGenerator := func(i string, t string) string {
key := "Resp" + i
if t != "" {
key += "_" + t
}
return key
}

for i, b := range endpointConfig.Backend {
patterns[i] = b.URLPattern
for _, match := range reUrlPatterns.FindAllStringSubmatch(b.URLPattern, -1) {
if len(match) > 1 {
backendIndex, err := strconv.Atoi(match[1])
if err != nil {
continue
}

sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
backendIndex: backendIndex,
destination: destKeyGenerator(match[1], match[2]),
source: strings.Split(match[2], "."),
fullResponse: len(match[2]) == 0,
})
}
}

if i > 0 {
for _, p := range propagatedParams {
for _, match := range rePropagatedParams.FindAllStringSubmatch(p, -1) {
if len(match) > 1 {
backendIndex, err := strconv.Atoi(match[1])
if err != nil || backendIndex >= totalBackends {
continue
}

sequentialReplacements[i] = append(sequentialReplacements[i], sequentialBackendReplacement{
backendIndex: backendIndex,
destination: destKeyGenerator(match[1], match[2]),
source: strings.Split(match[2], "."),
fullResponse: len(match[2]) == 0,
})
}
}
}
}
}
return sequentialMerge(reqClone, patterns, serviceTimeout, combiner, next...)

return sequentialMerge(reqClone, sequentialReplacements, serviceTimeout, combiner, next...)
}
}

func shouldRunSequentialMerger(cfg *config.EndpointConfig) bool {
type sequentialBackendReplacement struct {
backendIndex int
destination string
source []string
fullResponse bool
}

func sequentialMergerConfig(cfg *config.EndpointConfig) (bool, []string) {
enabled := false
propagatedParams := []string{}
if v, ok := cfg.ExtraConfig[Namespace]; ok {
if e, ok := v.(map[string]interface{}); ok {
if v, ok := e[isSequentialKey]; ok {
c, ok := v.(bool)
return ok && c
enabled = ok && c
}
if v, ok := e[sequentialPropagateKey]; ok {
if a, ok := v.([]interface{}); ok {
for _, p := range a {
propagatedParams = append(propagatedParams, p.(string))
}
}
}
}
}
return false
return enabled, propagatedParams
}

func hasUnsafeBackends(cfg *config.EndpointConfig) bool {
Expand Down Expand Up @@ -118,75 +181,92 @@ func parallelMerge(reqCloner func(*Request) *Request, timeout time.Duration, rc
}
}

var reMergeKey = regexp.MustCompile(`\{\{\.Resp(\d+)_([\w-\.]+)\}\}`)

func sequentialMerge(reqCloner func(*Request) *Request, patterns []string, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy {
func sequentialMerge(reqCloner func(*Request) *Request, sequentialReplacements [][]sequentialBackendReplacement, timeout time.Duration, rc ResponseCombiner, next ...Proxy) Proxy {
return func(ctx context.Context, request *Request) (*Response, error) {
localCtx, cancel := context.WithTimeout(ctx, timeout)

parts := make([]*Response, len(next))
out := make(chan *Response, 1)
errCh := make(chan error, 1)
sequentialMergeRegistry := map[string]string{}

acc := newIncrementalMergeAccumulator(len(next), rc)
TxLoop:
for i, n := range next {
if i > 0 {
for _, match := range reMergeKey.FindAllStringSubmatch(patterns[i], -1) {
if len(match) > 1 {
rNum, err := strconv.Atoi(match[1])
if err != nil || rNum >= i || parts[rNum] == nil {
continue
}
key := "Resp" + match[1] + "_" + match[2]

var v interface{}
var ok bool

data := parts[rNum].Data
keys := strings.Split(match[2], ".")
if len(keys) > 1 {
for _, k := range keys[:len(keys)-1] {
v, ok = data[k]
if !ok {
break
}
clean, ok := v.(map[string]interface{})
if !ok {
break
}
data = clean
for _, r := range sequentialReplacements[i] {
if r.backendIndex >= i || parts[r.backendIndex] == nil {
continue
}

var v interface{}
var ok bool

data := parts[r.backendIndex].Data
if len(r.source) > 1 {
for _, k := range r.source[:len(r.source)-1] {
v, ok = data[k]
if !ok {
break
}
clean, ok := v.(map[string]interface{})
if !ok {
break
}
data = clean
}
}

v, ok = data[keys[len(keys)-1]]
if !ok {
if found := sequentialMergeRegistry[r.destination]; found != "" {
request.Params[r.destination] = found
continue
}

if r.fullResponse {
if parts[r.backendIndex].Io == nil {
continue
}
switch clean := v.(type) {
case []interface{}:
if len(clean) == 0 {
request.Params[key] = ""
continue
}
var b strings.Builder
for i := 0; i < len(clean)-1; i++ {
fmt.Fprintf(&b, "%v,", clean[i])
}
fmt.Fprintf(&b, "%v", clean[len(clean)-1])
request.Params[key] = b.String()
case string:
request.Params[key] = clean
case int:
request.Params[key] = strconv.Itoa(clean)
case float64:
request.Params[key] = strconv.FormatFloat(clean, 'E', -1, 32)
case bool:
request.Params[key] = strconv.FormatBool(clean)
default:
request.Params[key] = fmt.Sprintf("%v", v)
buf, err := io.ReadAll(parts[r.backendIndex].Io)

if err == nil {
request.Params[r.destination] = string(buf)
sequentialMergeRegistry[r.destination] = string(buf)
}
continue
}

v, ok = data[r.source[len(r.source)-1]]
if !ok {
continue
}

var param string

switch clean := v.(type) {
case []interface{}:
if len(clean) == 0 {
request.Params[r.destination] = ""
break
}
var b strings.Builder
for i := 0; i < len(clean)-1; i++ {
fmt.Fprintf(&b, "%v,", clean[i])
}
fmt.Fprintf(&b, "%v", clean[len(clean)-1])
param = b.String()
case string:
param = clean
case int:
param = strconv.Itoa(clean)
case float64:
param = strconv.FormatFloat(clean, 'E', -1, 32)
case bool:
param = strconv.FormatBool(clean)
default:
param = fmt.Sprintf("%v", v)
}
request.Params[r.destination] = param
sequentialMergeRegistry[r.destination] = param
}
}

Expand Down Expand Up @@ -284,30 +364,25 @@ func requestPart(ctx context.Context, next Proxy, request *Request, out chan<- *
}

func sequentialRequestPart(ctx context.Context, next Proxy, request *Request, out chan<- *Response, failed chan<- error) {
localCtx, cancel := context.WithCancel(ctx)

copyRequest := CloneRequest(request)

in, err := next(localCtx, request)
in, err := next(ctx, request)

*request = *copyRequest

if err != nil {
failed <- err
cancel()
return
}
if in == nil {
failed <- errNullResult
cancel()
return
}
select {
case out <- in:
case <-ctx.Done():
failed <- ctx.Err()
}
cancel()
}

func newMergeError(errs []error) error {
Expand Down Expand Up @@ -342,9 +417,10 @@ func RegisterResponseCombiner(name string, f ResponseCombiner) {
}

const (
mergeKey = "combiner"
isSequentialKey = "sequential"
defaultCombinerName = "default"
mergeKey = "combiner"
isSequentialKey = "sequential"
sequentialPropagateKey = "sequential_propagated_params"
defaultCombinerName = "default"
)

var responseCombiners = initResponseCombiners()
Expand Down
27 changes: 23 additions & 4 deletions proxy/merging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,16 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) {
{URLPattern: "/"},
{URLPattern: "/aaa/{{.Resp0_array}}"},
{URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}"},
{URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}?x={{.Resp1_tupu}}"},
{URLPattern: "/aaa/{{.Resp0_int}}/{{.Resp0_string}}/{{.Resp0_bool}}/{{.Resp0_float}}/{{.Resp0_struct.foo}}?x={{.Resp1_tupu}}", Encoding: "noop"},
{URLPattern: "/aaa/{{.Resp0_struct.foo}}/{{.Resp0_struct.struct.foo}}/{{.Resp0_struct.struct.struct.foo}}"},
{URLPattern: "/zzz", Encoding: "noop"},
{URLPattern: "/hit-me"},
},
Timeout: time.Duration(timeout) * time.Millisecond,
ExtraConfig: config.ExtraConfig{
Namespace: map[string]interface{}{
isSequentialKey: true,
isSequentialKey: true,
sequentialPropagateKey: []interface{}{"resp0_propagated", "resp5"},
},
},
}
Expand Down Expand Up @@ -144,11 +147,13 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) {
},
},
},
"array": []interface{}{"1", "2"},
"array": []interface{}{"1", "2"},
"propagated": "everywhere",
}, IsComplete: true}),
func(ctx context.Context, r *Request) (*Response, error) {
checkBody(t, r)
checkRequestParam(t, r, "Resp0_array", "1,2")
checkRequestParam(t, r, "Resp0_propagated", "everywhere")
return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil
},
func(ctx context.Context, r *Request) (*Response, error) {
Expand All @@ -158,6 +163,7 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) {
checkRequestParam(t, r, "Resp0_float", "3.14E+00")
checkRequestParam(t, r, "Resp0_bool", "true")
checkRequestParam(t, r, "Resp0_struct.foo", "bar")
checkRequestParam(t, r, "Resp0_propagated", "everywhere")
return &Response{Data: map[string]interface{}{"tupu": "foo"}, IsComplete: true}, nil
},
func(ctx context.Context, r *Request) (*Response, error) {
Expand All @@ -168,15 +174,28 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) {
checkRequestParam(t, r, "Resp0_bool", "true")
checkRequestParam(t, r, "Resp0_struct.foo", "bar")
checkRequestParam(t, r, "Resp1_tupu", "foo")
checkRequestParam(t, r, "Resp0_propagated", "everywhere")
return &Response{Data: map[string]interface{}{"aaaa": []int{1, 2, 3}}, IsComplete: true}, nil
},
func(ctx context.Context, r *Request) (*Response, error) {
checkBody(t, r)
checkRequestParam(t, r, "Resp0_struct.foo", "bar")
checkRequestParam(t, r, "Resp0_struct.struct.foo", "bar")
checkRequestParam(t, r, "Resp0_struct.struct.struct.foo", "bar")
checkRequestParam(t, r, "Resp0_propagated", "everywhere")
return &Response{Data: map[string]interface{}{"bbbb": []bool{true, false}}, IsComplete: true}, nil
},
func(ctx context.Context, r *Request) (*Response, error) {
checkBody(t, r)
checkRequestParam(t, r, "Resp0_propagated", "everywhere")
return &Response{Data: map[string]interface{}{}, Io: io.NopCloser(strings.NewReader("hello")), IsComplete: true}, nil
},
func(ctx context.Context, r *Request) (*Response, error) {
checkBody(t, r)
checkRequestParam(t, r, "Resp0_propagated", "everywhere")
checkRequestParam(t, r, "Resp5", "hello")
return &Response{Data: map[string]interface{}{}, IsComplete: true}, nil
},
)
mustEnd := time.After(time.Duration(2*timeout) * time.Millisecond)
out, err := p(context.Background(), &Request{
Expand All @@ -194,7 +213,7 @@ func TestNewMergeDataMiddleware_sequential(t *testing.T) {
case <-mustEnd:
t.Errorf("We were expecting a response but we got none\n")
default:
if len(out.Data) != 9 {
if len(out.Data) != 10 {
t.Errorf("We weren't expecting a partial response but we got %v!\n", out)
}
if !out.IsComplete {
Expand Down

0 comments on commit 2604f92

Please sign in to comment.