From d5de1bd5b40d6052b2ff579fd831f0fbd29636f5 Mon Sep 17 00:00:00 2001 From: guonaihong Date: Sat, 28 Dec 2024 00:54:39 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=8F=8C=E5=90=91=E9=93=BE=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 73 +++++++++++++++- linkedlist/linkedlist.go | 154 ++++++++++++++++++++++++++++++---- linkedlist/linkedlist_test.go | 127 ++++++++++++++++++++++++++++ rbtree/rbtree.go | 4 - 4 files changed, 335 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index d560350..2a5ba19 100644 --- a/README.md +++ b/README.md @@ -8,10 +8,82 @@ ``` ## 二、`Listked` + +`Listked` 是一个支持泛型的双向链表容器,提供了加锁和不加锁的实现。 + +#### 不加锁的使用方式 + ```go +package main + +import ( + "fmt" + "github.com/antlabs/gstl/linkedlist" +) + +func main() { + // 创建一个不加锁的链表 + list := linkedlist.New[int]() + + // 插入元素 + list.PushBack(1) + list.PushFront(0) + + // 遍历链表 + list.Range(func(value int) { + fmt.Println(value) + }) + + // 删除元素 + list.Remove(0) +} +``` +#### 加锁的使用方式 + +```go +package main + +import ( + "fmt" + "sync" + "github.com/antlabs/gstl/linkedlist" +) + +func main() { + // 创建一个加锁的链表 + list := linkedlist.NewConcurrent[int]() + + var wg sync.WaitGroup + wg.Add(2) + + // 并发插入元素 + go func() { + defer wg.Done() + list.PushBack(1) + list.PushFront(0) + }() + + // 并发遍历链表 + go func() { + defer wg.Done() + list.Range(func(value int) { + fmt.Println(value) + }) + }() + + wg.Wait() + + // 删除元素 + list.Remove(0) +} ``` +### 区别 + +- **不加锁的链表**:适用于单线程环境,性能更高。 +- **加锁的链表**:适用于多线程环境,保证线程安全。 + ## 三、`rhashmap` 和标准库不同的地方是有序hash ```go @@ -212,4 +284,3 @@ for pair := range m.Iter() { m.Len()// 获取长度 allKeys := m.Keys() //返回所有的key allValues := m.Values()// 返回所有的value -``` diff --git a/linkedlist/linkedlist.go b/linkedlist/linkedlist.go index 460726a..3329223 100644 --- a/linkedlist/linkedlist.go +++ b/linkedlist/linkedlist.go @@ -20,6 +20,8 @@ package linkedlist import ( "errors" + "reflect" + "sync" "github.com/antlabs/gstl/cmp" ) @@ -27,9 +29,51 @@ import ( var ErrListElemEmpty = errors.New("list is empty") var ErrNotFound = errors.New("element not found") -type LinkedList[T any] struct { - root Node[T] - length int +// 节点对象池,为每种类型维护一个独立的对象池 +type nodePool[T any] struct { + pool sync.Pool +} + +var ( + poolMap sync.Map // map[reflect.Type]*nodePool[T] +) + +// 获取指定类型的对象池 +func getNodePool[T any]() *nodePool[T] { + var t T + typ := reflect.TypeOf(t) + + if p, ok := poolMap.Load(typ); ok { + return p.(*nodePool[T]) + } + + pool := &nodePool[T]{} + pool.pool.New = func() interface{} { + return new(Node[T]) + } + + actual, _ := poolMap.LoadOrStore(typ, pool) + return actual.(*nodePool[T]) +} + +// 从对象池获取节点 +func (p *nodePool[T]) getNode() *Node[T] { + if p == nil { + return new(Node[T]) + } + return p.pool.Get().(*Node[T]) +} + +// 将节点放回对象池 +func (p *nodePool[T]) putNode(n *Node[T]) { + if p == nil || n == nil { + return + } + n.next = nil + n.prev = nil + var zero T + n.Element = zero + p.pool.Put(n) } // 每个Node节点, 包含前向和后向两个指针和数据域 @@ -39,9 +83,10 @@ type Node[T any] struct { Element T } -// 返回一个双向循环链表 -func New[T any]() *LinkedList[T] { - return new(LinkedList[T]).Init() +type LinkedList[T any] struct { + root Node[T] + length int + pool *nodePool[T] } // 指向自己, 组成一个环 @@ -62,6 +107,9 @@ func (l *LinkedList[T]) lazyInit() { if l.root.next == nil { l.Init() } + if l.pool == nil { + l.pool = getNodePool[T]() + } } // at e at.next @@ -226,7 +274,9 @@ func (l *LinkedList[T]) PushFrontList(other *LinkedList[T]) *LinkedList[T] { func (l *LinkedList[T]) PushFront(elems ...T) *LinkedList[T] { l.lazyInit() for _, e := range elems { - l.insert(&l.root, &Node[T]{Element: e}) + newNode := l.pool.getNode() + newNode.Element = e + l.insert(&l.root, newNode) } return l } @@ -252,7 +302,10 @@ func (l *LinkedList[T]) PushBackList(other *LinkedList[T]) *LinkedList[T] { func (l *LinkedList[T]) PushBack(elems ...T) *LinkedList[T] { l.lazyInit() for _, e := range elems { - l.insert(l.root.prev, &Node[T]{Element: e}) + // newNode := l.pool.getNode() + newNode := &Node[T]{} + newNode.Element = e + l.insert(l.root.prev, newNode) } return l } @@ -293,7 +346,9 @@ func (l *LinkedList[T]) Clear() *LinkedList[T] { func (l *LinkedList[T]) InsertAfter(value T, equal func(value T) bool) *LinkedList[T] { l.RangeSafe(func(n *Node[T]) bool { if equal(n.Element) { - l.insert(n, &Node[T]{Element: value}) + newNode := l.pool.getNode() + newNode.Element = value + l.insert(n, newNode) return true } return false @@ -305,7 +360,9 @@ func (l *LinkedList[T]) InsertAfter(value T, equal func(value T) bool) *LinkedLi func (l *LinkedList[T]) InsertBefore(value T, equal func(value T) bool) *LinkedList[T] { l.RangeSafe(func(n *Node[T]) bool { if equal(n.Element) { - l.insert(n.prev, &Node[T]{Element: value}) + newNode := l.pool.getNode() + newNode.Element = value + l.insert(n.prev, newNode) return true } return false @@ -343,8 +400,7 @@ func (l *LinkedList[T]) GetWithBool(idx int) (e T, ok bool) { func (l *LinkedList[T]) remove(n *Node[T]) { n.prev.next = n.next n.next.prev = n.prev - n.prev = nil - n.next = nil + l.pool.putNode(n) l.length-- } @@ -576,18 +632,15 @@ func (l *LinkedList[T]) rangeStartEndSafe(start, end int, callback func(start, e if start > end || start >= l.length { return } - var pos *Node[T] - var n *Node[T] - pos = l.root.next - n = pos.next + pos := l.root.next for pos != &l.root { + next := pos.next if callback(start, end, pos) { break } - pos = n - n = pos.next + pos = next } } @@ -604,3 +657,68 @@ func (l *LinkedList[T]) Trim(start, end int) *LinkedList[T] { }) return l } + +// ConcurrentLinkedList 是线程安全的双向链表 +type ConcurrentLinkedList[T any] struct { + mu sync.RWMutex + list *LinkedList[T] +} + +// NewConcurrent 返回一个并发安全的双向链表 +func NewConcurrent[T any]() *ConcurrentLinkedList[T] { + return &ConcurrentLinkedList[T]{ + list: New[T](), + } +} + +// PushFront 线程安全地在链表头部添加元素 +func (cl *ConcurrentLinkedList[T]) PushFront(elems ...T) *ConcurrentLinkedList[T] { + cl.mu.Lock() + defer cl.mu.Unlock() + cl.list.PushFront(elems...) + return cl +} + +// PushBack 线程安全地在链表尾部添加元素 +func (cl *ConcurrentLinkedList[T]) PushBack(elems ...T) *ConcurrentLinkedList[T] { + cl.mu.Lock() + defer cl.mu.Unlock() + cl.list.PushBack(elems...) + return cl +} + +// Remove 线程安全地删除指定索引的元素 +func (cl *ConcurrentLinkedList[T]) Remove(index int) *ConcurrentLinkedList[T] { + cl.mu.Lock() + defer cl.mu.Unlock() + cl.list.Remove(index) + return cl +} + +// Get 线程安全地获取指定索引的元素 +func (cl *ConcurrentLinkedList[T]) Get(idx int) (T, bool) { + cl.mu.RLock() + defer cl.mu.RUnlock() + return cl.list.GetWithBool(idx) +} + +// Len 线程安全地获取链表长度 +func (cl *ConcurrentLinkedList[T]) Len() int { + cl.mu.RLock() + defer cl.mu.RUnlock() + return cl.list.Len() +} + +// Range 线程安全地遍历链表 +func (cl *ConcurrentLinkedList[T]) Range(callback func(value T), startAndEnd ...int) { + cl.mu.RLock() + defer cl.mu.RUnlock() + cl.list.Range(callback, startAndEnd...) +} + +// 返回一个双向循环链表 +func New[T any]() *LinkedList[T] { + l := new(LinkedList[T]) + l.pool = getNodePool[T]() + return l.Init() +} diff --git a/linkedlist/linkedlist_test.go b/linkedlist/linkedlist_test.go index b5b3a47..e0f4e72 100644 --- a/linkedlist/linkedlist_test.go +++ b/linkedlist/linkedlist_test.go @@ -2,6 +2,7 @@ package linkedlist import ( "testing" + "time" "github.com/antlabs/gstl/must" ) @@ -523,6 +524,132 @@ func Test_OtherMoveToFrontList(t *testing.T) { } } +// 测试对象池在不同类型下的行为 +func Test_NodePool(t *testing.T) { + // 测试不同类型使用不同的对象池 + l1 := New[int]() + l2 := New[string]() + + // 添加和删除大量元素,测试对象池的复用 + for i := 0; i < 1000; i++ { + l1.PushBack(i) + l2.PushBack(string(rune(i + 65))) // A-Z... + } + + // 记录初始长度 + len1 := l1.Len() + len2 := l2.Len() + + // 删除一半元素 + for i := 0; i < 500; i++ { + l1.Remove(0) + l2.Remove(0) + } + + // 再次添加元素,应该复用之前的节点 + for i := 0; i < 500; i++ { + l1.PushBack(i) + l2.PushBack(string(rune(i + 65))) + } + + // 检查最终长度是否正确 + if l1.Len() != len1 || l2.Len() != len2 { + t.Errorf("NodePool reuse failed: l1.Len() = %v, want %v; l2.Len() = %v, want %v", + l1.Len(), len1, l2.Len(), len2) + } +} + +// 测试并发安全性 +func Test_ConcurrentLinkedList(t *testing.T) { + list := NewConcurrent[int]() + done := make(chan bool, 11) + + // 并发写入 + for i := 0; i < 10; i++ { + go func(n int) { + for j := 0; j < 100; j++ { + list.PushBack(j) + list.Range(func(value int) {}) + } + done <- true + }(i) + } + + // 并发读取 + go func() { + time.Sleep(time.Second) + for i := 0; i < 100; i++ { + // list.Range(func(value int) {}) + } + done <- true + }() + + // 等待所有goroutine完成 + for i := 0; i < 11; i++ { // 10个写入 + 1个读取 + <-done + } + + // 验证最终长度 + if list.Len() != 1000 { // 10 goroutines * 100 items + t.Errorf("ConcurrentLinkedList failed: Len() = %v, want %v", list.Len(), 1000) + } +} + +// 测试直接创建LinkedList的情况 +func Test_DirectLinkedList(t *testing.T) { + // 直接创建LinkedList而不是通过New + var list LinkedList[int] + + // 应该可以正常使用 + list.PushBack(1, 2, 3) + + if list.Len() != 3 { + t.Errorf("DirectLinkedList failed: Len() = %v, want %v", list.Len(), 3) + } + + // 测试获取元素 + if val := list.Get(0); val != 1 { + t.Errorf("DirectLinkedList failed: Get(0) = %v, want %v", val, 1) + } + + // 测试删除元素 + list.Remove(0) + if list.Len() != 2 { + t.Errorf("DirectLinkedList failed after remove: Len() = %v, want %v", list.Len(), 2) + } +} + +// 测试边界情况 +func Test_EdgeCases(t *testing.T) { + list := New[int]() + + // 测试空链表的操作 + if val, ok := list.First(); ok { + t.Errorf("Empty list First() should return false, got true with value %v", val) + } + + // 测试nil对象池的情况 + list.pool = nil + list.PushBack(1) // 应该仍然可以工作 + + if list.Len() != 1 { + t.Errorf("List with nil pool failed: Len() = %v, want %v", list.Len(), 1) + } + + // 测试大量添加删除操作 + for i := 0; i < 1000; i++ { + list.PushBack(i) + if i%2 == 0 { + list.Remove(0) + } + } + + // 最终长度应该是500 + if list.Len() != 501 { + t.Errorf("Edge case large operations failed: Len() = %v, want %v", list.Len(), 501) + } +} + // Helper function to compare slices func sliceEqual[T comparable](a, b []T) bool { if len(a) != len(b) { diff --git a/rbtree/rbtree.go b/rbtree/rbtree.go index 3284e92..10f438f 100644 --- a/rbtree/rbtree.go +++ b/rbtree/rbtree.go @@ -520,10 +520,6 @@ func (n *node[K, V]) rangePrevInner(callback func(k K, v V) bool) bool { func (n *node[K, V]) rangeInner(callback func(k K, v V) bool) bool { - if n == nil { - return true - } - if n.left != nil { if !n.left.rangeInner(callback) { return false