Skip to content

Commit

Permalink
refactor: 所有添加路由项的方法可以直接指定中间件
Browse files Browse the repository at this point in the history
  • Loading branch information
caixw committed May 20, 2024
1 parent 861f9c4 commit 2f4db8b
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 56 deletions.
98 changes: 55 additions & 43 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,44 +110,44 @@ func (r *RouterOf[T]) Use(m ...types.MiddlewareOf[T]) {
// Handle 添加一条路由数据
//
// pattern 为路由匹配模式,可以是正则匹配也可以是字符串匹配,
// 若语法不正确,则直接 panic,可以通过 [CheckSyntax] 检测语法的有效性,其它接口也相同。
// methods 该路由项对应的请求方法,如果未指定值,则表示所有支持的请求方法,其中 OPTIONS 和 HEAD 不受控。
func (r *RouterOf[T]) Handle(pattern string, h T, methods ...string) *RouterOf[T] {
r.handle(pattern, h, nil, methods...)
return r
}

func (r *RouterOf[T]) handle(pattern string, h T, ms []types.MiddlewareOf[T], methods ...string) {
if err := r.tree.Add(pattern, h, slices.Concat(ms, r.middleware), methods...); err != nil {
// 若语法不正确,则直接 panic,可以通过 [CheckSyntax] 检测语法的有效性,其它接口也相同;
// m 为应用于当前路由项的中间件;
// methods 该路由项对应的请求方法,如果未指定值,则表示所有支持的请求方法,其中 OPTIONS 和 HEAD 不受控;
func (r *RouterOf[T]) Handle(pattern string, h T, m []types.MiddlewareOf[T], methods ...string) *RouterOf[T] {
m = slices.Concat(m, r.middleware)
if err := r.tree.Add(pattern, h, m, methods...); err != nil {
panic(err)
}
return r
}

// Get 相当于 RouterOf.Handle(pattern, h, http.MethodGet) 的简易写法
//
// h 不应该主动调用 WriteHeader,否则会导致 HEAD 请求获取不到 Content-Length 报头。
func (r *RouterOf[T]) Get(pattern string, h T) *RouterOf[T] {
return r.Handle(pattern, h, http.MethodGet)
func (r *RouterOf[T]) Get(pattern string, h T, m ...types.MiddlewareOf[T]) *RouterOf[T] {
return r.Handle(pattern, h, m, http.MethodGet)
}

func (r *RouterOf[T]) Post(pattern string, h T) *RouterOf[T] {
return r.Handle(pattern, h, http.MethodPost)
func (r *RouterOf[T]) Post(pattern string, h T, m ...types.MiddlewareOf[T]) *RouterOf[T] {
return r.Handle(pattern, h, m, http.MethodPost)
}

func (r *RouterOf[T]) Delete(pattern string, h T) *RouterOf[T] {
return r.Handle(pattern, h, http.MethodDelete)
func (r *RouterOf[T]) Delete(pattern string, h T, m ...types.MiddlewareOf[T]) *RouterOf[T] {
return r.Handle(pattern, h, m, http.MethodDelete)
}

func (r *RouterOf[T]) Put(pattern string, h T) *RouterOf[T] {
return r.Handle(pattern, h, http.MethodPut)
func (r *RouterOf[T]) Put(pattern string, h T, m ...types.MiddlewareOf[T]) *RouterOf[T] {
return r.Handle(pattern, h, m, http.MethodPut)
}

func (r *RouterOf[T]) Patch(pattern string, h T) *RouterOf[T] {
return r.Handle(pattern, h, http.MethodPatch)
func (r *RouterOf[T]) Patch(pattern string, h T, m ...types.MiddlewareOf[T]) *RouterOf[T] {
return r.Handle(pattern, h, m, http.MethodPatch)
}

// Any 添加一条包含全部请求方法的路由
func (r *RouterOf[T]) Any(pattern string, h T) *RouterOf[T] { return r.Handle(pattern, h) }
func (r *RouterOf[T]) Any(pattern string, h T, m ...types.MiddlewareOf[T]) *RouterOf[T] {
return r.Handle(pattern, h, m)
}

// URL 根据参数生成地址
//
Expand Down Expand Up @@ -216,33 +216,34 @@ func (r *RouterOf[T]) ServeContext(w http.ResponseWriter, req *http.Request, ctx
// Name 路由名称
func (r *RouterOf[T]) Name() string { return r.name }

func (p *PrefixOf[T]) Handle(pattern string, h T, methods ...string) *PrefixOf[T] {
p.router.handle(p.Pattern()+pattern, h, p.middleware, methods...)
func (p *PrefixOf[T]) Handle(pattern string, h T, m []types.MiddlewareOf[T], methods ...string) *PrefixOf[T] {
m = slices.Concat(m, p.middleware)
p.router.Handle(p.Pattern()+pattern, h, m, methods...)
return p
}

func (p *PrefixOf[T]) Get(pattern string, h T) *PrefixOf[T] {
return p.Handle(pattern, h, http.MethodGet)
func (p *PrefixOf[T]) Get(pattern string, h T, m ...types.MiddlewareOf[T]) *PrefixOf[T] {
return p.Handle(pattern, h, m, http.MethodGet)
}

func (p *PrefixOf[T]) Post(pattern string, h T) *PrefixOf[T] {
return p.Handle(pattern, h, http.MethodPost)
func (p *PrefixOf[T]) Post(pattern string, h T, m ...types.MiddlewareOf[T]) *PrefixOf[T] {
return p.Handle(pattern, h, m, http.MethodPost)
}

func (p *PrefixOf[T]) Delete(pattern string, h T) *PrefixOf[T] {
return p.Handle(pattern, h, http.MethodDelete)
func (p *PrefixOf[T]) Delete(pattern string, h T, m ...types.MiddlewareOf[T]) *PrefixOf[T] {
return p.Handle(pattern, h, m, http.MethodDelete)
}

func (p *PrefixOf[T]) Put(pattern string, h T) *PrefixOf[T] {
return p.Handle(pattern, h, http.MethodPut)
func (p *PrefixOf[T]) Put(pattern string, h T, m ...types.MiddlewareOf[T]) *PrefixOf[T] {
return p.Handle(pattern, h, m, http.MethodPut)
}

func (p *PrefixOf[T]) Patch(pattern string, h T) *PrefixOf[T] {
return p.Handle(pattern, h, http.MethodPatch)
func (p *PrefixOf[T]) Patch(pattern string, h T, m ...types.MiddlewareOf[T]) *PrefixOf[T] {
return p.Handle(pattern, h, m, http.MethodPatch)
}

func (p *PrefixOf[T]) Any(pattern string, h T) *PrefixOf[T] {
return p.Handle(pattern, h)
func (p *PrefixOf[T]) Any(pattern string, h T, m ...types.MiddlewareOf[T]) *PrefixOf[T] {
return p.Handle(pattern, h, m)
}

// Pattern 当前对象的路径
Expand Down Expand Up @@ -283,25 +284,36 @@ func (r *RouterOf[T]) Prefix(prefix string, m ...types.MiddlewareOf[T]) *PrefixO
return &PrefixOf[T]{router: r, pattern: prefix, middleware: slices.Clone(m)}
}

// Router 返回与当前关联的 *RouterOf 实例
// Router 返回与当前关联的 [RouterOf] 实例
func (p *PrefixOf[T]) Router() *RouterOf[T] { return p.router }

func (r *ResourceOf[T]) Handle(h T, methods ...string) *ResourceOf[T] {
r.router.handle(r.pattern, h, r.middleware, methods...)
func (r *ResourceOf[T]) Handle(h T, m []types.MiddlewareOf[T], methods ...string) *ResourceOf[T] {
m = slices.Concat(m, r.middleware)
r.router.Handle(r.pattern, h, m, methods...)
return r
}

func (r *ResourceOf[T]) Get(h T) *ResourceOf[T] { return r.Handle(h, http.MethodGet) }
func (r *ResourceOf[T]) Get(h T, m ...types.MiddlewareOf[T]) *ResourceOf[T] {
return r.Handle(h, m, http.MethodGet)
}

func (r *ResourceOf[T]) Post(h T) *ResourceOf[T] { return r.Handle(h, http.MethodPost) }
func (r *ResourceOf[T]) Post(h T, m ...types.MiddlewareOf[T]) *ResourceOf[T] {
return r.Handle(h, m, http.MethodPost)
}

func (r *ResourceOf[T]) Delete(h T) *ResourceOf[T] { return r.Handle(h, http.MethodDelete) }
func (r *ResourceOf[T]) Delete(h T, m ...types.MiddlewareOf[T]) *ResourceOf[T] {
return r.Handle(h, m, http.MethodDelete)
}

func (r *ResourceOf[T]) Put(h T) *ResourceOf[T] { return r.Handle(h, http.MethodPut) }
func (r *ResourceOf[T]) Put(h T, m ...types.MiddlewareOf[T]) *ResourceOf[T] {
return r.Handle(h, m, http.MethodPut)
}

func (r *ResourceOf[T]) Patch(h T) *ResourceOf[T] { return r.Handle(h, http.MethodPatch) }
func (r *ResourceOf[T]) Patch(h T, m ...types.MiddlewareOf[T]) *ResourceOf[T] {
return r.Handle(h, m, http.MethodPatch)
}

func (r *ResourceOf[T]) Any(h T) *ResourceOf[T] { return r.Handle(h) }
func (r *ResourceOf[T]) Any(h T, m ...types.MiddlewareOf[T]) *ResourceOf[T] { return r.Handle(h, m) }

// Remove 删除指定匹配模式的路由项
func (r *ResourceOf[T]) Remove(methods ...string) { r.router.Remove(r.pattern, methods...) }
Expand Down
10 changes: 5 additions & 5 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func TestRouterOf(t *testing.T) {

// 不能主动添加 Head
a.PanicString(func() {
r.Handle("/options", rest.BuildHandler(a, 202, "", nil), http.MethodOptions)
r.Handle("/options", rest.BuildHandler(a, 202, "", nil), nil, http.MethodOptions)
}, "OPTIONS")
}

Expand All @@ -74,9 +74,9 @@ func TestRouterOf_Handle_Remove(t *testing.T) {
// 添加 GET /api/1
// 添加 PUT /api/1
// 添加 GET /api/2
r.Handle("/api/1", rest.BuildHandler(a, 201, "", nil), http.MethodGet)
r.Handle("/api/1", rest.BuildHandler(a, 201, "", nil), http.MethodPut)
r.Handle("/api/2", rest.BuildHandler(a, 202, "", nil), http.MethodGet)
r.Handle("/api/1", rest.BuildHandler(a, 201, "", nil), nil, http.MethodGet)
r.Handle("/api/1", rest.BuildHandler(a, 201, "", nil), nil, http.MethodPut)
r.Handle("/api/2", rest.BuildHandler(a, 202, "", nil), nil, http.MethodGet)

rest.Get(a, "/api/1").Do(r).Status(201)
rest.Put(a, "/api/1", nil).Do(r).Status(201)
Expand All @@ -96,7 +96,7 @@ func TestRouterOf_Handle_Remove(t *testing.T) {
rest.Get(a, "/api/2").Do(r).Status(http.StatusNotFound) // 整个节点被删除

// 添加 POST /api/1
r.Handle("/api/1", rest.BuildHandler(a, 201, "", nil), http.MethodPost)
r.Handle("/api/1", rest.BuildHandler(a, 201, "", nil), nil, http.MethodPost)
rest.Post(a, "/api/1", nil).Do(r).Status(201)

// 删除 ANY /api/1
Expand Down
8 changes: 4 additions & 4 deletions routertest/bench.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func (t *Tester[T]) benchURL(b *testing.B, h T) {

router := mux.NewRouterOf("test", t.c, t.notFound, t.m, t.o, mux.Lock(true), mux.URLDomain(domain))
for _, api := range apis {
router.Handle(api.pattern, h, api.method)
router.Handle(api.pattern, h, nil, api.method)
}

b.ReportAllocs()
Expand Down Expand Up @@ -76,7 +76,7 @@ func (t *Tester[T]) benchAddAndServeHTTP(b *testing.B, h T) {
w := httptest.NewRecorder()
r, _ := http.NewRequest(api.method, api.test, nil)

router.Handle(api.pattern, h, api.method)
router.Handle(api.pattern, h, nil, api.method)
router.ServeHTTP(w, r)
router.Remove(api.pattern, api.method)

Expand All @@ -89,7 +89,7 @@ func (t *Tester[T]) benchAddAndServeHTTP(b *testing.B, h T) {
func (t *Tester[T]) benchServeHTTP(b *testing.B, h T) {
router := mux.NewRouterOf("test", t.c, t.notFound, t.m, t.o)
for _, api := range apis {
router.Handle(api.pattern, h, api.method)
router.Handle(api.pattern, h, nil, api.method)
}

b.ReportAllocs()
Expand All @@ -112,7 +112,7 @@ func (t *Tester[T]) calcMemStats(h T) uint64 {
return calcMemStats(func() {
r := mux.NewRouterOf("test", t.c, t.notFound, t.m, t.o, mux.Lock(true))
for _, api := range apis {
r.Handle(api.pattern, h, api.method)
r.Handle(api.pattern, h, nil, api.method)
}
})
}
Expand Down
8 changes: 4 additions & 4 deletions routertest/routertest.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,14 @@ func (t *Tester[T]) Serve(a *assert.Assertion, h func(status int) T) {
a.NotNil(router)
srv := rest.NewServer(a, router, nil)

router.Handle("/posts/{path}.html", h(201))
router.Handle("/posts/{path}.html", h(201), nil)
srv.Get("/posts/2017/1.html").Do(nil).Status(201)
srv.Get("/Posts/2017/1.html").Do(nil).Status(404) // 大小写不一样

router.Handle("/posts/{path:.+}.html", h(202))
router.Handle("/posts/{path:.+}.html", h(202), nil)
srv.Get("/posts/2017/1.html").Do(nil).Status(202)

router.Handle("/posts/{id:digit}123", h(203))
router.Handle("/posts/{id:digit}123", h(203), nil)
srv.Get("/posts/123123").Do(nil).Status(203)

router.Get("///", h(201))
Expand All @@ -131,6 +131,6 @@ func (t *Tester[T]) Serve(a *assert.Assertion, h func(status int) T) {
router = mux.NewRouterOf("test", t.c, t.notFound, t.m, t.o)
srv = rest.NewServer(a, router, nil)

router.Handle("/posts/{path}.html", h(201))
router.Handle("/posts/{path}.html", h(201), nil)
srv.Get("/posts/2017/1.html").Do(nil).Status(201)
}

0 comments on commit 2f4db8b

Please sign in to comment.