-
Notifications
You must be signed in to change notification settings - Fork 13
/
graph.py
95 lines (72 loc) · 2.98 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
class Node:
def __init__(self, parent, rank=0, size=1):
self.parent = parent
self.rank = rank
self.size = size
def __repr__(self):
return '(parent=%s, rank=%s, size=%s)' % (self.parent, self.rank, self.size)
class Forest:
def __init__(self, num_nodes):
self.nodes = [Node(i) for i in xrange(num_nodes)]
self.num_sets = num_nodes
def size_of(self, i):
return self.nodes[i].size
def find(self, n):
temp = n
while temp != self.nodes[temp].parent:
temp = self.nodes[temp].parent
self.nodes[n].parent = temp
return temp
def merge(self, a, b):
if self.nodes[a].rank > self.nodes[b].rank:
self.nodes[b].parent = a
self.nodes[a].size = self.nodes[a].size + self.nodes[b].size
else:
self.nodes[a].parent = b
self.nodes[b].size = self.nodes[b].size + self.nodes[a].size
if self.nodes[a].rank == self.nodes[b].rank:
self.nodes[b].rank = self.nodes[b].rank + 1
self.num_sets = self.num_sets - 1
def print_nodes(self):
for node in self.nodes:
print node
def create_edge(img, width, x, y, x1, y1, diff):
vertex_id = lambda x, y: y * width + x
w = diff(img, x, y, x1, y1)
return (vertex_id(x, y), vertex_id(x1, y1), w)
def build_graph(img, width, height, diff, neighborhood_8=False):
graph = []
for y in xrange(height):
for x in xrange(width):
if x > 0:
graph.append(create_edge(img, width, x, y, x-1, y, diff))
if y > 0:
graph.append(create_edge(img, width, x, y, x, y-1, diff))
if neighborhood_8:
if x > 0 and y > 0:
graph.append(create_edge(img, width, x, y, x-1, y-1, diff))
if x > 0 and y < height-1:
graph.append(create_edge(img, width, x, y, x-1, y+1, diff))
return graph
def remove_small_components(forest, graph, min_size):
for edge in graph:
a = forest.find(edge[0])
b = forest.find(edge[1])
if a != b and (forest.size_of(a) < min_size or forest.size_of(b) < min_size):
forest.merge(a, b)
return forest
def segment_graph(graph, num_nodes, const, min_size, threshold_func):
weight = lambda edge: edge[2]
forest = Forest(num_nodes)
sorted_graph = sorted(graph, key=weight)
threshold = [threshold_func(1, const)] * num_nodes
for edge in sorted_graph:
parent_a = forest.find(edge[0])
parent_b = forest.find(edge[1])
a_condition = weight(edge) <= threshold[parent_a]
b_condition = weight(edge) <= threshold[parent_b]
if parent_a != parent_b and a_condition and b_condition:
forest.merge(parent_a, parent_b)
a = forest.find(parent_a)
threshold[a] = weight(edge) + threshold_func(forest.nodes[a].size, const)
return remove_small_components(forest, sorted_graph, min_size)