From d57183fe4c3fd2a5a6a56a9a88fe0fd2ad404223 Mon Sep 17 00:00:00 2001 From: clark Date: Fri, 7 May 2021 10:16:07 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=E9=94=99=E8=AF=AF=E5=8F=98=E9=87=8F=20?= =?UTF-8?q?=E6=8B=86=E5=88=86=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- npubo.go | 187 ++++++++++++++------------------------------------ npubo_test.go | 18 ++++- 2 files changed, 66 insertions(+), 139 deletions(-) diff --git a/npubo.go b/npubo.go index 8adf785..00da64e 100644 --- a/npubo.go +++ b/npubo.go @@ -3,9 +3,15 @@ package npubo import ( "errors" "fmt" - "strings" + "runtime" "sync" - "time" +) + +var ( + ErrTopicNotFound = errors.New("topic not found") + ErrSubscriberTimeout = errors.New("subscriber timeout") + ErrInvaildTopic = errors.New("invaild topic") + ErrNilNode = errors.New("nil node ") ) type ( @@ -15,7 +21,12 @@ type ( // 推送错误返回回调 ErrCall func(sub *Subscriber, e error) - // 前缀输节点 + ChanCall struct { + Subscriber *Subscriber + Content interface{} + } + + // 前缀树节点 Node struct { Calls map[string]*Call NextNode map[string]*Node @@ -24,17 +35,21 @@ type ( // 推送结构 Publisher struct { Node *Node + Root *Node timeout int + openChan bool rwLock *sync.RWMutex workerLock *sync.Mutex } // 订阅结构 Subscriber struct { - Node *Node // 所在节点 - Publisher *Publisher // 所在推送 - Topic string // 订阅路径 - CId string // 客户端Id + Node *Node // 所在节点 + Publisher *Publisher // 所在推送 + Topic string // 订阅路径 + CallTopic string // 推送订阅路径 + CId string // 客户端Id + C chan (ChanCall) // 通道订阅 isWorker bool } ) @@ -47,144 +62,35 @@ func newNode() *Node { } } -func NewPublisher(timeout int) *Publisher { +func NewPublisher(timeout int, openChan bool) *Publisher { return &Publisher{ - Node: newNode(), + Node: newNode(), + Root: &Node{ + Calls: make(map[string]*Call), + NextNode: nil, + }, + openChan: openChan, timeout: timeout, rwLock: &sync.RWMutex{}, workerLock: &sync.Mutex{}, } } -// 订阅 -func (that *Publisher) Subscribe(topic string, c_id string, call Call) (*Subscriber, error) { - that.rwLock.Lock() - defer that.rwLock.Unlock() - - nowNode := that.Node.NextNode - cals := strings.Split(topic, "/") - sub := &Subscriber{ - Topic: topic, - CId: c_id, - Node: nil, - } - - if strings.Contains(topic, "*") || cals[0] == "" { - return sub, errors.New("invaild topic") - } - - for i, v := range cals { - if _, ok := nowNode[v]; !ok { - nowNode[v] = newNode() - } - if i == len(cals)-1 { - nowNode[v].Calls[c_id] = &call - sub.Node = nowNode[v] - } - nowNode = nowNode[v].NextNode - } - return sub, nil -} - -// 发布消息 -func (that *Publisher) Publish(topic string, val interface{}, errBack ErrCall) { - that.rwLock.RLock() - defer that.rwLock.RUnlock() - - nowNode := that.Node.NextNode - cals := strings.Split(topic, "/") - - topicRecode := []string{} - for i, v := range cals { - if _, ok := nowNode[v]; ok { // 正常匹配路径 - topicRecode = append(topicRecode, v) - if i == len(cals)-1 { - for c_id, call := range nowNode[v].Calls { - sub := &Subscriber{ - Topic: strings.Join(topicRecode, "/"), - CId: c_id, - Node: nowNode[v], - Publisher: that, - isWorker: true, - } - e := that.callSubscriber(call, sub, val) - if e != nil && errBack != nil { - that.callErrBack(errBack, sub, e) - } - } - } - nowNode = nowNode[v].NextNode - } else if v == "*" { // 匹配通配符 - that.callAllNode(nowNode, errBack, topicRecode, val) - break - } else { - break - } - } -} - -func (that *Publisher) callAllNode(nowNode map[string]*Node, errBack ErrCall, topicInit []string, val interface{}) { - var wg sync.WaitGroup - for t, n := range nowNode { - topic := topicInit - topic = append(topic, t) - for c_id, call := range n.Calls { - sub := &Subscriber{ - Topic: strings.Join(topic, "/"), - CId: c_id, - Node: n, - Publisher: that, - isWorker: true, - } - e := that.callSubscriber(call, sub, val) - if e != nil && errBack != nil { - that.callErrBack(errBack, sub, e) - } - } - - wg.Add(1) - go func(nextNode map[string]*Node, topic []string, val interface{}) { - defer wg.Done() - that.callAllNode(nextNode, errBack, topic, val) - }(n.NextNode, topic, val) +func (that *Subscriber) RootEvict(c_id string, call Call) error { + if that.isWorker { + that.Publisher.workerLock.Lock() + defer that.Publisher.workerLock.Unlock() + } else { + that.Publisher.rwLock.Lock() + defer that.Publisher.rwLock.Unlock() } - wg.Wait() -} - -func (that *Publisher) callSubscriber(call *Call, sub *Subscriber, val interface{}) (e error) { - done := make(chan byte, 0) - go func() { - defer func() { - if r := recover(); r != nil { - e = errors.New(fmt.Sprint(r)) - } - }() - e = (*call)(sub, val) - done <- 0 - }() - select { - case <-done: - return e - case <-time.After(time.Microsecond * time.Duration(that.timeout)): - close(done) - return errors.New("subscriber timeout") + if that.Node == nil { + return nil } -} -func (that *Publisher) callErrBack(errBack ErrCall, sub *Subscriber, e error) { - done := make(chan byte, 0) - go func() { - defer func() { recover() }() - errBack(sub, e) - done <- 0 - }() - select { - case <-done: - return - case <-time.After(time.Microsecond * time.Duration(that.timeout)): - close(done) - return - } + delete(that.Node.Calls, that.CId) + that.Node = nil + return nil } // 取消订阅 @@ -201,6 +107,8 @@ func (that *Subscriber) Evict() error { } delete(that.Node.Calls, that.CId) + defer func() { recover() }() + close(that.C) that.Node = nil return nil } @@ -215,7 +123,7 @@ func (that *Subscriber) RewriteCall(call Call) error { defer that.Publisher.rwLock.Unlock() } if that.Node == nil { - return errors.New("node is nil") + return ErrNilNode } that.Node.Calls[that.CId] = &call @@ -226,5 +134,10 @@ func (that *Subscriber) RewriteCall(call Call) error { func (that *Publisher) Close() { that.rwLock.RLock() defer that.rwLock.RUnlock() - that.Node = newNode() + that.Node, that.Root = nil, nil + + runtime.GC() + + that.Node, that.Root = newNode(), newNode() + fmt.Println(that) } diff --git a/npubo_test.go b/npubo_test.go index 0cb57c2..84a3b15 100644 --- a/npubo_test.go +++ b/npubo_test.go @@ -8,14 +8,23 @@ import ( "time" ) -var pub *npubo.Publisher = npubo.NewPublisher(500) +var pub *npubo.Publisher = npubo.NewPublisher(500, true) func TestSub(t *testing.T) { - pub.Subscribe("sub_one/one", "QwQ", func(sub *npubo.Subscriber, val interface{}) error { + + sub, _ := pub.Subscribe("sub_one/one", "QwQ", func(sub *npubo.Subscriber, val interface{}) error { fmt.Println("sub", sub, " message", val) return nil }) + go func(sub *npubo.Subscriber) { + for val := range sub.C { + fmt.Println("chan:", val.Subscriber.CallTopic) + } + }(sub) + + //sub.Evict() + pub.Subscribe("sub_one/timeout", "QwQ", func(sub *npubo.Subscriber, val interface{}) error { time.Sleep(time.Second) return nil @@ -24,6 +33,7 @@ func TestSub(t *testing.T) { pub.Subscribe("sub_one/error", "QwQ", func(sub *npubo.Subscriber, val interface{}) error { return errors.New("a error") }) + //pub.Close() } func TestPub(t *testing.T) { @@ -38,6 +48,10 @@ func TestPub(t *testing.T) { pub.Publish("sub_one/error", "Message", func(sub *npubo.Subscriber, e error) { fmt.Println("sub", sub, " error", e) }) + + pub.Publish("*", "Call All Subscriber", func(sub *npubo.Subscriber, e error) { + fmt.Println("sub", sub, " error", e) + }) } func BenchmarkSub(b *testing.B) {