diff --git a/trie.go b/trie.go index 3c6b648..3cbf2fe 100644 --- a/trie.go +++ b/trie.go @@ -1,7 +1,6 @@ package goblin import ( - "fmt" "net/http" "path" "regexp" @@ -164,10 +163,7 @@ type regCache struct { func (rc *regCache) getReg(ptn string) (*regexp.Regexp, error) { v, ok := rc.s.Load(ptn) if ok { - reg, ok := v.(*regexp.Regexp) - if !ok { - return nil, fmt.Errorf("the value of %q is wrong", ptn) - } + reg, _ := v.(*regexp.Regexp) return reg, nil } reg, err := regexp.Compile(ptn) @@ -185,10 +181,6 @@ func (t *tree) Search(path string) (*action, Params, error) { path = cleanPath(path) curNode := t.node - if path == "/" && curNode.action == nil { - return nil, nil, ErrNotFound - } - path = removeTrailingSlash(path) cnt := strings.Count(path, "/") @@ -217,7 +209,6 @@ func (t *tree) Search(path string) (*action, Params, error) { // no matching path was found. return nil, nil, ErrNotFound } - break } nextNode := curNode.getChild(l) diff --git a/trie_test.go b/trie_test.go index 31fbb7d..61c0b9c 100644 --- a/trie_test.go +++ b/trie_test.go @@ -1,8 +1,10 @@ package goblin import ( + "fmt" "net/http" "reflect" + "regexp" "testing" ) @@ -703,6 +705,7 @@ func TestSearchRegexp(t *testing.T) { fooBarHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) fooBarIDHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) fooBarIDNameHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) + bazInvalidIDHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}) tree.Insert(`/`, rootHandler, []middleware{first}) tree.Insert(`/:*[(.+)]`, rootWildCardHandler, []middleware{first}) @@ -712,6 +715,7 @@ func TestSearchRegexp(t *testing.T) { tree.Insert(`/foo/bar`, fooBarHandler, []middleware{first}) tree.Insert(`/foo/bar/:id`, fooBarIDHandler, []middleware{first}) tree.Insert(`/foo/bar/:id/:name`, fooBarIDNameHandler, []middleware{first}) + tree.Insert(`/baz/:id[[\d+]`, bazInvalidIDHandler, []middleware{first}) cases := []caseWithFailure{ { @@ -883,6 +887,14 @@ func TestSearchRegexp(t *testing.T) { }, }, }, + { + hasError: true, + item: &item{ + path: "/baz/1", + }, + expectedAction: nil, + expectedParams: Params{}, + }, } testWithFailure(t, tree, cases) @@ -1010,6 +1022,59 @@ func testWithFailure(t *testing.T, tree *tree, cases []caseWithFailure) { } } +func TestGetReg(t *testing.T) { + cases := []struct { + name string + ptn string + isCached bool + expectedReg *regexp.Regexp + expectedErr error + }{ + { + name: "Valid - no cache", + ptn: `\d+`, + isCached: false, + expectedReg: regexp.MustCompile(`\d+`), + expectedErr: nil, + }, + { + name: "Valid - cached", + ptn: `\d+`, + isCached: true, + expectedReg: regexp.MustCompile(`\d+`), + expectedErr: nil, + }, + { + name: "Invalid - regexp compile error", + ptn: `[\d+`, + isCached: false, + expectedReg: nil, + expectedErr: fmt.Errorf("error parsing regexp: missing closing ]: `[\\d+`"), + }, + } + + cache := regCache{} + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + if c.isCached { + cache.s.Store(c.ptn, c.expectedReg) + } + reg, err := cache.getReg(c.ptn) + + if !reflect.DeepEqual(reg, c.expectedReg) { + t.Errorf("actual:%v expected:%v", reg, c.expectedReg) + } + + if err != nil { + if err.Error() != c.expectedErr.Error() { + t.Errorf("actual:%v expected:%v", err.Error(), c.expectedErr.Error()) + } + } + }) + } +} + func TestGetPattern(t *testing.T) { cases := []struct { actual string