Skip to content

Commit

Permalink
feat: skiplist加上线程安全的实现
Browse files Browse the repository at this point in the history
  • Loading branch information
guonaihong committed Dec 27, 2024
1 parent d5de1bd commit 42dc179
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 11 deletions.
158 changes: 147 additions & 11 deletions skiplist/skiplist.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"errors"
"fmt"
"math/rand"
"sync"
"time"

"github.com/antlabs/gstl/api"
Expand All @@ -54,6 +55,44 @@ var (
ErrNotFound = errors.New("not found element")
)

type nodePool[K constraints.Ordered, T any] struct {
pool sync.Pool
}

var (
poolMap sync.Map // map[reflect.Type]*nodePool[K, T]
)

func getNodePool[K constraints.Ordered, T any]() *nodePool[K, T] {
var k K
var t T
typKey := fmt.Sprintf("%T-%T", k, t)

if p, ok := poolMap.Load(typKey); ok {
return p.(*nodePool[K, T])
}

pool := &nodePool[K, T]{}
pool.pool.New = func() interface{} {
return new(Node[K, T])
}

actual, _ := poolMap.LoadOrStore(typKey, pool)
return actual.(*nodePool[K, T])
}

func (p *nodePool[K, T]) getNode() *Node[K, T] {
return p.pool.Get().(*Node[K, T])
}

func (p *nodePool[K, T]) putNode(n *Node[K, T]) {
n.score = *new(K)
n.elem = *new(T)
n.backward = nil
n.NodeLevel = nil
p.pool.Put(n)
}

