Skip to content

Commit

Permalink
optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
justlorain committed Dec 1, 2023
1 parent 793bb38 commit 1025dc1
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 32 deletions.
72 changes: 47 additions & 25 deletions ws_reverse_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,31 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// MIT License
//
// Copyright (c) 2018 YeQiang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//
// This file may have been modified by CloudWeGo authors.
// All CloudWeGo Modifications are Copyright 2023 CloudWeGo Authors.

package reverseproxy

Expand Down Expand Up @@ -52,17 +77,14 @@ func (w *WSReverseProxy) ServeHTTP(ctx context.Context, c *app.RequestContext) {
forwardHeader := prepareForwardHeader(ctx, c)
// NOTE: customer Director will overwrite existed header if they have the same header key
if w.options.Director != nil {
appendHeader := w.options.Director(ctx, c)
appendHeader.VisitAll(func(key, value []byte) {
forwardHeader.SetBytesKV(key, value)
})
w.options.Director(ctx, c, forwardHeader)
}
connBackend, respBackend, err := w.options.Dialer.Dial(w.target, ConvertHZHeaderToStdHeader(forwardHeader))
connBackend, respBackend, err := w.options.Dialer.Dial(w.target, forwardHeader)
if err != nil {
hlog.Errorf("can not dial to remote backend(%v): %v", w.target, err)
hlog.CtxErrorf(ctx, "can not dial to remote backend(%v): %v", w.target, err)
if respBackend != nil {
if err = wsCopyResponse(&c.Response, respBackend); err != nil {
hlog.Errorf("can not copy response: %v", err)
hlog.CtxErrorf(ctx, "can not copy response: %v", err)
}
} else {
c.AbortWithMsg(err.Error(), consts.StatusServiceUnavailable)
Expand All @@ -78,7 +100,7 @@ func (w *WSReverseProxy) ServeHTTP(ctx context.Context, c *app.RequestContext) {
errMsg string
)

hlog.Debugf("upgrade handler working...")
hlog.CtxDebugf(ctx, "upgrade handler working...")

// replicateWSRespConn
// ┌─────────────────────────────────┐
Expand All @@ -96,8 +118,8 @@ func (w *WSReverseProxy) ServeHTTP(ctx context.Context, c *app.RequestContext) {
// │ ◄───────────┤ (server) ◄─────────────┤ (server) │
// └──────────┘ └────────────────┘ └──────────┘

go replicateWSRespConn(connClient, connBackend, errClientC)
go replicateWSReqConn(connBackend, connClient, errBackendC)
go replicateWSRespConn(ctx, connClient, connBackend, errClientC)
go replicateWSReqConn(ctx, connBackend, connClient, errBackendC)

for {
select {
Expand All @@ -110,11 +132,11 @@ func (w *WSReverseProxy) ServeHTTP(ctx context.Context, c *app.RequestContext) {
var ce *websocket.CloseError
var hzce *hzws.CloseError
if !errors.As(err, &ce) || !errors.As(err, &hzce) {
hlog.Errorf(errMsg, err)
hlog.CtxErrorf(ctx, errMsg, err)
}
}
}); err != nil {
hlog.Errorf("can not upgrade to websocket: %v", err)
hlog.CtxErrorf(ctx, "can not upgrade to websocket: %v", err)
}
}

Expand All @@ -128,8 +150,8 @@ func ConvertHZHeaderToStdHeader(hzHeader *protocol.RequestHeader) http.Header {
return header
}

func prepareForwardHeader(_ context.Context, c *app.RequestContext) *protocol.RequestHeader {
forwardHeader := &protocol.RequestHeader{}
func prepareForwardHeader(_ context.Context, c *app.RequestContext) http.Header {
forwardHeader := new(protocol.RequestHeader)
if origin := c.Request.Header.Peek("Origin"); origin != nil {
forwardHeader.SetBytesKV([]byte("Origin"), origin)
}
Expand All @@ -153,61 +175,61 @@ func prepareForwardHeader(_ context.Context, c *app.RequestContext) *protocol.Re
if string(c.Request.URI().Scheme()) == "https" {
forwardHeader.Set("X-Forwarded-Proto", "https")
}
return forwardHeader
return ConvertHZHeaderToStdHeader(forwardHeader)
}

func replicateWSReqConn(dst *websocket.Conn, src *hzws.Conn, errC chan error) {
func replicateWSReqConn(ctx context.Context, dst *websocket.Conn, src *hzws.Conn, errC chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
hlog.Errorf("read message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
hlog.CtxErrorf(ctx, "read message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
var ce *hzws.CloseError
if errors.As(err, &ce) {
msg = hzws.FormatCloseMessage(ce.Code, ce.Text)
} else {
hlog.Errorf("read message failed when replicate websocket conn: err=%v", err)
hlog.CtxErrorf(ctx, "read message failed when replicate websocket conn: err=%v", err)
msg = hzws.FormatCloseMessage(hzws.CloseAbnormalClosure, err.Error())
}
errC <- err

if err = dst.WriteMessage(websocket.CloseMessage, msg); err != nil {
hlog.Errorf("write message failed when replicate websocket conn: err=%v", err)
hlog.CtxErrorf(ctx, "write message failed when replicate websocket conn: err=%v", err)
}
break
}

err = dst.WriteMessage(msgType, msg)
if err != nil {
hlog.Errorf("write message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
hlog.CtxErrorf(ctx, "write message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
errC <- err
break
}
}
}

func replicateWSRespConn(dst *hzws.Conn, src *websocket.Conn, errC chan error) {
func replicateWSRespConn(ctx context.Context, dst *hzws.Conn, src *websocket.Conn, errC chan error) {
for {
msgType, msg, err := src.ReadMessage()
if err != nil {
hlog.Errorf("read message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
hlog.CtxErrorf(ctx, "read message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
var ce *websocket.CloseError
if errors.As(err, &ce) {
msg = websocket.FormatCloseMessage(ce.Code, ce.Text)
} else {
hlog.Errorf("read message failed when replicate websocket conn: err=%v", err)
hlog.CtxErrorf(ctx, "read message failed when replicate websocket conn: err=%v", err)
msg = websocket.FormatCloseMessage(websocket.CloseAbnormalClosure, err.Error())
}
errC <- err

if err = dst.WriteMessage(hzws.CloseMessage, msg); err != nil {
hlog.Errorf("write message failed when replicate websocket conn: err=%v", err)
hlog.CtxErrorf(ctx, "write message failed when replicate websocket conn: err=%v", err)
}
break
}

err = dst.WriteMessage(msgType, msg)
if err != nil {
hlog.Errorf("write message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
hlog.CtxErrorf(ctx, "write message failed when replicating websocket conn: msgType=%v msg=%v err=%v", msgType, msg, err)
errC <- err
break
}
Expand Down
4 changes: 2 additions & 2 deletions ws_reverse_proxy_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ package reverseproxy

import (
"context"
"net/http"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/gorilla/websocket"
hzws "github.com/hertz-contrib/websocket"
)

type Director func(ctx context.Context, c *app.RequestContext) *protocol.RequestHeader
type Director func(ctx context.Context, c *app.RequestContext, forwardHeader http.Header)

type Option func(o *Options)

Expand Down
8 changes: 3 additions & 5 deletions ws_reverse_proxy_option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,18 @@ package reverseproxy
import (
"context"
"fmt"
"net/http"
"testing"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/gorilla/websocket"
hzws "github.com/hertz-contrib/websocket"
)

func TestOptions(t *testing.T) {
director := func(ctx context.Context, c *app.RequestContext) *protocol.RequestHeader {
header := &protocol.RequestHeader{}
header.Add("X-Test-Head", "content")
return header
director := func(ctx context.Context, c *app.RequestContext, forwardHeader http.Header) {
forwardHeader.Add("X-Test-Head", "content")
}
dialer := websocket.DefaultDialer
upgrader := &hzws.HertzUpgrader{
Expand Down

0 comments on commit 1025dc1

Please sign in to comment.