From 859bf41a0c1eb46483c52273345c74ab0af2415f Mon Sep 17 00:00:00 2001 From: luopengift <870148195@qq.com> Date: Sun, 17 Nov 2024 03:06:21 +0800 Subject: [PATCH] fix: Logf override the past. update OnStart and OnShutdown. --- options.go | 22 +++++++++++++++++++--- requests_test.go | 5 +++-- server.go | 45 +++++++++++++++++---------------------------- server_test.go | 33 +++++++++++++++++++++++---------- transport.go | 4 +++- 5 files changed, 65 insertions(+), 44 deletions(-) diff --git a/options.go b/options.go index 5ba8be2..08f7de2 100644 --- a/options.go +++ b/options.go @@ -25,12 +25,15 @@ type Options struct { MaxConns int Verify bool Stream func(int64, []byte) error + Log func(ctx context.Context, stat *Stat) Transport http.RoundTripper HttpRoundTripper []func(http.RoundTripper) http.RoundTripper Handler http.Handler HttpHandler []func(http.Handler) http.Handler + OnStart func(*http.Server) + OnShutdown func(*http.Server) certFile string keyFile string @@ -53,6 +56,9 @@ func newOptions(opts []Option, extends ...Option) Options { Timeout: 30 * time.Second, MaxConns: 100, Proxy: http.ProxyFromEnvironment, + + OnStart: func(*http.Server) {}, + OnShutdown: func(*http.Server) {}, } for _, o := range opts { o(&opt) @@ -283,10 +289,20 @@ func RoundTripper(tr http.RoundTripper) Option { } // Logf print log -func Logf(f func(ctx context.Context, stat *Stat)) Option { +func Logf(f func(context.Context, *Stat)) Option { + return func(o *Options) { + o.Log = f + } +} + +func OnStart(f func(*http.Server)) Option { return func(o *Options) { - o.HttpRoundTripper = append(o.HttpRoundTripper, printRoundTripper(f)) - o.HttpHandler = append(o.HttpHandler, printHandler(f)) + o.OnStart = f + } +} +func OnShutdown(f func(*http.Server)) Option { + return func(o *Options) { + o.OnShutdown = f } } diff --git a/requests_test.go b/requests_test.go index 298f93a..42c4b6f 100644 --- a/requests_test.go +++ b/requests_test.go @@ -48,7 +48,7 @@ func Test_PostBody(t *testing.T) { sess := New( BasicAuth("user", "123456"), Logf(func(context.Context, *Stat) { - fmt.Println("session") + fmt.Println("111111111") }), ) @@ -71,7 +71,8 @@ func Test_PostBody(t *testing.T) { Header("hello", "world"), //TraceLv(9), Logf(func(ctx context.Context, stat *Stat) { - t.Logf("%v", stat.String()) + fmt.Println(22222222) + //t.Logf("%v", stat.String()) }), ) if err != nil { diff --git a/server.go b/server.go index 7a63393..373e962 100644 --- a/server.go +++ b/server.go @@ -153,6 +153,9 @@ func (mux *ServeMux) Use(fn ...func(http.Handler) http.Handler) { func (mux *ServeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { current := mux.root.Find(r.URL.Path[1:]) handler, options := current.handler, newOptions(mux.opts, current.opts...) + if options.Log != nil { + options.HttpHandler = append(options.HttpHandler, printHandler(options.Log)) + } for _, h := range options.HttpHandler { handler = h(handler) } @@ -172,19 +175,16 @@ func (mux *ServeMux) Pprof() { type Server struct { options Options server *http.Server - - onStartup func(*http.Server) - onShutdown func(*http.Server) } // NewServer new server, opts is not add to ServeMux func NewServer(ctx context.Context, h http.Handler, opts ...Option) *Server { - s := &Server{server: &http.Server{Handler: h}, onStartup: func(*http.Server) {}, onShutdown: func(*http.Server) {}} - if mux, ok := h.(*ServeMux); ok { - s.options = newOptions(mux.opts, opts...) - } else { - s.options = newOptions(opts) + s := &Server{server: &http.Server{Handler: h}} + mux, ok := h.(*ServeMux) + if !ok { + mux = NewServeMux() } + s.options = newOptions(mux.opts, opts...) u, err := url.Parse(s.options.URL) if err != nil { @@ -192,27 +192,18 @@ func NewServer(ctx context.Context, h http.Handler, opts ...Option) *Server { } s.server.Addr = u.Host - s.server.RegisterOnShutdown(func() { s.onShutdown(s.server) }) - - go func() { - select { - case <-ctx.Done(): - if err := s.server.Shutdown(ctx); err != nil { - panic(err) - } - } - }() + s.options.OnStart(s.server) + s.server.RegisterOnShutdown(func() { s.options.OnShutdown(s.server) }) + go s.Shutdown(ctx) return s } -// OnStartup do something before serve startup -func (s *Server) OnStartup(f func(s *http.Server)) { - s.onStartup = f -} - -// OnShutdown do something after serve shutdown -func (s *Server) OnShutdown(f func(s *http.Server)) { - s.onShutdown = f +// Shutdown gracefully shuts down the server without interrupting any active connections. +func (s *Server) Shutdown(ctx context.Context) error { + select { + case <-ctx.Done(): + return s.server.Shutdown(ctx) + } } // ListenAndServe listens on the TCP network address srv.Addr and then @@ -231,8 +222,6 @@ func (s *Server) OnShutdown(f func(s *http.Server)) { // ListenAndServe(TLS) always returns a non-nil error. After [Server.Shutdown] or // [Server.Close], the returned error is [ErrServerClosed]. func (s *Server) ListenAndServe() (err error) { - - s.onStartup(s.server) if s.options.certFile == "" || s.options.keyFile == "" { return s.server.ListenAndServe() } diff --git a/server_test.go b/server_test.go index 2bce215..f561c7d 100644 --- a/server_test.go +++ b/server_test.go @@ -23,11 +23,27 @@ func (h) ServeHTTP(w http.ResponseWriter, r *http.Request) { } func Test_Server(t *testing.T) { - mux := requests.NewServeMux() + mux := requests.NewServeMux(requests.Logf(func(ctx context.Context, stat *requests.Stat) { + fmt.Println("serveMux11111----------") + }), requests.Logf(func(ctx context.Context, stat *requests.Stat) { + fmt.Println("serveMux222222----------") + }), + ) mux.Pprof() - s := requests.NewServer(context.Background(), mux, requests.URL("http://127.0.0.1:6066")) - s.OnStartup(func(s *http.Server) { fmt.Println("http serve") }) - mux.Handle("/handle", h{}) + s := requests.NewServer( + context.Background(), + mux, + requests.URL("http://127.0.0.1:6066"), + requests.Logf(func(ctx context.Context, stat *requests.Stat) { + fmt.Println("--------server") + }), + requests.OnStart(func(s *http.Server) { + fmt.Println("OnStart", s.Addr) + }), + ) + mux.Handle("/handle", h{}, requests.Logf(func(ctx context.Context, stat *requests.Stat) { + fmt.Println("handle------") + })) mux.HandleFunc("/handler", func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("handler ok")) }) @@ -80,10 +96,9 @@ func Test_Use(t *testing.T) { ) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - s := requests.NewServer(ctx, r, requests.URL("http://0.0.0.0:9099")) - s.OnShutdown(func(s *http.Server) { + s := requests.NewServer(ctx, r, requests.URL("http://0.0.0.0:9099"), requests.OnShutdown(func(s *http.Server) { t.Logf("http: %s shutdown...", s.Addr) - }) + })) r.Pprof() go s.ListenAndServe() time.Sleep(1 * time.Second) @@ -93,9 +108,7 @@ func Test_Use(t *testing.T) { //sess.DoRequest(context.Background(), Path("/ping"), Logf(LogS)) //sess.DoRequest(context.Background(), Path("/1234"), Logf(LogS)) - select { - case <-ctx.Done(): - } + time.Sleep(3 * time.Second) } diff --git a/transport.go b/transport.go index 2982f7b..f4bbd69 100644 --- a/transport.go +++ b/transport.go @@ -123,7 +123,9 @@ func (t *Transport) RoundTripper(opts ...Option) http.RoundTripper { if options.Transport == nil { options.Transport = t.Transport } - + if options.Log != nil { + options.HttpRoundTripper = append(options.HttpRoundTripper, printRoundTripper(options.Log)) + } for _, tr := range options.HttpRoundTripper { options.Transport = tr(options.Transport) }