From a80f1e8fb14ca2028fcea41621e4a7bad6fa8398 Mon Sep 17 00:00:00 2001 From: caixw Date: Mon, 20 May 2024 16:18:39 +0800 Subject: [PATCH] =?UTF-8?q?fix(group):=20=E4=BF=AE=E6=AD=A3=20Group=20?= =?UTF-8?q?=E4=B8=8B=E7=9A=84=E8=B7=AF=E7=94=B1=E5=90=8C=E4=B8=80=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E4=BB=B6=E5=A4=9A=E6=AC=A1=E8=B0=83=E7=94=A8=E7=9A=84?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- group/group.go | 10 ++++++---- group/group_test.go | 27 +++++++++++++++++++++++++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/group/group.go b/group/group.go index bf0fbae..a4ec2d3 100644 --- a/group/group.go +++ b/group/group.go @@ -21,8 +21,9 @@ type ( GroupOf[T any] struct { routers []*routerOf[T] - call mux.CallOf[T] - notFound T + call mux.CallOf[T] + notFound T // 所有路由都找不差时调用的方法,该方法应用的中间件中 router 参数是为空的。 + originNotFound T // 这是应用中间件的 notFound methodNotAllowedBuilder, optionsBuilder types.BuildNodeHandleOf[T] options []options.Option @@ -45,6 +46,7 @@ func NewOf[T any](call mux.CallOf[T], notFound T, methodNotAllowedBuilder, optio call: call, notFound: notFound, + originNotFound: notFound, methodNotAllowedBuilder: methodNotAllowedBuilder, optionsBuilder: optionsBuilder, options: o, @@ -73,7 +75,7 @@ func (g *GroupOf[T]) ServeHTTP(w http.ResponseWriter, r *http.Request) { // 新路由会继承 [NewOf] 中指定的参数,其中的 o 可以覆盖由 [NewOf] 中指定的相关参数; func (g *GroupOf[T]) New(name string, matcher Matcher, o ...mux.Option) *mux.RouterOf[T] { o = g.mergeOption(o...) - r := mux.NewRouterOf(name, g.call, g.notFound, g.methodNotAllowedBuilder, g.optionsBuilder, o...) + r := mux.NewRouterOf(name, g.call, g.originNotFound, g.methodNotAllowedBuilder, g.optionsBuilder, o...) g.Add(matcher, r) return r } @@ -130,8 +132,8 @@ func (g *GroupOf[T]) Use(m ...types.MiddlewareOf[T]) { for _, r := range g.routers { r.r.Use(m...) } - g.notFound = tree.ApplyMiddleware(g.notFound, "", "", "", m...) + g.notFound = tree.ApplyMiddleware(g.notFound, "", "", "", m...) g.middleware = append(g.middleware, m...) } diff --git a/group/group_test.go b/group/group_test.go index bcbf3c6..4e49fa2 100644 --- a/group/group_test.go +++ b/group/group_test.go @@ -61,31 +61,54 @@ func TestGroupOf_Use(t *testing.T) { h1 := g.New("h1", NewHosts(false, "h1.example.com")) h1.Get("/posts/5.html", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("h1")) })) + g.Use(tree.BuildTestMiddleware(a, "m1")) h2 := g.New("h2", NewHosts(false, "h2.example.com")) h2.Get("/posts/5.html", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Write([]byte("h2")) })) rest.NewRequest(a, http.MethodGet, "https://h1.example.com/posts/5.html"). Do(g). + Status(http.StatusOK). StringBody("h1m1") rest.NewRequest(a, http.MethodGet, "https://h2.example.com/posts/5.html"). Do(g). + Status(http.StatusOK). StringBody("h2m1") + + // h2 中的 notFound + rest.NewRequest(a, http.MethodGet, "https://h2.example.com/not-exists"). + Do(g). + Status(http.StatusNotFound). + StringBody("404 page not found\nm1") + // group.notFound rest.NewRequest(a, http.MethodGet, "https://not-match.example.com/posts/5.html"). Do(g). - BodyFunc(func(a *assert.Assertion, body []byte) { a.Contains(body, "m1") }) + Status(http.StatusNotFound). + StringBody("404 page not found\nm1") + + // 添加了新的中间件 g.Use(tree.BuildTestMiddleware(a, "m2")) rest.NewRequest(a, http.MethodGet, "https://h1.example.com/posts/5.html"). Do(g). + Status(http.StatusOK). StringBody("h1m1m2") rest.NewRequest(a, http.MethodGet, "https://h2.example.com/posts/5.html"). Do(g). + Status(http.StatusOK). StringBody("h2m1m2") + + // h2 中的 notFound + rest.NewRequest(a, http.MethodGet, "https://h2.example.com/not-exists"). + Do(g). + Status(http.StatusNotFound). + StringBody("404 page not found\nm1m2") + // group.notFound rest.NewRequest(a, http.MethodGet, "https://not-match.example.com/posts/5.html"). Do(g). - BodyFunc(func(a *assert.Assertion, body []byte) { a.Contains(body, "m1m2") }) + Status(http.StatusNotFound). + StringBody("404 page not found\nm1m2") } func TestGroupOf_Add(t *testing.T) {