From ff156e5c7a85986d031ff259f58e1fc14e432294 Mon Sep 17 00:00:00 2001 From: caixw Date: Wed, 22 May 2024 21:53:04 +0800 Subject: [PATCH] =?UTF-8?q?refactor(types):=20=E9=87=87=E7=94=A8=20map=20?= =?UTF-8?q?=E4=BB=A3=E6=9B=BF=E5=8E=9F=E6=9D=A5=E7=9A=84=20slice?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 性能相差无几,但是少了许多代码。 --- header/value.go | 6 ++-- internal/tree/test.go | 2 +- types/context.go | 76 ++++++++++--------------------------------- types/context_test.go | 34 +++++-------------- types/types.go | 8 +++-- 5 files changed, 36 insertions(+), 90 deletions(-) diff --git a/header/value.go b/header/value.go index b3010c5..9f4d8aa 100644 --- a/header/value.go +++ b/header/value.go @@ -10,9 +10,9 @@ const ( NoCache = "no-cache" Identity = "identity" - MessageHTTP = "message/http" // MessageHTTP TRACE 请求方法的 content-type 值 - MultipartFormData = "multipart/form-data" // MultipartFormData 表单提交类型 - FormData = "application/x-www-form-urlencoded" // FormData 普通的表单上传 + MessageHTTP = "message/http" // SSE 的 content-type 值 + MultipartFormData = "multipart/form-data" // 表单提交类型 + FormData = "application/x-www-form-urlencoded" // 普通的表单上传 EventStream = "text/event-stream" Plain = "text/plain" HTML = "text/html" diff --git a/internal/tree/test.go b/internal/tree/test.go index 0fe76ce..3ef2396 100644 --- a/internal/tree/test.go +++ b/internal/tree/test.go @@ -20,7 +20,7 @@ import ( func TestTrace(w http.ResponseWriter, r *http.Request) { trace.Trace(w, r, true) } -// NewTestTree 返回以 http.Handler 作为参数实例化的 Tree +// NewTestTree 返回以 [http.Handler] 作为参数实例化的 Tree func NewTestTree(a *assert.Assertion, lock bool, trace http.Handler, i *syntax.Interceptors) *Tree[http.Handler] { t := New("", lock, i, http.NotFoundHandler(), trace, BuildTestNodeHandlerFunc(http.StatusMethodNotAllowed), BuildTestNodeHandlerFunc(http.StatusOK)) a.NotNil(t) diff --git a/types/context.go b/types/context.go index 1344b4a..d534ed9 100644 --- a/types/context.go +++ b/types/context.go @@ -15,12 +15,10 @@ var contextPool = &sync.Pool{New: func() any { return &Context{} }} // // Context 同时实现了 [Route] 接口。 type Context struct { - Path string // 实际请求的路径信息 - keys []string - vals []string - paramsCount int - routerName string - node Node + Path string // 实际请求的路径信息 + params map[string]string + routerName string + node Node } func NewContext() *Context { @@ -31,9 +29,7 @@ func NewContext() *Context { func (ctx *Context) Reset() { ctx.Path = "" - ctx.keys = ctx.keys[:0] - ctx.vals = ctx.vals[:0] - ctx.paramsCount = 0 + clear(ctx.params) ctx.routerName = "" ctx.node = nil } @@ -50,7 +46,7 @@ func (ctx *Context) RouterName() string { return ctx.routerName } func (ctx *Context) Destroy() { const destroyMaxSize = 30 - if ctx != nil && len(ctx.keys) <= destroyMaxSize { + if ctx != nil && len(ctx.params) <= destroyMaxSize { contextPool.Put(ctx) } } @@ -139,67 +135,31 @@ func (ctx *Context) MustFloat(key string, def float64) float64 { } func (ctx *Context) Get(key string) (string, bool) { - if ctx == nil { + if ctx.params == nil { return "", false } - - for i, k := range ctx.keys { - if k == key { - return ctx.vals[i], true - } - } - return "", false + v, f := ctx.params[key] + return v, f } -func (ctx *Context) Count() (cnt int) { - if ctx == nil { - return 0 - } - return ctx.paramsCount -} +func (ctx *Context) Count() int { return len(ctx.params) } func (ctx *Context) Set(k, v string) { - deletedIndex := -1 - - for i, key := range ctx.keys { - if key == k { - ctx.vals[i] = v - return - } - if key == "" && deletedIndex == -1 { - deletedIndex = i - } - } - - // 没有需要修改的项 - ctx.paramsCount++ - if deletedIndex != -1 { // 优先考虑被标记为删除的项作为添加 - ctx.keys[deletedIndex] = k - ctx.vals[deletedIndex] = v - } else { - ctx.keys = append(ctx.keys, k) - ctx.vals = append(ctx.vals, v) + if ctx.params == nil { + ctx.params = map[string]string{k: v} + return } + ctx.params[k] = v } func (ctx *Context) Delete(k string) { - if ctx == nil { - return - } - - for i, key := range ctx.keys { - if key == k { - ctx.keys[i] = "" - ctx.paramsCount-- - return - } + if ctx.params != nil { + delete(ctx.params, k) } } func (ctx *Context) Range(f func(key, val string)) { - for i, k := range ctx.keys { - if k != "" { - f(k, ctx.vals[i]) - } + for k, v := range ctx.params { + f(k, v) } } diff --git a/types/context_test.go b/types/context_test.go index de5f380..f667396 100644 --- a/types/context_test.go +++ b/types/context_test.go @@ -148,10 +148,6 @@ func TestContext_Float(t *testing.T) { val, err = ctx.Float("k5") a.ErrorIs(err, ErrParamNotExists()).Equal(val, 0.0) a.Equal(ctx.MustFloat("k5", -10.0), -10.0) - - var ps2 *Context - val, err = ps2.Float("key1") - a.Equal(err, ErrParamNotExists()).Equal(val, 0.0) } func TestContext_Set(t *testing.T) { @@ -162,27 +158,20 @@ func TestContext_Set(t *testing.T) { a.Equal(ctx.Count(), 1) ctx.Set("k1", "v2") - a.Equal(ctx.Count(), 1) - a.Equal(ctx.keys, []string{"k1"}) - a.Equal(ctx.vals, []string{"v2"}) + a.Equal(ctx.Count(), 1). + Equal(ctx.params, map[string]string{"k1": "v2"}) ctx.Set("k2", "v2") - a.Equal(ctx.keys, []string{"k1", "k2"}) - a.Equal(ctx.vals, []string{"v2", "v2"}) - a.Equal(ctx.Count(), 2) + a.Equal(ctx.params, map[string]string{"k1": "v2", "k2": "v2"}). + Equal(ctx.Count(), 2) } func TestContext_Get(t *testing.T) { a := assert.New(t, false) - var ctx *Context - a.Zero(ctx.Count()) - v, found := ctx.Get("not-exists") - a.False(found).Zero(v) - - ctx = NewContext() + ctx := NewContext() ctx.Set("k1", "v1") - v, found = ctx.Get("k1") + v, found := ctx.Get("k1") a.True(found).Equal(v, "v1") v, found = ctx.Get("not-exists") @@ -193,10 +182,7 @@ func TestContext_Get(t *testing.T) { func TestContext_Delete(t *testing.T) { a := assert.New(t, false) - var ctx *Context - ctx.Delete("k1") - - ctx = NewContext() + ctx := NewContext() ctx.Path = "/path" ctx.Set("k1", "v1") ctx.Set("k2", "v2") @@ -207,12 +193,10 @@ func TestContext_Delete(t *testing.T) { a.Equal(1, ctx.Count()) ctx.Delete("k2") - a.Equal(0, ctx.Count()). - Equal(2, len(ctx.keys)) + a.Equal(0, ctx.Count()) ctx.Set("k3", "v3") - a.Equal(1, ctx.Count()). - Equal(2, len(ctx.keys)) + a.Equal(1, ctx.Count()) } func TestContext_Range(t *testing.T) { diff --git a/types/types.go b/types/types.go index 9fbe082..2b84841 100644 --- a/types/types.go +++ b/types/types.go @@ -109,9 +109,11 @@ type Node interface { // BuildNodeHandler 为节点生成处理方法 type BuildNodeHandler[T any] func(Node) T -// Middleware 中间件对象需要实现的接口 +// Middleware 中间件 +// +// 中间件用于改变一个路由项的行为,运行于在路由项的初始化阶段。 type Middleware[T any] interface { - // Middleware 包装处理 next + // Middleware 调整路由项 next 的行为 // // next 路由项的处理函数; // method 当前路由的请求方法; @@ -125,7 +127,7 @@ type Middleware[T any] interface { Middleware(next T, method, pattern, router string) T } -// MiddlewareFunc 中间件处理函数 +// MiddlewareFunc 中间件 type MiddlewareFunc[T any] func(next T, method, pattern, router string) T func (f MiddlewareFunc[T]) Middleware(next T, method, pattern, router string) T {