Skip to content

Commit

Permalink
Merge pull request #711 from luraproject/modifier_ctx
Browse files Browse the repository at this point in the history
Pass the application context to the request and response modifiers.
  • Loading branch information
kpacha authored Feb 23, 2024
2 parents 7bca413 + bf32a29 commit 9249c56
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 22 deletions.
43 changes: 28 additions & 15 deletions proxy/plugin.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,14 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str
return resp, err
}

return executeResponseModifiers(respModifiers, resp)
return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r))
}
}

if totRespModifiers == 0 {
return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(reqModifiers, r)
r, err = executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}
Expand All @@ -119,7 +119,7 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str

return func(ctx context.Context, r *Request) (*Response, error) {
var err error
r, err = executeRequestModifiers(reqModifiers, r)
r, err = executeRequestModifiers(ctx, reqModifiers, r)
if err != nil {
return nil, err
}
Expand All @@ -129,22 +129,14 @@ func newPluginMiddleware(logger logging.Logger, tag, pattern string, cfg map[str
return resp, err
}

return executeResponseModifiers(respModifiers, resp)
return executeResponseModifiers(ctx, respModifiers, resp, newRequestWrapper(ctx, r))
}
}
}

func executeRequestModifiers(reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) {
func executeRequestModifiers(ctx context.Context, reqModifiers []func(interface{}) (interface{}, error), r *Request) (*Request, error) {
var tmp RequestWrapper
tmp = &requestWrapper{
method: r.Method,
url: r.URL,
query: r.Query,
path: r.Path,
body: r.Body,
params: r.Params,
headers: r.Headers,
}
tmp = newRequestWrapper(ctx, r)

for _, f := range reqModifiers {
res, err := f(tmp)
Expand All @@ -169,9 +161,11 @@ func executeRequestModifiers(reqModifiers []func(interface{}) (interface{}, erro
return r, nil
}

func executeResponseModifiers(respModifiers []func(interface{}) (interface{}, error), r *Response) (*Response, error) {
func executeResponseModifiers(ctx context.Context, respModifiers []func(interface{}) (interface{}, error), r *Response, req RequestWrapper) (*Response, error) {
var tmp ResponseWrapper
tmp = responseWrapper{
ctx: ctx,
request: req,
data: r.Data,
isComplete: r.IsComplete,
metadata: metadataWrapper{
Expand Down Expand Up @@ -222,7 +216,21 @@ type ResponseWrapper interface {
StatusCode() int
}

func newRequestWrapper(ctx context.Context, r *Request) *requestWrapper {
return &requestWrapper{
ctx: ctx,
method: r.Method,
url: r.URL,
query: r.Query,
path: r.Path,
body: r.Body,
params: r.Params,
headers: r.Headers,
}
}

type requestWrapper struct {
ctx context.Context
method string
url *url.URL
query url.Values
Expand All @@ -232,6 +240,7 @@ type requestWrapper struct {
headers map[string][]string
}

func (r *requestWrapper) Context() context.Context { return r.ctx }
func (r *requestWrapper) Method() string { return r.method }
func (r *requestWrapper) URL() *url.URL { return r.url }
func (r *requestWrapper) Query() url.Values { return r.query }
Expand All @@ -249,12 +258,16 @@ func (m metadataWrapper) Headers() map[string][]string { return m.headers }
func (m metadataWrapper) StatusCode() int { return m.statusCode }

type responseWrapper struct {
ctx context.Context
request interface{}
data map[string]interface{}
isComplete bool
metadata metadataWrapper
io io.Reader
}

func (r responseWrapper) Context() context.Context { return r.ctx }
func (r responseWrapper) Request() interface{} { return r.request }
func (r responseWrapper) Data() map[string]interface{} { return r.data }
func (r responseWrapper) IsComplete() bool { return r.isComplete }
func (r responseWrapper) Io() io.Reader { return r.io }
Expand Down
23 changes: 18 additions & 5 deletions proxy/plugin/modifier.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
// SPDX-License-Identifier: Apache-2.0

/*
Package plugin provides tools for loading and registering proxy plugins
Package plugin provides tools for loading and registering proxy plugins
*/
package plugin

import (
"context"
"fmt"
"plugin"
"strings"
Expand Down Expand Up @@ -84,6 +85,10 @@ type LoggerRegisterer interface {
RegisterLogger(interface{})
}

type ContextRegisterer interface {
RegisterContext(context.Context)
}

// RegisterModifierFunc type is the function passed to the loaded Registerers
type RegisterModifierFunc func(
name string,
Expand All @@ -99,18 +104,22 @@ func Load(path, pattern string, rmf RegisterModifierFunc) (int, error) {

// LoadWithLogger scans the given path using the pattern and registers all the found modifier plugins into the rmf
func LoadWithLogger(path, pattern string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
return LoadWithLoggerAndContext(context.Background(), path, pattern, rmf, logger)
}

func LoadWithLoggerAndContext(ctx context.Context, path, pattern string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
plugins, err := luraplugin.Scan(path, pattern)
if err != nil {
return 0, err
}
return load(plugins, rmf, logger)
return load(ctx, plugins, rmf, logger)
}

func load(plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
func load(ctx context.Context, plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (int, error) {
errors := []error{}
loadedPlugins := 0
for k, pluginName := range plugins {
if err := open(pluginName, rmf, logger); err != nil {
if err := open(ctx, pluginName, rmf, logger); err != nil {
errors = append(errors, fmt.Errorf("plugin #%d (%s): %s", k, pluginName, err.Error()))
continue
}
Expand All @@ -123,7 +132,7 @@ func load(plugins []string, rmf RegisterModifierFunc, logger logging.Logger) (in
return loadedPlugins, nil
}

func open(pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (err error) {
func open(ctx context.Context, pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (err error) {
defer func() {
if r := recover(); r != nil {
var ok bool
Expand Down Expand Up @@ -155,6 +164,10 @@ func open(pluginName string, rmf RegisterModifierFunc, logger logging.Logger) (e
}
}

if lr, ok := r.(ContextRegisterer); ok {
lr.RegisterContext(ctx)
}

registerer.RegisterModifiers(rmf)
return
}
Expand Down
81 changes: 79 additions & 2 deletions proxy/plugin/modifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,88 @@
package plugin

import (
"bytes"
"context"
"fmt"
"io"
"net/url"
"strings"
"testing"

"github.com/luraproject/lura/v2/logging"
)

func ExampleLoadWithLoggerAndContext() {
var data []byte

buf := bytes.NewBuffer(data)
logger, err := logging.NewLogger("DEBUG", buf, "")
if err != nil {
fmt.Println(err.Error())
return
}
total, err := LoadWithLoggerAndContext(context.Background(), "./tests", ".so", RegisterModifier, logger)
if err != nil {
fmt.Println(err.Error())
return
}
if total != 2 {
fmt.Printf("unexpected number of loaded plugins!. have %d, want 2\n", total)
return
}

modFactory, ok := GetRequestModifier("lura-request-modifier-example-request")
if !ok {
fmt.Println("modifier factory not found in the register")
return
}

modifier := modFactory(map[string]interface{}{})

input := requestWrapper{
ctx: context.WithValue(context.Background(), "myCtxKey", "some"),
path: "/bar",
method: "GET",
}

tmp, err := modifier(input)
if err != nil {
fmt.Println(err.Error())
return
}

output, ok := tmp.(RequestWrapper)
if !ok {
fmt.Println("unexpected result type")
return
}

if res := output.Path(); res != "/bar/fooo" {
fmt.Printf("unexpected result path. have %s, want /bar/fooo\n", res)
return
}

lines := strings.Split(buf.String(), "\n")
for i := range lines[:len(lines)-1] {
fmt.Println(lines[i][21:])
}

// output:
// DEBUG: [PLUGIN: lura-error-example] Logger loaded
// DEBUG: [PLUGIN: lura-request-modifier-example] Logger loaded
// DEBUG: [PLUGIN: lura-request-modifier-example] Context loaded
// DEBUG: [PLUGIN: lura-request-modifier-example] Request modifier injected
// DEBUG: context key: some
// DEBUG: params: map[]
// DEBUG: headers: map[]
// DEBUG: method: GET
// DEBUG: url: <nil>
// DEBUG: query: map[]
// DEBUG: path: /bar/fooo
}

func TestLoad(t *testing.T) {
total, err := Load("./tests", ".so", RegisterModifier)
total, err := LoadWithLogger("./tests", ".so", RegisterModifier, logging.NoOp)
if err != nil {
t.Error(err.Error())
t.Fail()
Expand All @@ -29,7 +104,7 @@ func TestLoad(t *testing.T) {

modifier := modFactory(map[string]interface{}{})

input := requestWrapper{path: "/bar"}
input := requestWrapper{ctx: context.WithValue(context.Background(), "myCtxKey", "some"), path: "/bar"}

tmp, err := modifier(input)
if err != nil {
Expand Down Expand Up @@ -59,6 +134,7 @@ type RequestWrapper interface {
}

type requestWrapper struct {
ctx context.Context
method string
url *url.URL
query url.Values
Expand All @@ -68,6 +144,7 @@ type requestWrapper struct {
headers map[string][]string
}

func (r requestWrapper) Context() context.Context { return r.ctx }
func (r requestWrapper) Method() string { return r.method }
func (r requestWrapper) URL() *url.URL { return r.url }
func (r requestWrapper) Query() url.Values { return r.query }
Expand Down
Loading

0 comments on commit 9249c56

Please sign in to comment.