From 42dc179b389fbe5d5f45477476bb99e7d5b14120 Mon Sep 17 00:00:00 2001 From: guonaihong Date: Sat, 28 Dec 2024 01:39:04 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20skiplist=E5=8A=A0=E4=B8=8A=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E5=AE=89=E5=85=A8=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- skiplist/skiplist.go | 158 +++++++++++++++++++++++++++++++++++--- skiplist/skiplist_test.go | 60 +++++++++++++++ 2 files changed, 207 insertions(+), 11 deletions(-) diff --git a/skiplist/skiplist.go b/skiplist/skiplist.go index cfc6362..77bda16 100644 --- a/skiplist/skiplist.go +++ b/skiplist/skiplist.go @@ -38,6 +38,7 @@ import ( "errors" "fmt" "math/rand" + "sync" "time" "github.com/antlabs/gstl/api" @@ -54,6 +55,44 @@ var ( ErrNotFound = errors.New("not found element") ) +type nodePool[K constraints.Ordered, T any] struct { + pool sync.Pool +} + +var ( + poolMap sync.Map // map[reflect.Type]*nodePool[K, T] +) + +func getNodePool[K constraints.Ordered, T any]() *nodePool[K, T] { + var k K + var t T + typKey := fmt.Sprintf("%T-%T", k, t) + + if p, ok := poolMap.Load(typKey); ok { + return p.(*nodePool[K, T]) + } + + pool := &nodePool[K, T]{} + pool.pool.New = func() interface{} { + return new(Node[K, T]) + } + + actual, _ := poolMap.LoadOrStore(typKey, pool) + return actual.(*nodePool[K, T]) +} + +func (p *nodePool[K, T]) getNode() *Node[K, T] { + return p.pool.Get().(*Node[K, T]) +} + +func (p *nodePool[K, T]) putNode(n *Node[K, T]) { + n.score = *new(K) + n.elem = *new(T) + n.backward = nil + n.NodeLevel = nil + p.pool.Put(n) +} + type Node[K constraints.Ordered, T any] struct { score K elem T @@ -73,6 +112,7 @@ type SkipList[K constraints.Ordered, T any] struct { r *rand.Rand length int level int + pool *nodePool[K, T] //compare func(T, T) int } @@ -86,8 +126,10 @@ func New[K constraints.Ordered, T any]() *SkipList[K, T] { //s.compare = compare var score K + var elem T + s.pool = getNodePool[K, T]() s.resetRand() - s.head = newNode(SKIPLIST_MAXLEVEL, score, *new(T)) + s.head = s.newNode(SKIPLIST_MAXLEVEL, score, elem) return s } @@ -113,28 +155,35 @@ func (s *SkipList[K, T]) rand() int { return SKIPLIST_MAXLEVEL } -func newNode[K constraints.Ordered, T any](level int, score K, elem T) *Node[K, T] { - return &Node[K, T]{ - score: score, - elem: elem, - NodeLevel: make([]struct { - forward *Node[K, T] - span int - }, level), - } +func (s *SkipList[K, T]) newNode(level int, score K, elem T) *Node[K, T] { + node := s.pool.getNode() + node.score = score + node.elem = elem + node.NodeLevel = make([]struct { + forward *Node[K, T] + span int + }, level) + return node +} + +func (s *SkipList[K, T]) putNode(node *Node[K, T]) { + s.pool.putNode(node) } // 设置值, 和Insert是同义词 func (s *SkipList[K, T]) Set(score K, elem T) { + s.InsertInner(score, elem, s.rand()) } func (s *SkipList[K, T]) SetWithPrev(score K, elem T) (prev T, replaced bool) { + return s.InsertInner(score, elem, s.rand()) } // 设置值 func (s *SkipList[K, T]) Insert(score K, elem T) { + s.InsertInner(score, elem, s.rand()) } @@ -189,7 +238,7 @@ func (s *SkipList[K, T]) InsertInner(score K, elem T, level int) (prev T, replac } // 创建新节点 - x = newNode(level, score, elem) + x = s.newNode(level, score, elem) for i := 0; i < level; i++ { // x.NodeLevel[i]的节点假设等于a, 需要插入的节点x在a之后, // a, x, a.forward三者的关系就是[a, x, a.forward] @@ -329,10 +378,12 @@ func (s *SkipList[K, T]) removeNode(x *Node[K, T], update []*Node[K, T]) { s.level-- } s.length-- + s.putNode(x) } // 根据score删除 func (s *SkipList[K, T]) Delete(score K) { + s.Remove(score) } @@ -358,6 +409,7 @@ func (s *SkipList[K, T]) Remove(score K) *SkipList[K, T] { } func (s *SkipList[K, T]) Draw() *SkipList[K, T] { + if s.head == nil { return s } @@ -378,6 +430,7 @@ func (s *SkipList[K, T]) Draw() *SkipList[K, T] { // 遍历 func (s *SkipList[K, T]) Range(callback func(score K, v T) bool) { + if s.head == nil { return } @@ -392,6 +445,7 @@ func (s *SkipList[K, T]) Range(callback func(score K, v T) bool) { // 返回最小的n个值, 升序返回, 比如0,1,2,3 func (s *SkipList[K, T]) TopMin(limit int, callback func(score K, v T) bool) { + s.Range(func(score K, v T) bool { if limit <= 0 { return false @@ -404,11 +458,13 @@ func (s *SkipList[K, T]) TopMin(limit int, callback func(score K, v T) bool) { // 返回长度 func (s *SkipList[K, T]) Len() int { + return s.length } // 从后向前倒序遍历b tree func (s *SkipList[K, T]) RangePrev(callback func(k K, v T) bool) *SkipList[K, T] { + // 遍历 if s.tail == nil { return s @@ -425,6 +481,7 @@ func (s *SkipList[K, T]) RangePrev(callback func(k K, v T) bool) *SkipList[K, T] // 返回最大的n个值, 降序返回, 10, 9, 8, 7 func (s *SkipList[K, T]) TopMax(limit int, callback func(k K, v T) bool) { + s.RangePrev(func(k K, v T) bool { if limit <= 0 { return false @@ -434,3 +491,82 @@ func (s *SkipList[K, T]) TopMax(limit int, callback func(k K, v T) bool) { return true }) } + +// ConcurrentSkipList represents a concurrent-safe skip list +// It embeds SkipList and adds a RWMutex for thread safety +type ConcurrentSkipList[K constraints.Ordered, T any] struct { + mu sync.RWMutex + *SkipList[K, T] +} + +// NewConcurrent returns a new concurrent-safe skip list +func NewConcurrent[K constraints.Ordered, T any]() *ConcurrentSkipList[K, T] { + return &ConcurrentSkipList[K, T]{ + SkipList: New[K, T](), + } +} + +// Insert inserts a new element into the concurrent skip list +func (c *ConcurrentSkipList[K, T]) Insert(score K, elem T) { + c.mu.Lock() + defer c.mu.Unlock() + c.SkipList.Insert(score, elem) +} + +// Delete removes an element from the concurrent skip list +func (c *ConcurrentSkipList[K, T]) Delete(score K) { + c.mu.Lock() + defer c.mu.Unlock() + c.SkipList.Delete(score) +} + +// Get retrieves an element from the concurrent skip list +func (c *ConcurrentSkipList[K, T]) Get(score K) (elem T, ok bool) { + c.mu.RLock() + defer c.mu.RUnlock() + return c.SkipList.GetWithBool(score) +} + +// Range iterates over elements in the concurrent skip list +func (c *ConcurrentSkipList[K, T]) Range(callback func(score K, v T) bool) { + c.mu.RLock() + defer c.mu.RUnlock() + c.SkipList.Range(callback) +} + +// Draw draws the concurrent skip list +func (c *ConcurrentSkipList[K, T]) Draw() *ConcurrentSkipList[K, T] { + c.mu.RLock() + defer c.mu.RUnlock() + c.SkipList.Draw() + return c +} + +// Len returns the length of the concurrent skip list +func (c *ConcurrentSkipList[K, T]) Len() int { + c.mu.RLock() + defer c.mu.RUnlock() + return c.SkipList.Len() +} + +// RangePrev iterates over elements in the concurrent skip list in reverse order +func (c *ConcurrentSkipList[K, T]) RangePrev(callback func(k K, v T) bool) *ConcurrentSkipList[K, T] { + c.mu.RLock() + defer c.mu.RUnlock() + c.SkipList.RangePrev(callback) + return c +} + +// TopMin returns the top min elements from the concurrent skip list +func (c *ConcurrentSkipList[K, T]) TopMin(limit int, callback func(score K, v T) bool) { + c.mu.RLock() + defer c.mu.RUnlock() + c.SkipList.TopMin(limit, callback) +} + +// TopMax returns the top max elements from the concurrent skip list +func (c *ConcurrentSkipList[K, T]) TopMax(limit int, callback func(k K, v T) bool) { + c.mu.RLock() + defer c.mu.RUnlock() + c.SkipList.TopMax(limit, callback) +} diff --git a/skiplist/skiplist_test.go b/skiplist/skiplist_test.go index 68ec4d9..e7e1efb 100644 --- a/skiplist/skiplist_test.go +++ b/skiplist/skiplist_test.go @@ -3,6 +3,7 @@ package skiplist // apache 2.0 antlabs import ( "fmt" + "sync" "testing" "github.com/antlabs/gstl/cmp" @@ -292,6 +293,65 @@ func Test_Skiplist_TopMax(t *testing.T) { } } +func Test_ConcurrentSkipList_InsertGet(t *testing.T) { + csl := NewConcurrent[int, string]() + var wg sync.WaitGroup + count := 1000 + + // Concurrent inserts + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + csl.Insert(i, fmt.Sprintf("value%d", i)) + }(i) + } + + wg.Wait() + + // Concurrent gets + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + if val, ok := csl.Get(i); !ok || val != fmt.Sprintf("value%d", i) { + t.Errorf("expected value%d, got %v", i, val) + } + }(i) + } + + wg.Wait() +} + +func Test_ConcurrentSkipList_Delete(t *testing.T) { + csl := NewConcurrent[int, string]() + var wg sync.WaitGroup + count := 1000 + + // Insert elements + for i := 0; i < count; i++ { + csl.Insert(i, fmt.Sprintf("value%d", i)) + } + + // Concurrent deletes + wg.Add(count) + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + csl.Delete(i) + }(i) + } + + wg.Wait() + + // Verify all elements are deleted + for i := 0; i < count; i++ { + if _, ok := csl.Get(i); ok { + t.Errorf("expected element %d to be deleted", i) + } + } +} + // 辅助函数,用于比较两个切片是否相等 func equalSlices(a, b []int) bool { if len(a) != len(b) {