-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathavl_tree.go
257 lines (230 loc) · 4.88 KB
/
avl_tree.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
package collections
type avlNode struct {
h int
value int
left *avlNode
right *avlNode
}
type AVLTree struct {
tree *avlNode
}
// 生成 AVL 树
func NewAVLTree() *AVLTree {
return &AVLTree{&avlNode{h: -2}}
}
// 插入节点
func (a *AVLTree) Insert(v int) {
a.tree = insert(v, a.tree)
}
// 搜索节点
func (a *AVLTree) Search(v int) bool {
return a.tree.search(v)
}
// 删除节点
func (a *AVLTree) Delete(v int) bool {
if a.tree.search(v) {
a.tree.delete(v)
return true
}
return false
}
// 获取所有节点中的最大值
func (a *AVLTree) GetMaxValue() int {
return a.tree.maxNode().value
}
// 获取所有节点中的最小值
func (a *AVLTree) GetMinValue() int {
return a.tree.minNode().value
}
// 返回排序后所有值
func (a *AVLTree) AllValues() []int {
return a.tree.values()
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func insert(v int, t *avlNode) *avlNode {
if t == nil {
return &avlNode{value: v}
}
if t.h == -2 {
t.value = v
t.h = 0
return t
}
cmp := v - t.value
if cmp > 0 {
// 将节点插入到右子树中
t.right = insert(v, t.right)
} else if cmp < 0 {
// 将节点插入到左子树中
t.left = insert(v, t.left)
}
// 维持树平衡
t = t.keepBalance(v)
t.h = max(t.left.height(), t.right.height()) + 1
return t
}
func (t *avlNode) search(v int) bool {
if t == nil {
return false
}
cmp := v - t.value
if cmp > 0 {
// 如果 v 大于当前节点值,继续从右子树中寻找
return t.right.search(v)
} else if cmp < 0 {
// 如果 v 小于当前节点值,继续从左子树中寻找
return t.left.search(v)
} else {
// 相等则表示找到
return true
}
}
func (t *avlNode) delete(v int) *avlNode {
if t == nil {
return t
}
cmp := v - t.value
if cmp > 0 {
// 如果 v 大于当前节点值,继续从右子树中删除
t.right = t.right.delete(v)
} else if cmp < 0 {
// 如果 v 小于当前节点值,继续从左子树中删除
t.left = t.left.delete(v)
} else {
// 找到 v
if t.left != nil && t.right != nil {
// 如果该节点既有左子树又有右子树
// 使用右子树中的最小节点取代删除节点,然后删除右子树中的最小节点
t.value = t.right.minNode().value
t.right = t.right.delete(t.value)
} else if t.left != nil {
// 如果只有左子树,则直接删除节点
t = t.left
} else {
// 只有右子树或空树
t = t.right
}
}
if t != nil {
t.h = max(t.left.height(), t.right.height()) + 1
t = t.keepBalance(v)
}
return t
}
func (t *avlNode) minNode() *avlNode {
if t == nil {
return nil
}
// 整棵树的最左边节点就是值最小的节点
if t.left == nil {
return t
} else {
return t.left.minNode()
}
}
func (t *avlNode) maxNode() *avlNode {
if t == nil {
return nil
}
// 整棵树的最右边节点就是值最大的节点
if t.right == nil {
return t
} else {
return t.right.maxNode()
}
}
/*
左左情况:右旋
*
*
*
*/
func (t *avlNode) llRotate() *avlNode {
node := t.left
t.left = node.right
node.right = t
node.h = max(node.left.height(), node.right.height()) + 1
t.h = max(t.left.height(), t.right.height()) + 1
return node
}
/*
右右情况:左旋
*
*
*
*/
func (t *avlNode) rrRotate() *avlNode {
node := t.right
t.right = node.left
node.left = t
node.h = max(node.left.height(), node.right.height()) + 1
t.h = max(t.left.height(), t.right.height()) + 1
return node
}
/*
左右情况:先左旋 后右旋
*
*
*
*/
func (t *avlNode) lrRotate() *avlNode {
t.left = t.left.rrRotate()
return t.llRotate()
}
/*
右左情况:先右旋 后左旋
*
*
*
*/
func (t *avlNode) rlRotate() *avlNode {
t.right = t.right.llRotate()
return t.rrRotate()
}
func (t *avlNode) keepBalance(v int) *avlNode {
// 左子树失衡
if t.left.height()-t.right.height() == 2 {
if v-t.left.value < 0 {
// 当插入的节点在失衡节点的左子树的左子树中,直接右旋
t = t.llRotate()
} else {
// 当插入的节点在失衡节点的左子树的右子树中,先左旋后右旋
t = t.lrRotate()
}
} else if t.right.height()-t.left.height() == 2 {
if t.right.right.height() > t.right.left.height() {
// 当插入的节点在失衡节点的右子树的右子树中,直接左旋
t = t.rrRotate()
} else {
// 当插入的节点在失衡节点的右子树的左子树中,先右旋后左旋
t = t.rlRotate()
}
}
// 调整树高度
t.h = max(t.left.height(), t.right.height()) + 1
return t
}
func (t *avlNode) height() int {
if t != nil {
return t.h
}
return -1
}
// 中序遍历按顺序获取所有值
func appendValue(values []int, t *avlNode) []int {
if t != nil {
values = appendValue(values, t.left)
values = append(values, t.value)
values = appendValue(values, t.right)
}
return values
}
func (t *avlNode) values() []int {
values := make([]int, 0)
return appendValue(values, t)
}