Skip to content

Commit

Permalink
trie.Common method, possible memory leakage fix (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
s0rg authored Sep 19, 2023
1 parent 6a37156 commit fdd3223
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 3 deletions.
3 changes: 2 additions & 1 deletion node.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ func (n *node[T]) SetValue(val T) {
}

func (n *node[T]) DropValue() {
n.ok = false
var zero T
n.value, n.ok = zero, false
}

func (n *node[T]) HasValue() (ok bool) {
Expand Down
57 changes: 55 additions & 2 deletions trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package trie

import (
"bytes"
"cmp"
"fmt"
"io"
"slices"
"strings"
)

Expand Down Expand Up @@ -112,7 +114,42 @@ func (t *Trie[T]) Iter(prefix string, walker func(key string, value T)) {
walker(prefix, n.value)
}

dfsKeys(n, prefix, walker)
dfsValues(n, prefix, walker)
}

// Common returns slice of common keys with at least `minLength` of their length
// Pass prefix="" to find all commons whithin given length
// Resulting slice is sorted by overall matching keys count, key with most goes first.
func (t *Trie[T]) Common(prefix string, minLength int) (rv []string) {
n, ok := t.find(prefix)
if !ok {
return
}

minLength -= len(prefix)

dfsKeys(n, prefix, func(k string, n *node[T]) bool {
if len(k) < minLength {
return true
}

if n.HasValue() || len(n.childs) > 1 {
rv = append(rv, k)

return false
}

return true
})

slices.SortStableFunc(rv, func(a, b string) int {
sa, _ := t.Suggest(a)
sb, _ := t.Suggest(b)

return cmp.Compare(len(sb), len(sa))
})

return rv
}

// String implements fmt.Stringer interface.
Expand Down Expand Up @@ -151,7 +188,7 @@ func writeNode[T any](
}
}

func dfsKeys[T any](
func dfsValues[T any](
n *node[T],
prefix string,
handler func(key string, value T),
Expand All @@ -163,6 +200,22 @@ func dfsKeys[T any](
handler(key, c.value)
}

dfsValues(c, key, handler)
}
}

func dfsKeys[T any](
n *node[T],
prefix string,
handler func(string, *node[T]) bool,
) {
for r, c := range n.childs {
key := prefix + string(r)

if !handler(key, c) {
return
}

dfsKeys(c, key, handler)
}
}
35 changes: 35 additions & 0 deletions trie_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package trie_test

import (
"slices"
"strings"
"testing"

Expand Down Expand Up @@ -250,6 +251,40 @@ func TestTrieSuggest(t *testing.T) {
}
}

func TestTrieCommons(t *testing.T) {
t.Parallel()

tr := trie.New[int]()

if res := tr.Common("foo", 3); len(res) > 0 {
t.FailNow()
}

tr.Add("foo", 1)
tr.Add("food", 2)
tr.Add("car", 3)
tr.Add("carpet", 4)
tr.Add("cart", 5)
tr.Add("cartridge", 9)
tr.Add("probe", 10)
tr.Add("problem", 11)
tr.Add("probability", 12)

want := []string{
"car",
"prob",
"foo",
}

if res := tr.Common("", 3); slices.Compare(res, want) != 0 {
t.Fail()
}

if res := tr.Common("ca", 3); len(res) != 1 || res[0] != "car" {
t.Fail()
}
}

func FuzzTrie(f *testing.F) {
f.Add("foo:F,bar:B")

Expand Down

0 comments on commit fdd3223

Please sign in to comment.