Skip to content

Commit

Permalink
Merge pull request #32 from ipfs/feat/configure-hamt-tree-width-31
Browse files Browse the repository at this point in the history
Configurable HAMT tree bitwidth
  • Loading branch information
hannahhoward authored Sep 12, 2019
2 parents ee6e898 + c51764d commit 85b7f1d
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 45 deletions.
89 changes: 58 additions & 31 deletions hamt.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,49 @@ import (

cid "github.com/ipfs/go-cid"
cbor "github.com/ipfs/go-ipld-cbor"
murmur3 "github.com/spaolacci/murmur3"
cbg "github.com/whyrusleeping/cbor-gen"
xerrors "golang.org/x/xerrors"
)

const arrayWidth = 3
const defaultBitWidth = 8

type Node struct {
Bitfield *big.Int `refmt:"bf"`
Pointers []*Pointer `refmt:"p"`

// for fetching and storing children
store *CborIpldStore
store *CborIpldStore
bitWidth int
}

func NewNode(cs *CborIpldStore) *Node {
return &Node{
// Option is a function that configures the node
type Option func(*Node)

// UseTreeBitWidth allows you to set the width of the HAMT tree
// in bits (from 1-8) via a customized hash function
func UseTreeBitWidth(bitWidth int) Option {
return func(nd *Node) {
if bitWidth > 0 && bitWidth <= 8 {
nd.bitWidth = bitWidth
}
}
}

// NewNode creates a new IPLD HAMT Node with the given store and given
// options
func NewNode(cs *CborIpldStore, options ...Option) *Node {
nd := &Node{
Bitfield: big.NewInt(0),
Pointers: make([]*Pointer, 0),
store: cs,
bitWidth: defaultBitWidth,
}
// apply functional options to node before using
for _, option := range options {
option(nd)
}
return nd
}

type KV struct {
Expand All @@ -44,14 +66,8 @@ type Pointer struct {
cache *Node
}

var hash = func(k string) []byte {
h := murmur3.New128()
h.Write([]byte(k))
return h.Sum(nil)
}

func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
return n.getValue(ctx, hash(k), 0, k, func(kv *KV) error {
return n.getValue(ctx, &hashBits{b: hash(k)}, k, func(kv *KV) error {
// used to just see if the thing exists in the set
if out == nil {
return nil
Expand All @@ -70,32 +86,32 @@ func (n *Node) Find(ctx context.Context, k string, out interface{}) error {
}

func (n *Node) Delete(ctx context.Context, k string) error {
return n.modifyValue(ctx, hash(k), 0, k, nil)
return n.modifyValue(ctx, &hashBits{b: hash(k)}, k, nil)
}

var ErrNotFound = fmt.Errorf("not found")
var ErrMaxDepth = fmt.Errorf("attempted to traverse hamt beyond max depth")

func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb func(*KV) error) error {
if depth >= len(hv) {
func (n *Node) getValue(ctx context.Context, hv *hashBits, k string, cb func(*KV) error) error {
idx, err := hv.Next(n.bitWidth)
if err != nil {
return ErrMaxDepth
}

idx := hv[depth]
if n.Bitfield.Bit(int(idx)) == 0 {
if n.Bitfield.Bit(idx) == 0 {
return ErrNotFound
}

cindex := byte(n.indexForBitPos(int(idx)))
cindex := byte(n.indexForBitPos(idx))

c := n.getChild(cindex)
if c.isShard() {
chnd, err := c.loadChild(ctx, n.store)
chnd, err := c.loadChild(ctx, n.store, n.bitWidth)
if err != nil {
return err
}

return chnd.getValue(ctx, hv, depth+1, k, cb)
return chnd.getValue(ctx, hv, k, cb)
}

for _, kv := range c.KVs {
Expand All @@ -107,12 +123,13 @@ func (n *Node) getValue(ctx context.Context, hv []byte, depth int, k string, cb
return ErrNotFound
}

func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore) (*Node, error) {
func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore, bitWidth int) (*Node, error) {
if p.cache != nil {
return p.cache, nil
}

out, err := LoadNode(ctx, ns, p.Link)
out.bitWidth = bitWidth
if err != nil {
return nil, err
}
Expand All @@ -121,13 +138,19 @@ func (p *Pointer) loadChild(ctx context.Context, ns *CborIpldStore) (*Node, erro
return out, nil
}

func LoadNode(ctx context.Context, cs *CborIpldStore, c cid.Cid) (*Node, error) {
func LoadNode(ctx context.Context, cs *CborIpldStore, c cid.Cid, options ...Option) (*Node, error) {
var out Node
if err := cs.Get(ctx, c, &out); err != nil {
return nil, err
}

out.store = cs
out.bitWidth = defaultBitWidth
// apply functional options to node before using
for _, option := range options {
option(&out)
}

return &out, nil
}

Expand All @@ -145,7 +168,7 @@ func (n *Node) checkSize(ctx context.Context) (uint64, error) {
totsize := uint64(len(blk.RawData()))
for _, ch := range n.Pointers {
if ch.isShard() {
chnd, err := ch.loadChild(ctx, n.store)
chnd, err := ch.loadChild(ctx, n.store, n.bitWidth)
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -197,7 +220,7 @@ func (n *Node) Set(ctx context.Context, k string, v interface{}) error {
d = &cbg.Deferred{Raw: b}
}

return n.modifyValue(ctx, hash(k), 0, k, d)
return n.modifyValue(ctx, &hashBits{b: hash(k)}, k, d)
}

func (n *Node) cleanChild(chnd *Node, cindex byte) error {
Expand Down Expand Up @@ -234,11 +257,11 @@ func (n *Node) cleanChild(chnd *Node, cindex byte) error {
}
}

func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string, v *cbg.Deferred) error {
if depth >= len(hv) {
func (n *Node) modifyValue(ctx context.Context, hv *hashBits, k string, v *cbg.Deferred) error {
idx, err := hv.Next(n.bitWidth)
if err != nil {
return ErrMaxDepth
}
idx := int(hv[depth])

if n.Bitfield.Bit(idx) != 1 {
return n.insertChild(idx, k, v)
Expand All @@ -248,12 +271,12 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,

child := n.getChild(cindex)
if child.isShard() {
chnd, err := child.loadChild(ctx, n.store)
chnd, err := child.loadChild(ctx, n.store, n.bitWidth)
if err != nil {
return err
}

if err := chnd.modifyValue(ctx, hv, depth+1, k, v); err != nil {
if err := chnd.modifyValue(ctx, hv, k, v); err != nil {
return err
}

Expand Down Expand Up @@ -293,12 +316,15 @@ func (n *Node) modifyValue(ctx context.Context, hv []byte, depth int, k string,
// If the array is full, create a subshard and insert everything into it
if len(child.KVs) >= arrayWidth {
sub := NewNode(n.store)
if err := sub.modifyValue(ctx, hv, depth+1, k, v); err != nil {
sub.bitWidth = n.bitWidth
hvcopy := &hashBits{b: hv.b, consumed: hv.consumed}
if err := sub.modifyValue(ctx, hvcopy, k, v); err != nil {
return err
}

for _, p := range child.KVs {
if err := sub.modifyValue(ctx, hash(p.Key), depth+1, p.Key, p.Value); err != nil {
chhv := &hashBits{b: hash(p.Key), consumed: hv.consumed}
if err := sub.modifyValue(ctx, chhv, p.Key, p.Value); err != nil {
return err
}
}
Expand Down Expand Up @@ -360,6 +386,7 @@ func (n *Node) getChild(i byte) *Pointer {

func (n *Node) Copy() *Node {
nn := NewNode(n.store)
nn.bitWidth = n.bitWidth
nn.Bitfield.Set(n.Bitfield)
nn.Pointers = make([]*Pointer, len(n.Pointers))

Expand Down Expand Up @@ -388,7 +415,7 @@ func (p *Pointer) isShard() bool {
func (n *Node) ForEach(ctx context.Context, f func(k string, val interface{}) error) error {
for _, p := range n.Pointers {
if p.isShard() {
chnd, err := p.loadChild(ctx, n.store)
chnd, err := p.loadChild(ctx, n.store, n.bitWidth)
if err != nil {
return err
}
Expand Down
16 changes: 10 additions & 6 deletions hamt_bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,20 @@ func BenchmarkSerializeNode(b *testing.B) {
}

func BenchmarkFind(b *testing.B) {
b.Run("find-10k", doBenchmarkEntriesCount(10000))
b.Run("find-100k", doBenchmarkEntriesCount(100000))
b.Run("find-1m", doBenchmarkEntriesCount(1000000))
b.Run("find-10k", doBenchmarkEntriesCount(10000, 8))
b.Run("find-100k", doBenchmarkEntriesCount(100000, 8))
b.Run("find-1m", doBenchmarkEntriesCount(1000000, 8))
b.Run("find-10k-bitwidth-5", doBenchmarkEntriesCount(10000, 5))
b.Run("find-100k-bitwidth-5", doBenchmarkEntriesCount(100000, 5))
b.Run("find-1m-bitwidth-5", doBenchmarkEntriesCount(1000000, 5))

}

func doBenchmarkEntriesCount(num int) func(b *testing.B) {
func doBenchmarkEntriesCount(num int, bitWidth int) func(b *testing.B) {
r := rander{rand.New(rand.NewSource(int64(num)))}
return func(b *testing.B) {
cs := NewCborStore()
n := NewNode(cs)
n := NewNode(cs, UseTreeBitWidth(bitWidth))

var keys []string
for i := 0; i < num; i++ {
Expand All @@ -82,7 +86,7 @@ func doBenchmarkEntriesCount(num int) func(b *testing.B) {
b.ReportAllocs()

for i := 0; i < b.N; i++ {
nd, err := LoadNode(context.TODO(), cs, c)
nd, err := LoadNode(context.TODO(), cs, c, UseTreeBitWidth(bitWidth))
if err != nil {
b.Fatal(err)
}
Expand Down
29 changes: 21 additions & 8 deletions hamt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,22 @@ func TestCanonicalStructure(t *testing.T) {
defer func() {
hash = murmurHash
}()
addAndRemoveKeys(t, []string{"K"}, []string{"B"})
addAndRemoveKeys(t, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
addAndRemoveKeys(t, defaultBitWidth, []string{"K"}, []string{"B"})
addAndRemoveKeys(t, defaultBitWidth, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
}

func TestCanonicalStructureAlternateBitWidth(t *testing.T) {
hash = identityHash
defer func() {
hash = murmurHash
}()
addAndRemoveKeys(t, 7, []string{"K"}, []string{"B"})
addAndRemoveKeys(t, 7, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
addAndRemoveKeys(t, 6, []string{"K"}, []string{"B"})
addAndRemoveKeys(t, 6, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
addAndRemoveKeys(t, 5, []string{"K"}, []string{"B"})
addAndRemoveKeys(t, 5, []string{"K0", "K1", "KAA1", "KAA2", "KAA3"}, []string{"KAA4"})
}
func TestOverflow(t *testing.T) {
hash = identityHash
defer func() {
Expand Down Expand Up @@ -90,7 +102,7 @@ func TestOverflow(t *testing.T) {
}
}

func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
func addAndRemoveKeys(t *testing.T, bitWidth int, keys []string, extraKeys []string) {
ctx := context.Background()
vals := make(map[string][]byte)
for i := 0; i < len(keys); i++ {
Expand All @@ -99,7 +111,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
}

cs := NewCborStore()
begn := NewNode(cs)
begn := NewNode(cs, UseTreeBitWidth(bitWidth))
for _, k := range keys {
if err := begn.Set(ctx, k, vals[k]); err != nil {
t.Fatal(err)
Expand All @@ -122,7 +134,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
t.Fatal(err)
}
n.store = cs

n.bitWidth = bitWidth
for k, v := range vals {
var out []byte
err := n.Find(ctx, k, &out)
Expand Down Expand Up @@ -157,6 +169,7 @@ func addAndRemoveKeys(t *testing.T, keys []string, extraKeys []string) {
t.Fatal(err)
}
n2.store = cs
n2.bitWidth = bitWidth
if !nodesEqual(t, cs, &n, &n2) {
t.Fatal("nodes should be equal")
}
Expand All @@ -168,7 +181,7 @@ func dotGraphRec(n *Node, name *int) {
if p.isShard() {
*name++
fmt.Printf("\tn%d -> n%d;\n", cur, *name)
nd, err := p.loadChild(context.Background(), n.store)
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -200,7 +213,7 @@ func statsrec(n *Node, st *hamtStats) {
st.totalNodes++
for _, p := range n.Pointers {
if p.isShard() {
nd, err := p.loadChild(context.Background(), n.store)
nd, err := p.loadChild(context.Background(), n.store, n.bitWidth)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -305,7 +318,7 @@ func TestSetGet(t *testing.T) {
t.Fatal(err)
}
n.store = cs

n.bitWidth = defaultBitWidth
bef = time.Now()
//for k, v := range vals {
for _, k := range keys {
Expand Down
Loading

0 comments on commit 85b7f1d

Please sign in to comment.