type Node[K constraints.Ordered, T any] struct {
score K
elem T
Expand All @@ -73,6 +112,7 @@ type SkipList[K constraints.Ordered, T any] struct {
r *rand.Rand
length int
level int
pool *nodePool[K, T]

//compare func(T, T) int
}
Expand All @@ -86,8 +126,10 @@ func New[K constraints.Ordered, T any]() *SkipList[K, T] {

//s.compare = compare
var score K
var elem T
s.pool = getNodePool[K, T]()
s.resetRand()
s.head = newNode(SKIPLIST_MAXLEVEL, score, *new(T))
s.head = s.newNode(SKIPLIST_MAXLEVEL, score, elem)
return s
}

Expand All @@ -113,28 +155,35 @@ func (s *SkipList[K, T]) rand() int {
return SKIPLIST_MAXLEVEL
}

func newNode[K constraints.Ordered, T any](level int, score K, elem T) *Node[K, T] {
return &Node[K, T]{
score: score,
elem: elem,
NodeLevel: make([]struct {
forward *Node[K, T]
span int
}, level),
}
func (s *SkipList[K, T]) newNode(level int, score K, elem T) *Node[K, T] {
node := s.pool.getNode()
node.score = score
node.elem = elem
node.NodeLevel = make([]struct {
forward *Node[K, T]
span int
}, level)
return node
}

func (s *SkipList[K, T]) putNode(node *Node[K, T]) {
s.pool.putNode(node)
}

// 设置值, 和Insert是同义词
func (s *SkipList[K, T]) Set(score K, elem T) {

s.InsertInner(score, elem, s.rand())
}

func (s *SkipList[K, T]) SetWithPrev(score K, elem T) (prev T, replaced bool) {

return s.InsertInner(score, elem, s.rand())
}

// 设置值
func (s *SkipList[K, T]) Insert(score K, elem T) {

s.InsertInner(score, elem, s.rand())
}

Expand Down Expand Up @@ -189,7 +238,7 @@ func (s *SkipList[K, T]) InsertInner(score K, elem T, level int) (prev T, replac
}

// 创建新节点
x = newNode(level, score, elem)
x = s.newNode(level, score, elem)
for i := 0; i < level; i++ {
// x.NodeLevel[i]的节点假设等于a, 需要插入的节点x在a之后,
// a, x, a.forward三者的关系就是[a, x, a.forward]
Expand Down Expand Up @@ -329,10 +378,12 @@ func (s *SkipList[K, T]) removeNode(x *Node[K, T], update []*Node[K, T]) {
s.level--
}
s.length--
s.putNode(x)
}

// 根据score删除
func (s *SkipList[K, T]) Delete(score K) {

s.Remove(score)
}

Expand All @@ -358,6 +409,7 @@ func (s *SkipList[K, T]) Remove(score K) *SkipList[K, T] {
}

func (s *SkipList[K, T]) Draw() *SkipList[K, T] {

if s.head == nil {
return s
}
Expand All @@ -378,6 +430,7 @@ func (s *SkipList[K, T]) Draw() *SkipList[K, T] {

// 遍历
func (s *SkipList[K, T]) Range(callback func(score K, v T) bool) {

if s.head == nil {
return
}
Expand All @@ -392,6 +445,7 @@ func (s *SkipList[K, T]) Range(callback func(score K, v T) bool) {

// 返回最小的n个值, 升序返回, 比如0,1,2,3
func (s *SkipList[K, T]) TopMin(limit int, callback func(score K, v T) bool) {

s.Range(func(score K, v T) bool {
if limit <= 0 {
return false
Expand All @@ -404,11 +458,13 @@ func (s *SkipList[K, T]) TopMin(limit int, callback func(score K, v T) bool) {

// 返回长度
func (s *SkipList[K, T]) Len() int {

return s.length
}

// 从后向前倒序遍历b tree
func (s *SkipList[K, T]) RangePrev(callback func(k K, v T) bool) *SkipList[K, T] {

// 遍历
if s.tail == nil {
return s
Expand All @@ -425,6 +481,7 @@ func (s *SkipList[K, T]) RangePrev(callback func(k K, v T) bool) *SkipList[K, T]

// 返回最大的n个值, 降序返回, 10, 9, 8, 7
func (s *SkipList[K, T]) TopMax(limit int, callback func(k K, v T) bool) {

s.RangePrev(func(k K, v T) bool {
if limit <= 0 {
return false
Expand All @@ -434,3 +491,82 @@ func (s *SkipList[K, T]) TopMax(limit int, callback func(k K, v T) bool) {
return true
})
}

// ConcurrentSkipList represents a concurrent-safe skip list
// It embeds SkipList and adds a RWMutex for thread safety
type ConcurrentSkipList[K constraints.Ordered, T any] struct {
mu sync.RWMutex
*SkipList[K, T]
}

// NewConcurrent returns a new concurrent-safe skip list
func NewConcurrent[K constraints.Ordered, T any]() *ConcurrentSkipList[K, T] {
return &ConcurrentSkipList[K, T]{
SkipList: New[K, T](),
}
}

// Insert inserts a new element into the concurrent skip list
func (c *ConcurrentSkipList[K, T]) Insert(score K, elem T) {
c.mu.Lock()
defer c.mu.Unlock()
c.SkipList.Insert(score, elem)
}

// Delete removes an element from the concurrent skip list
func (c *ConcurrentSkipList[K, T]) Delete(score K) {
c.mu.Lock()
defer c.mu.Unlock()
c.SkipList.Delete(score)
}

// Get retrieves an element from the concurrent skip list
func (c *ConcurrentSkipList[K, T]) Get(score K) (elem T, ok bool) {
c.mu.RLock()
defer c.mu.RUnlock()
return c.SkipList.GetWithBool(score)
}

// Range iterates over elements in the concurrent skip list
func (c *ConcurrentSkipList[K, T]) Range(callback func(score K, v T) bool) {
c.mu.RLock()
defer c.mu.RUnlock()
c.SkipList.Range(callback)
}

// Draw draws the concurrent skip list
func (c *ConcurrentSkipList[K, T]) Draw() *ConcurrentSkipList[K, T] {
c.mu.RLock()
defer c.mu.RUnlock()
c.SkipList.Draw()
return c
}

// Len returns the length of the concurrent skip list
func (c *ConcurrentSkipList[K, T]) Len() int {
c.mu.RLock()
defer c.mu.RUnlock()
return c.SkipList.Len()
}

// RangePrev iterates over elements in the concurrent skip list in reverse order
func (c *ConcurrentSkipList[K, T]) RangePrev(callback func(k K, v T) bool) *ConcurrentSkipList[K, T] {
c.mu.RLock()
defer c.mu.RUnlock()
c.SkipList.RangePrev(callback)
return c
}

// TopMin returns the top min elements from the concurrent skip list
func (c *ConcurrentSkipList[K, T]) TopMin(limit int, callback func(score K, v T) bool) {
c.mu.RLock()
defer c.mu.RUnlock()
c.SkipList.TopMin(limit, callback)
}

// TopMax returns the top max elements from the concurrent skip list
func (c *ConcurrentSkipList[K, T]) TopMax(limit int, callback func(k K, v T) bool) {
c.mu.RLock()
defer c.mu.RUnlock()
c.SkipList.TopMax(limit, callback)
}
60 changes: 60 additions & 0 deletions skiplist/skiplist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package skiplist
// apache 2.0 antlabs
import (
"fmt"
"sync"
"testing"

"github.com/antlabs/gstl/cmp"
Expand Down Expand Up @@ -292,6 +293,65 @@ func Test_Skiplist_TopMax(t *testing.T) {
}
}

func Test_ConcurrentSkipList_InsertGet(t *testing.T) {
csl := NewConcurrent[int, string]()
var wg sync.WaitGroup
count := 1000

// Concurrent inserts
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
csl.Insert(i, fmt.Sprintf("value%d", i))
}(i)
}

wg.Wait()

// Concurrent gets
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
if val, ok := csl.Get(i); !ok || val != fmt.Sprintf("value%d", i) {
t.Errorf("expected value%d, got %v", i, val)
}
}(i)
}

wg.Wait()
}

func Test_ConcurrentSkipList_Delete(t *testing.T) {
csl := NewConcurrent[int, string]()
var wg sync.WaitGroup
count := 1000

// Insert elements
for i := 0; i < count; i++ {
csl.Insert(i, fmt.Sprintf("value%d", i))
}

// Concurrent deletes
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
csl.Delete(i)
}(i)
}

wg.Wait()

// Verify all elements are deleted
for i := 0; i < count; i++ {
if _, ok := csl.Get(i); ok {
t.Errorf("expected element %d to be deleted", i)
}
}
}

// 辅助函数,用于比较两个切片是否相等
func equalSlices(a, b []int) bool {
if len(a) != len(b) {
Expand Down

0 comments on commit 42dc179

Please sign in to comment.