-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathkeys_storage.go
171 lines (144 loc) · 4.31 KB
/
keys_storage.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
package doubleratchet
import (
"bytes"
"fmt"
"sort"
)
// KeysStorage is an interface of an abstract in-memory or persistent keys storage.
type KeysStorage interface {
// Get returns a message key by the given key and message number.
Get(k Key, msgNum uint) (mk Key, ok bool, err error)
// Put saves the given mk under the specified key and msgNum.
Put(sessionID []byte, k Key, msgNum uint, mk Key, keySeqNum uint) error
// DeleteMk ensures there's no message key under the specified key and msgNum.
DeleteMk(k Key, msgNum uint) error
// DeleteOldMKeys deletes old message keys for a session.
DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error
// TruncateMks truncates the number of keys to maxKeys.
TruncateMks(sessionID []byte, maxKeys int) error
// Count returns number of message keys stored under the specified key.
Count(k Key) (uint, error)
// All returns all the keys
All() (map[string]map[uint]Key, error)
}
// KeysStorageInMemory is an in-memory message keys storage.
type KeysStorageInMemory struct {
keys map[string]map[uint]InMemoryKey
}
// Get returns a message key by the given key and message number.
func (s *KeysStorageInMemory) Get(pubKey Key, msgNum uint) (Key, bool, error) {
index := fmt.Sprintf("%x", pubKey)
if s.keys == nil {
return Key{}, false, nil
}
msgs, ok := s.keys[index]
if !ok {
return Key{}, false, nil
}
mk, ok := msgs[msgNum]
if !ok {
return Key{}, false, nil
}
return mk.messageKey, true, nil
}
type InMemoryKey struct {
messageKey Key
seqNum uint
sessionID []byte
}
// Put saves the given mk under the specified key and msgNum.
func (s *KeysStorageInMemory) Put(sessionID []byte, pubKey Key, msgNum uint, mk Key, seqNum uint) error {
index := fmt.Sprintf("%x", pubKey)
if s.keys == nil {
s.keys = make(map[string]map[uint]InMemoryKey)
}
if _, ok := s.keys[index]; !ok {
s.keys[index] = make(map[uint]InMemoryKey)
}
s.keys[index][msgNum] = InMemoryKey{
sessionID: sessionID,
messageKey: mk,
seqNum: seqNum,
}
return nil
}
// DeleteMk ensures there's no message key under the specified key and msgNum.
func (s *KeysStorageInMemory) DeleteMk(pubKey Key, msgNum uint) error {
index := fmt.Sprintf("%x", pubKey)
if s.keys == nil {
return nil
}
if _, ok := s.keys[index]; !ok {
return nil
}
if _, ok := s.keys[index][msgNum]; !ok {
return nil
}
delete(s.keys[index], msgNum)
if len(s.keys[index]) == 0 {
delete(s.keys, index)
}
return nil
}
// TruncateMks truncates the number of keys to maxKeys.
func (s *KeysStorageInMemory) TruncateMks(sessionID []byte, maxKeys int) error {
var seqNos []uint
// Collect all seq numbers
for _, keys := range s.keys {
for _, inMemoryKey := range keys {
if bytes.Equal(inMemoryKey.sessionID, sessionID) {
seqNos = append(seqNos, inMemoryKey.seqNum)
}
}
}
// Nothing to do if we haven't reached the limit
if len(seqNos) <= maxKeys {
return nil
}
// Take the sequence numbers we care about
sort.Slice(seqNos, func(i, j int) bool { return seqNos[i] < seqNos[j] })
toDeleteSlice := seqNos[:len(seqNos)-maxKeys]
// Put in map for easier lookup
toDelete := make(map[uint]bool)
for _, seqNo := range toDeleteSlice {
toDelete[seqNo] = true
}
for pubKey, keys := range s.keys {
for i, inMemoryKey := range keys {
if toDelete[inMemoryKey.seqNum] && bytes.Equal(inMemoryKey.sessionID, sessionID) {
delete(s.keys[pubKey], i)
}
}
}
return nil
}
// DeleteOldMKeys deletes old message keys for a session.
func (s *KeysStorageInMemory) DeleteOldMks(sessionID []byte, deleteUntilSeqKey uint) error {
for pubKey, keys := range s.keys {
for i, inMemoryKey := range keys {
if inMemoryKey.seqNum <= deleteUntilSeqKey && bytes.Equal(inMemoryKey.sessionID, sessionID) {
delete(s.keys[pubKey], i)
}
}
}
return nil
}
// Count returns number of message keys stored under the specified key.
func (s *KeysStorageInMemory) Count(pubKey Key) (uint, error) {
index := fmt.Sprintf("%x", pubKey)
if s.keys == nil {
return 0, nil
}
return uint(len(s.keys[index])), nil
}
// All returns all the keys
func (s *KeysStorageInMemory) All() (map[string]map[uint]Key, error) {
response := make(map[string]map[uint]Key)
for pubKey, keys := range s.keys {
response[pubKey] = make(map[uint]Key)
for n, key := range keys {
response[pubKey][n] = key.messageKey
}
}
return response, nil
}