From 6a05798e1cc5018b82676f51b4dbdd2f8e670118 Mon Sep 17 00:00:00 2001 From: kinggo <1963359402@qq.com> Date: Wed, 7 Sep 2022 21:46:58 +0800 Subject: [PATCH] feat: add reverseproxy (#5) * feat: add reverseproxy feature * docs: modify README.md Co-authored-by: Skyenought <1808644906@qq.com> --- .github/workflows/tests.yml | 2 +- .gitignore | 20 +- .licenserc.yaml | 10 + README.md | 87 +++++++ go.mod | 4 +- go.sum | 73 ++++++ reverse_proxy.go | 280 +++++++++++++++++++++ reverse_proxy_test.go | 479 ++++++++++++++++++++++++++++++++++++ 8 files changed, 952 insertions(+), 3 deletions(-) create mode 100644 .licenserc.yaml create mode 100644 README.md create mode 100644 go.sum create mode 100644 reverse_proxy.go create mode 100644 reverse_proxy_test.go diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1d38c83..5da2806 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -4,7 +4,7 @@ on: [push, pull_request] jobs: lint-and-ut: - runs-on: self-hosted + runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index 62c8935..9017396 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,19 @@ -.idea/ \ No newline at end of file +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# IDE config +.vscode +.idea diff --git a/.licenserc.yaml b/.licenserc.yaml new file mode 100644 index 0000000..41e44dc --- /dev/null +++ b/.licenserc.yaml @@ -0,0 +1,10 @@ +header: + license: + spdx-id: Apache-2.0 + copyright-owner: CloudWeGo Authors + + paths: + - '**/*.go' + - '**/*.s' + + comment: on-failure diff --git a/README.md b/README.md new file mode 100644 index 0000000..f8f755e --- /dev/null +++ b/README.md @@ -0,0 +1,87 @@ +# reverseproxy (This is a community driven project) + +`reserve proxy` extension for Hertz + +## Quick Start + +```go +package main + +import ( + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/hertz-contrib/reverseproxy" +) + +func main() { + h := server.New() + rp, _ := reverseproxy.NewSingleHostReverseProxy("http://localhost:8082/test") + h.GET("/ping", rp.ServeHTTP) + h.Spin() +} +``` + +### Use tls + +Currently [netpoll](https://github.com/cloudwego/netpoll) does not support tls,we need to use the `net` (standard library) + +```go +package main + +import ( + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/network/standard" + "github.com/hertz-contrib/reverseproxy" +) + +func main() { + h := server.New() + rp, _ := reverseproxy.NewSingleHostReverseProxy("https://localhost:8082/test",client.WithDialer(standard.NewDialer())) + h.GET("/ping", rp.ServeHTTP) + h.Spin() +} +``` + +### Use service discovery + +Use `nacos` as example and more information refer to [registry](https://github.com/hertz-contrib/registry) + +```go +package main + +import ( + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/app/middlewares/client/sd" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/common/config" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/hertz-contrib/registry/nacos" + "github.com/hertz-contrib/reverseproxy" +) + +func main() { + cli, err := client.NewClient() + if err != nil{ + panic(err) + } + r, err := nacos.NewDefaultNacosResolver() + if err != nil{ + panic(err) + } + cli.Use(sd.Discovery(r)) + h := server.New() + rp, _ := reverseproxy.NewSingleHostReverseProxy("http://test.demo.api/test") + rp.SetClient(cli) + rp.SetDirector(func(req *protocol.Request){ + req.SetRequestURI(string(reverseproxy.JoinURLPath(req, rp.Target))) + req.Header.SetHostBytes(req.URI().Host()) + req.Options().Apply([]config.RequestOption{config.WithSD(true)}) + }) + h.GET("/ping", rp.ServeHTTP) + h.Spin() +} +``` + +### Request/Response + +`ReverseProxy` provides `SetDirector`、`SetModifyResponse`、`SetErrorHandler` to modify `Request` and `Response`. \ No newline at end of file diff --git a/go.mod b/go.mod index 9cb3188..073f8a5 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/hertz-contrib/reverseproxy -go 1.18 +go 1.16 + +require github.com/cloudwego/hertz v0.3.0 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..fc2cf6f --- /dev/null +++ b/go.sum @@ -0,0 +1,73 @@ +github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I= +github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= +github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= +github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= +github.com/bytedance/sonic v1.3.0 h1:T2rlvNytw6bTmczlAXvGqmuMzIqGJBOsJKYwRPWR7Y8= +github.com/bytedance/sonic v1.3.0/go.mod h1:V973WhNhGmvHxW6nQmsHEfHaoU9F3zTF+93rH03hcUQ= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06 h1:1sDoSuDPWzhkdzNVxCxtIaKiAe96ESVPv8coGwc1gZ4= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/cloudwego/hertz v0.3.0 h1:gyKaw/uc7I2YLuCzqWHD3G2W8MVKRZPZ5u8k8tiDBbk= +github.com/cloudwego/hertz v0.3.0/go.mod h1:GWWYlAVkq1gDu6vJd/XNciWsP6q0d4TrEKk5fpJYF04= +github.com/cloudwego/netpoll v0.2.4 h1:Kbo2HA1cXEgoy/bu1jSNrjcqZj2diENcJqLy6vKiROU= +github.com/cloudwego/netpoll v0.2.4/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/goccy/go-json v0.9.4 h1:L8MLKG2mvVXiQu07qB6hmfqeSYQdOnqPot2GhsIwIaI= +github.com/goccy/go-json v0.9.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= +github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= +github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= +github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc= +github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe h1:W8vbETX/n8S6EmY0Pu4Ix7VvpsJUESTwl0oCK8MJOgk= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/reverse_proxy.go b/reverse_proxy.go new file mode 100644 index 0000000..27b94c3 --- /dev/null +++ b/reverse_proxy.go @@ -0,0 +1,280 @@ +// Copyright 2021 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package reverseproxy + +import ( + "bytes" + "context" + "fmt" + "net" + "net/textproto" + "reflect" + "strings" + "unsafe" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/common/config" + "github.com/cloudwego/hertz/pkg/common/hlog" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +type ReverseProxy struct { + client *client.Client + + // target is set as a reverse proxy address + Target string + + // director must be a function which modifies the request + // into a new request. Its response is then redirected + // back to the original client unmodified. + // director must not access the provided Request + // after returning. + director func(*protocol.Request) + + // modifyResponse is an optional function that modifies the + // Response from the backend. It is called if the backend + // returns a response at all, with any HTTP status code. + // If the backend is unreachable, the optional errorHandler is + // called without any call to modifyResponse. + // + // If modifyResponse returns an error, errorHandler is called + // with its error value. If errorHandler is nil, its default + // implementation is used. + modifyResponse func(*protocol.Response) error + + // errorHandler is an optional function that handles errors + // reaching the backend or errors from modifyResponse. + // + // If nil, the default is to log the provided error and return + // a 502 Status Bad Gateway response. + errorHandler func(*app.RequestContext, error) +} + +// Hop-by-hop headers. These are removed when sent to the backend. +// As of RFC 7230, hop-by-hop headers are required to appear in the +// Connection header field. These are the headers defined by the +// obsoleted RFC 2616 (section 13.5.1) and are used for backward +// compatibility. +var hopHeaders = []string{ + "Connection", + "Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", // canonicalized version of "TE" + "Trailer", // not Trailers per URL above; https://www.rfc-editor.org/errata_search.php?eid=4522 + "Transfer-Encoding", + "Upgrade", +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that routes +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +// NewSingleHostReverseProxy does not rewrite the Host header. +// To rewrite Host headers, use ReverseProxy directly with a custom +// director policy. +func NewSingleHostReverseProxy(target string, options ...config.ClientOption) (*ReverseProxy, error) { + r := &ReverseProxy{ + Target: target, + director: func(req *protocol.Request) { + req.SetRequestURI(b2s(JoinURLPath(req, target))) + req.Header.SetHostBytes(req.URI().Host()) + }, + } + c, err := client.NewClient(options...) + if err != nil { + return nil, err + } + r.client = c + return r, nil +} + +func JoinURLPath(req *protocol.Request, target string) (path []byte) { + aslash := req.URI().Path()[0] == '/' + var bslash bool + if strings.HasPrefix(target, "http") { + // absolute path + bslash = strings.HasSuffix(target, "/") + } else { + // default redirect to local + bslash = strings.HasPrefix(target, "/") + if bslash { + target = fmt.Sprintf("%s%s", req.Host(), target) + } else { + target = fmt.Sprintf("%s/%s", req.Host(), target) + } + bslash = strings.HasSuffix(target, "/") + } + + targetQuery := strings.Split(target, "?") + var buffer bytes.Buffer + buffer.WriteString(targetQuery[0]) + switch { + case aslash && bslash: + buffer.Write(req.URI().Path()[1:]) + case !aslash && !bslash: + buffer.Write([]byte{'/'}) + buffer.Write(req.URI().Path()) + default: + buffer.Write(req.URI().Path()) + } + if len(targetQuery) > 1 { + buffer.Write([]byte{'?'}) + buffer.WriteString(targetQuery[1]) + } + if len(req.QueryString()) > 0 { + if len(targetQuery) == 1 { + buffer.Write([]byte{'?'}) + } else { + buffer.Write([]byte{'&'}) + } + buffer.Write(req.QueryString()) + } + return buffer.Bytes() +} + +// removeRequestConnHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeRequestConnHeaders(c *app.RequestContext) { + c.Request.Header.VisitAll(func(k, v []byte) { + if b2s(k) == "Connection" { + for _, sf := range strings.Split(b2s(v), ",") { + if sf = textproto.TrimString(sf); sf != "" { + c.Request.Header.DelBytes(s2b(sf)) + } + } + } + }) +} + +// removeRespConnHeaders removes hop-by-hop headers listed in the "Connection" header of h. +// See RFC 7230, section 6.1 +func removeResponseConnHeaders(c *app.RequestContext) { + c.Response.Header.VisitAll(func(k, v []byte) { + if b2s(k) == "Connection" { + for _, sf := range strings.Split(b2s(v), ",") { + if sf = textproto.TrimString(sf); sf != "" { + c.Response.Header.DelBytes(s2b(sf)) + } + } + } + }) +} + +func (r *ReverseProxy) defaultErrorHandler(c *app.RequestContext, _ error) { + c.Response.Header.SetStatusCode(consts.StatusBadGateway) +} + +func (r *ReverseProxy) ServeHTTP(c context.Context, ctx *app.RequestContext) { + req := &ctx.Request + resp := &ctx.Response + + if r.director != nil { + r.director(&ctx.Request) + } + req.Header.ResetConnectionClose() + + removeRequestConnHeaders(ctx) + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. + for _, h := range hopHeaders { + req.Header.DelBytes(s2b(h)) + } + + // prepare request(replace headers and some URL host) + if ip, _, err := net.SplitHostPort(ctx.RemoteAddr().String()); err == nil { + tmp := req.Header.Peek("X-Forwarded-For") + if len(tmp) > 0 { + ip = fmt.Sprintf("%s, %s", tmp, ip) + } + if tmp == nil || string(tmp) != "" { + req.Header.Add("X-Forwarded-For", ip) + } + } + err := r.client.Do(c, req, resp) + if err != nil { + hlog.CtxErrorf(c, "HERTZ: Client request error: %#v", err.Error()) + r.getErrorHandler()(ctx, err) + return + } + + removeResponseConnHeaders(ctx) + + for _, h := range hopHeaders { + resp.Header.DelBytes(s2b(h)) + } + + if r.modifyResponse == nil { + return + } + err = r.modifyResponse(resp) + if err != nil { + r.getErrorHandler()(ctx, err) + } +} + +// SetDirector use to customize protocol.Request +func (r *ReverseProxy) SetDirector(director func(req *protocol.Request)) { + r.director = director +} + +// SetClient use to customize client +func (r *ReverseProxy) SetClient(client *client.Client) { + r.client = client +} + +// SetModifyResponse use to modify response +func (r *ReverseProxy) SetModifyResponse(mr func(*protocol.Response) error) { + r.modifyResponse = mr +} + +// SetErrorHandler use to customize error handler +func (r *ReverseProxy) SetErrorHandler(eh func(c *app.RequestContext, err error)) { + r.errorHandler = eh +} + +func (r *ReverseProxy) getErrorHandler() func(c *app.RequestContext, err error) { + if r.errorHandler != nil { + return r.errorHandler + } + return r.defaultErrorHandler +} + +// b2s converts byte slice to a string without memory allocation. +// See https://groups.google.com/forum/#!msg/Golang-Nuts/ENgbUzYvCuU/90yGx7GUAgAJ . +// +// Note it may break if string and/or slice header will change +// in the future go versions. +func b2s(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + +// s2b converts string to a byte slice without memory allocation. +// +// Note it may break if string and/or slice header will change +// in the future go versions. +func s2b(s string) (b []byte) { + bh := (*reflect.SliceHeader)(unsafe.Pointer(&b)) + sh := (*reflect.StringHeader)(unsafe.Pointer(&s)) + bh.Data = sh.Data + bh.Cap = sh.Len + bh.Len = sh.Len + return +} diff --git a/reverse_proxy_test.go b/reverse_proxy_test.go new file mode 100644 index 0000000..ea586e1 --- /dev/null +++ b/reverse_proxy_test.go @@ -0,0 +1,479 @@ +// Copyright 2021 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +package reverseproxy + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "strings" + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/protocol" +) + +// Reverse proxy tests. +const fakeHopHeader = "X-Fake-Hop-Header-For-Test" +const fakeConnectionToken = "X-Fake-Connection-Token" + +func init() { + hopHeaders = append(hopHeaders, fakeHopHeader) + hopHeaders = append(hopHeaders, fakeConnectionToken) +} + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + r := server.New(server.WithHostPorts("127.0.0.1:9990")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + if ctx.Query("mode") == "hangup" { + ctx.GetConn().Close() + return + } + if ctx.Request.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if c := ctx.Request.Header.Get("Connection"); c != "" { + t.Errorf("handler got Connection header value %q", c) + } + + if c := ctx.Request.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got Upgrade header value %q", c) + } + if c := ctx.Request.Header.Get("Proxy-Connection"); c != "" { + t.Errorf("handler got Proxy-Connection header value %q", c) + } + + ctx.Response.Header.Set("Trailers", "not a special header field name") + ctx.Response.Header.Set("Trailer", "X-Trailer") + ctx.Response.Header.Set("X-Foo", "bar") + ctx.Response.Header.Set("Upgrade", "foo") + ctx.Response.Header.Set(fakeHopHeader, "foo") + ctx.Response.Header.Add("X-Multi-Value", "foo") + ctx.Response.Header.Add("X-Multi-Value", "bar") + c := protocol.AcquireCookie() + c.SetKey("flavor") + c.SetValue("chocolateChip") + ctx.Response.Header.SetCookie(c) + protocol.ReleaseCookie(c) + ctx.Response.Header.Set("X-Trailer", "trailer_value") + ctx.Response.Header.Set(http.TrailerPrefix+"X-Unannounced-Trailer", "unannounced_trailer_value") + ctx.Data(backendStatus, "application/json", []byte(backendResponse)) + }) + + proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9990/proxy") + if err != nil { + t.Errorf("proxy error: %v", err) + } + + r.GET("/backend", func(c context.Context, ctx *app.RequestContext) { + proxy.ServeHTTP(c, ctx) + }) + go r.Spin() + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + res := protocol.AcquireResponse() + req.Header.Set("Connection", "close, TE") + req.Header.Add("Te", "foo") + req.Header.Add("Te", "bar, trailers") + req.Header.Set("Proxy-Connection", "should be deleted") + req.Header.Set("Upgrade", "foo") + req.SetConnectionClose() + req.SetHost("some-name") + req.SetRequestURI("http://localhost:9990/backend") + cli.Do(context.Background(), req, res) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode(), backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + if c := res.Header.Get(fakeHopHeader); c != "" { + t.Errorf("got %s header value %q", fakeHopHeader, c) + } + if g, e := res.Header.Get("Trailers"), "not a special header field name"; g != e { + t.Errorf("header Trailers = %q; want %q", g, e) + } + length := 0 + res.Header.VisitAll(func(key, value []byte) { + if string(key) == "X-Multi-Value" { + length++ + } + }) + if length != 2 { + t.Errorf("got %d X-Multi-Value header values; expected %d", 2, length) + } + length = 0 + res.Header.VisitAll(func(key, value []byte) { + if string(key) == "Set-Cookie" { + length++ + } + }) + if length != 1 { + t.Fatalf("got %d SetCookies, want %d", 1, 0) + } + + cookie := protocol.AcquireCookie() + cookie.SetKey("flavor") + if has := res.Header.Cookie(cookie); !has { + t.Errorf("unexpected cookie %q", cookie) + } + if g, e := string(res.Body()), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + + // Test that a backend failing to be reached or one which doesn't return + // a response results in a StatusBadGateway. + req.SetRequestURI("http://localhost:9990/backend?mode=hangup") + cli.Do(context.Background(), req, res) + if res.StatusCode() != http.StatusBadGateway { + t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.StatusCode()) + } +} + +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + hopHeaders = append(hopHeaders, fakeHopHeader) + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + r := server.New(server.WithHostPorts("127.0.0.1:9991")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + if c := ctx.Request.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := ctx.Request.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := ctx.Request.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + ctx.Response.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + ctx.Response.Header.Add("Connection", someConnHeader) + ctx.Response.Header.Set(someConnHeader, "should be deleted") + ctx.Response.Header.Set(fakeConnectionToken, "should be deleted") + ctx.Data(200, "application/json", []byte(backendResponse)) + }) + + proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9991/proxy") + if err != nil { + t.Errorf("proxy error: %v", err) + } + + r.GET("/backend", func(cc context.Context, ctx *app.RequestContext) { + proxy.ServeHTTP(cc, ctx) + }) + go r.Spin() + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + resp := protocol.AcquireResponse() + req.Header.Set(someConnHeader, "should be deleted") + req.Header.Add("Connection", "Upgrade, "+fakeConnectionToken) + req.Header.Add("Connection", someConnHeader) + req.Header.Set(fakeConnectionToken, "should be deleted") + req.SetRequestURI("http://localhost:9991/backend") + cli.Do(context.Background(), req, resp) + + if got, want := string(resp.Body()), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := resp.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := resp.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + if c := resp.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + +func TestReverseProxyStripEmptyConnection(t *testing.T) { + const backendResponse = "I am the backend" + + // someConnHeader is some arbitrary header to be declared as a hop-by-hop header + // in the Request's Connection header. + const someConnHeader = "X-Some-Conn-Header" + r := server.New(server.WithHostPorts("127.0.0.1:9992")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + if c := ctx.Request.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %v; want empty", "Connection", c) + } + if c := ctx.Request.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } + ctx.Response.Header.Add("Connection", "") + ctx.Response.Header.Add("Connection", someConnHeader) + ctx.Response.Header.Set(someConnHeader, "should be deleted") + ctx.Data(200, "application/json", []byte(backendResponse)) + }) + proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9992/proxy") + if err != nil { + t.Errorf("proxy error: %v", err) + } + r.GET("/backend", func(cc context.Context, ctx *app.RequestContext) { + proxy.ServeHTTP(cc, ctx) + }) + go r.Spin() + defer r.Shutdown(context.TODO()) + + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + resp := protocol.AcquireResponse() + + req.Header.Add("Connection", "") + req.Header.Add("Connection", someConnHeader) + req.Header.Set(someConnHeader, "should be deleted") + req.SetRequestURI("http://localhost:9992/backend") + cli.Do(context.Background(), req, resp) + + if got, want := string(resp.Body()), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := resp.Header.Get("Connection"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Connection", c) + } + if c := resp.Header.Get(someConnHeader); c != "" { + t.Errorf("handler got header %q = %q; want empty", someConnHeader, c) + } +} + +func TestXForwardedFor(t *testing.T) { + const prevForwardedFor = "client ip" + const backendResponse = "I am the backend" + const backendStatus = 404 + r := server.New(server.WithHostPorts("127.0.0.1:9993")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + if ctx.Request.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + if !strings.Contains(ctx.Request.Header.Get("X-Forwarded-For"), prevForwardedFor) { + t.Errorf("X-Forwarded-For didn't contain prior data") + } + ctx.Data(backendStatus, "application/json", []byte(backendResponse)) + }) + proxy, err := NewSingleHostReverseProxy("http://127.0.0.1:9993/proxy") + if err != nil { + t.Errorf("proxy error: %v", err) + } + r.GET("/backend", proxy.ServeHTTP) + go r.Spin() + defer r.Shutdown(context.TODO()) + + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + resp := protocol.AcquireResponse() + + req.Header.Set("Connection", "close") + req.Header.Set("X-Forwarded-For", prevForwardedFor) + req.SetConnectionClose() + req.SetRequestURI("http://localhost:9993/backend") + cli.Do(context.Background(), req, resp) + if g, e := resp.StatusCode(), backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := string(resp.Body()), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +func TestXForwardedFor_Omit(t *testing.T) { + r := server.New(server.WithHostPorts("127.0.0.1:9994")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + if v := ctx.Request.Header.Get("X-Forwarded-For"); v != "" { + t.Errorf("got X-Forwarded-For header: %q", v) + } + ctx.Data(200, "application/json", []byte("hi")) + }) + proxy, _ := NewSingleHostReverseProxy("http://127.0.0.1:9994/proxy") + proxy.SetDirector(func(req *protocol.Request) { + req.Header.Set("X-Forwarded-For", "") + }) + r.GET("/backend", proxy.ServeHTTP) + go r.Spin() + defer r.Shutdown(context.TODO()) + + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + resp := protocol.AcquireResponse() + req.SetConnectionClose() + req.SetRequestURI("http://localhost:9994/backend") + cli.Do(context.Background(), req, resp) +} + +var proxyQueryTests = []struct { + baseSuffix string // suffix to add to backend URL + reqSuffix string // suffix to add to frontend's request URL + want string // what backend should see for final request URL (without ?) +}{ + {"?sta=tic", "", "sta=tic"}, +} + +func TestReverseProxyQuery(t *testing.T) { + r := server.New(server.WithHostPorts("127.0.0.1:9995")) + + r.GET("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + ctx.Response.Header.Set("X-Got-Query", string(ctx.Request.QueryString())) + ctx.Data(200, "application/json", []byte("hi")) + }) + + for i, tt := range proxyQueryTests { + proxy, _ := NewSingleHostReverseProxy("http://127.0.0.1:9995/proxy" + tt.baseSuffix) + r.GET("/backend", proxy.ServeHTTP) + go r.Spin() + defer r.Shutdown(context.TODO()) + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + resp := protocol.AcquireResponse() + req.SetRequestURI("http://localhost:9995/backend" + tt.reqSuffix) + cli.Do(context.Background(), req, resp) + if g, e := resp.Header.Get("X-Got-Query"), tt.want; g != e { + t.Errorf("%d. got query %q; expected %q", i, g, e) + } + } +} + +func TestReverseProxy_Post(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 200 + requestBody := bytes.Repeat([]byte("a"), 1<<20) + + r := server.New(server.WithHostPorts("127.0.0.1:9996")) + + r.POST("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + sluproxy := ctx.Request.Body() + if len(sluproxy) != len(requestBody) { + t.Errorf("Backend read %d request body bytes; want %d", len(sluproxy), len(requestBody)) + } + if !bytes.Equal(sluproxy, requestBody) { + t.Error("Backend read wrong request body.") // 1MB; omitting details + } + ctx.Data(backendStatus, "application/json", []byte(backendResponse)) + }) + proxy, _ := NewSingleHostReverseProxy("http://127.0.0.1:9996/proxy") + r.POST("/backend", proxy.ServeHTTP) + go r.Spin() + time.Sleep(time.Second) + defer r.Shutdown(context.TODO()) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + req.SetMethod("POST") + req.SetBodyRaw(requestBody) + resp := protocol.AcquireResponse() + req.SetConnectionClose() + req.SetRequestURI("http://localhost:9996/backend") + cli.Do(context.Background(), req, resp) + if g, e := resp.StatusCode(), backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := string(resp.Body()), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} + +func TestReverseProxyModifyResponse(t *testing.T) { + r := server.New(server.WithHostPorts("127.0.0.1:9997")) + + r.GET("/proxy/mod", func(cc context.Context, ctx *app.RequestContext) { + ctx.Response.Header.Add("X-Hit-Mod", fmt.Sprintf("%v", true)) + ctx.Data(200, "application/json", []byte("hi")) + }) + r.GET("/proxy/schedule", func(cc context.Context, ctx *app.RequestContext) { + ctx.Response.Header.Add("X-Hit-Mod", fmt.Sprintf("%v", false)) + ctx.Data(200, "application/json", []byte("hi")) + }) + proxy, _ := NewSingleHostReverseProxy("http://127.0.0.1:9997/proxy") + proxy.SetModifyResponse(func(response *protocol.Response) error { + if response.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + }) + + r.GET("/*n", proxy.ServeHTTP) + go r.Spin() + defer r.Shutdown(context.TODO()) + + time.Sleep(time.Second) + cli, _ := client.NewClient() + + tests := []struct { + url string + wantCode int + }{ + {"http://localhost:9997" + "/mod", http.StatusOK}, + {"http://localhost:9997" + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + req := protocol.AcquireRequest() + resp := protocol.AcquireResponse() + req.SetConnectionClose() + req.SetRequestURI(tt.url) + cli.Do(context.Background(), req, resp) + if g, e := resp.StatusCode(), tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + } +} + +func TestReverseProxyErrorHandler(t *testing.T) { + r := server.New(server.WithHostPorts("127.0.0.1:9998")) + + r.POST("/proxy/backend", func(cc context.Context, ctx *app.RequestContext) { + ctx.GetConn().Close() + }) + proxy, _ := NewSingleHostReverseProxy("http://127.0.0.1:9998/proxy") + proxy.SetErrorHandler(func(c *app.RequestContext, err error) { + c.Response.Header.SetStatusCode(http.StatusTeapot) + }) + proxy.SetModifyResponse(func(res *protocol.Response) error { + res.SetStatusCode(res.StatusCode() + 1) + return errors.New("some error to trigger errorHandler") + }) + r.POST("/backend", proxy.ServeHTTP) + go r.Spin() + time.Sleep(time.Second) + cli, _ := client.NewClient() + req := protocol.AcquireRequest() + req.SetMethod("POST") + resp := protocol.AcquireResponse() + req.SetRequestURI("http://127.0.0.1:9998/backend") + cli.Do(context.Background(), req, resp) + if g, e := resp.StatusCode(), http.StatusTeapot; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } +}