-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathstate.go
178 lines (143 loc) · 4.52 KB
/
state.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
package doubleratchet
// TODO: During each DH ratchet step a new ratchet key pair and sending chain are generated.
// As the sending chain is not needed right away, these steps could be deferred until the party
// is about to send a new message.
import (
"fmt"
)
// The double ratchet state.
type State struct {
Crypto Crypto
// DH Ratchet public key (the remote key).
DHr Key
// DH Ratchet key pair (the self ratchet key).
DHs DHPair
// Symmetric ratchet root chain.
RootCh kdfRootChain
// Symmetric ratchet sending and receiving chains.
SendCh, RecvCh kdfChain
// Number of messages in previous sending chain.
PN uint32
// Dictionary of skipped-over message keys, indexed by ratchet public key or header key
// and message number.
MkSkipped KeysStorage
// The maximum number of message keys that can be skipped in a single chain.
// WithMaxSkip should be set high enough to tolerate routine lost or delayed messages,
// but low enough that a malicious sender can't trigger excessive recipient computation.
MaxSkip uint
// Receiving header key and next header key. Only used for header encryption.
HKr, NHKr Key
// Sending header key and next header key. Only used for header encryption.
HKs, NHKs Key
// How long we keep messages keys, counted in number of messages received,
// for example if MaxKeep is 5 we only keep the last 5 messages keys, deleting everything n - 5.
MaxKeep uint
// Max number of message keys per session, older keys will be deleted in FIFO fashion
MaxMessageKeysPerSession int
// The number of the current ratchet step.
Step uint
// KeysCount the number of keys generated for decrypting
KeysCount uint
}
func DefaultState(sharedKey Key) State {
c := DefaultCrypto{}
return State{
DHs: dhPair{},
Crypto: c,
RootCh: kdfRootChain{CK: sharedKey, Crypto: c},
// Populate CKs and CKr with sharedKey so that both parties could send and receive
// messages from the very beginning.
SendCh: kdfChain{CK: sharedKey, Crypto: c},
RecvCh: kdfChain{CK: sharedKey, Crypto: c},
MkSkipped: &KeysStorageInMemory{},
MaxSkip: 1000,
MaxMessageKeysPerSession: 2000,
MaxKeep: 2000,
KeysCount: 0,
}
}
func (s *State) applyOptions(opts []option) error {
for i := range opts {
if err := opts[i](s); err != nil {
return fmt.Errorf("failed to apply option: %s", err)
}
}
return nil
}
func newState(sharedKey Key, opts ...option) (State, error) {
if sharedKey == nil {
return State{}, fmt.Errorf("sharedKey mustn't be empty")
}
s := DefaultState(sharedKey)
if err := s.applyOptions(opts); err != nil {
return State{}, err
}
return s, nil
}
// dhRatchet performs a single ratchet step.
func (s *State) dhRatchet(m MessageHeader) error {
s.PN = s.SendCh.N
s.DHr = m.DH
s.HKs = s.NHKs
s.HKr = s.NHKr
recvSecret, err := s.Crypto.DH(s.DHs, s.DHr)
if err != nil {
return fmt.Errorf("failed to generate dh recieve ratchet secret: %s", err)
}
s.RecvCh, s.NHKr = s.RootCh.step(recvSecret)
s.DHs, err = s.Crypto.GenerateDH()
if err != nil {
return fmt.Errorf("failed to generate dh pair: %s", err)
}
sendSecret, err := s.Crypto.DH(s.DHs, s.DHr)
if err != nil {
return fmt.Errorf("failed to generate dh send ratchet secret: %s", err)
}
s.SendCh, s.NHKs = s.RootCh.step(sendSecret)
return nil
}
type skippedKey struct {
key Key
nr uint
mk Key
seq uint
}
// skipMessageKeys skips message keys in the current receiving chain.
func (s *State) skipMessageKeys(key Key, until uint) ([]skippedKey, error) {
if until < uint(s.RecvCh.N) {
return nil, fmt.Errorf("bad until: probably an out-of-order message that was deleted")
}
if uint(s.RecvCh.N)+s.MaxSkip < until {
return nil, fmt.Errorf("too many messages")
}
skipped := []skippedKey{}
for uint(s.RecvCh.N) < until {
mk := s.RecvCh.step()
skipped = append(skipped, skippedKey{
key: key,
nr: uint(s.RecvCh.N - 1),
mk: mk,
seq: s.KeysCount,
})
// Increment key count
s.KeysCount++
}
return skipped, nil
}
func (s *State) applyChanges(sc State, sessionID []byte, skipped []skippedKey) error {
*s = sc
for _, skipped := range skipped {
if err := s.MkSkipped.Put(sessionID, skipped.key, skipped.nr, skipped.mk, skipped.seq); err != nil {
return err
}
}
if err := s.MkSkipped.TruncateMks(sessionID, s.MaxMessageKeysPerSession); err != nil {
return err
}
if s.KeysCount >= s.MaxKeep {
if err := s.MkSkipped.DeleteOldMks(sessionID, s.KeysCount-s.MaxKeep); err != nil {
return err
}
}
return nil
}