-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathquickshift_fromNick_Ol.py
254 lines (226 loc) · 8.77 KB
/
quickshift_fromNick_Ol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
"""
Quick shift implementation based on the article Quick Shift and Kernel
Methods for Mode Seeking by A. Vedaldi and S. Soatto, 2008.
All points are connected into a single tree where the root is the point
whith maximal estimated density. Thus, we need at final a threshold parameter
tau, to break the branches that are longer than tau.
Complexity in O(n_features*n_samples**2).
Any contribution is welcomed
"""
# Author: Clement Nicolle <clement.nicolle@student.ecp.fr>
from __future__ import division
import numpy as np
from sklearn.cluster import estimate_bandwidth
from sklearn.metrics.pairwise import pairwise_distances
import matplotlib.pyplot as plt
def compute_distance_matrix(data, metric):
"""Compute the distance between each pair of points.
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Input points.
metric : string
Metric used to compute the distance. See pairwise_distances doc to
look at all the possible values.
Returns
-------
distance_matrix : array-like, shape=[n_samples, n_samples]
Distance between each pair of points.
"""
return pairwise_distances(data, metric=metric)
def compute_weight_matrix(dist_matrix, window_type, bandwidth):
"""Compute the weight of each pair of points, according to the window
chosen.
Parameters
----------
dist_matrix : array-like, shape=[n_samples, n_samples]
Distance matrix.
window_type : string
Type of window to compute the weights matrix. Can be
"flat" or "normal".
bandwidth : float
Value of the bandwidth for the window.
Returns
-------
weight_matrix : array-like, shape=[n_samples, n_samples]
Weight for each pair of points.
"""
if window_type == 'flat':
# 1* to convert boolean in int
weight_matrix = 1*(dist_matrix <= bandwidth)
elif window_type == 'normal':
weight_matrix = np.exp(-dist_matrix**2 / (2 * bandwidth**2))
else:
raise ValueError("Unknown window type")
return weight_matrix
def compute_medoids(dist_matrix, weight_matrix, tau):
"""For each point, compute the associated medoid.
Parameters
----------
dist_matrix : array-like, shape=[n_samples, n_samples]
Distance matrix.
weight_matrix : array-like, shape=[n_samples, n_samples]
Weight for each pair of points.
tau : float
Threshold parameter. Distance should not be over tau so that two points
may be connected to each other.
Returns
-------
medoids : array, shape=[n_samples]
i-th value is the index of the medoid for i-th point.
"""
P = sum(weight_matrix)
# P[i,j] = P[i] - P[j]
P = P[:, np.newaxis] - P
dist_matrix[dist_matrix == 0] = tau/2
S = np.sign(P) * (1/dist_matrix) # pointwise product
S[dist_matrix > tau] = -1
# new medoid for point j highest coef in the j-th column of S
return np.argmax(S, axis=0)
def compute_stationary_medoids(data, tau, window_type, bandwidth, metric):
"""Return the indices of the own medoids.
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Input points.
tau : float
Threshold parameter. Distance should not be over tau so that two points
may be connected to each other.
window_type : string
Type of window to compute the weights matrix. Can be
"flat" or "normal".
bandwidth : float
Value of the bandwidth for the window.
metric : string
Metric used to compute the distance. See pairwise_distances doc to
look at all the possible values.
Returns
-------
medoids : array, shape=[n_samples]
i-th value is the index of the medoid for i-th point.
stationary_pts : array, shape=[n_stationary_pts]
Indices of the points which are their own medoids.
"""
dist_matrix = compute_distance_matrix(data, metric)
weight_matrix = compute_weight_matrix(dist_matrix, window_type, bandwidth)
medoids = compute_medoids(dist_matrix, weight_matrix, tau)
stationary_idx = []
for i in range(len(medoids)):
if medoids[i] == i:
stationary_idx.append(i)
return medoids, np.asarray(stationary_idx)
def quick_shift(data, tau, window_type, bandwidth, metric):
"""Perform medoid shiftclustering of data with corresponding parameters.
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Input points.
tau : float
Threshold parameter. Distance should not be over tau so that two points
may be connected to each other.
window_type : string
Type of window to compute the weights matrix. Can be
"flat" or "normal".
bandwidth : float
Value of the bandwidth for the window.
metric : string
Metric used to compute the distance. See pairwise_distances doc to
look at all the possible values.
Returns
-------
cluster_centers : array, shape=[n_clusters, n_features]
Coordinates of cluster centers.
labels : array, shape=[n_samples]
Cluster labels for each point.
cluster_centers_idx : array, shape=[n_clusters]
Index in data of cluster centers.
"""
if tau is None:
tau = estimate_bandwidth(data)
if bandwidth is None:
bandwidth = estimate_bandwidth(data)
medoids, cluster_centers_idx = compute_stationary_medoids(data, tau,
window_type,
bandwidth,
metric)
cluster_centers = data[cluster_centers_idx]
labels = []
labels_val = {}
lab = 0
for i in cluster_centers_idx:
labels_val[i] = lab
lab += 1
for i in range(len(data)):
next_med = medoids[i]
while next_med not in cluster_centers_idx:
next_med = medoids[next_med]
labels.append(labels_val[next_med])
return cluster_centers, np.asarray(labels), cluster_centers_idx
def visualize2D(data, labels, clusters_centers_idx):
"""Plot clustering result if points in 2D
Parameters
----------
data : array-like, shape=[n_samples, n_features]
Input points.
labels : array, shape=[n_samples]
Cluster labels for each point.
cluster_centers_idx : array, shape=[n_clusters]
Index in data of cluster centers.
"""
n_samples = len(data)
K = len(clusters_centers_idx)
colors = []
# generate random colors vector :
for i in range(K):
colors.append('#%06X' % np.random.randint(0, 0xFFFFFF))
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
for i in range(0, n_samples):
cluster = int(labels[i])
ax.scatter(data[i, 0], data[i, 1], color=colors[cluster])
for j in range(0, K):
ax.scatter(data[clusters_centers_idx[j], 0],
data[clusters_centers_idx[j], 1],
color='k', marker='x', s=100)
# clusters centers as large black X
class QuickShift():
""" Compute the Quick shift algorithm with flat or normal window
data : array-like, shape=[n_samples, n_features]
Input points.
tau : float
Threshold parameter. Distance should not be over tau so that two points
may be connected to each other.
window_type : string
Type of window to compute the weights matrix. Can be
"flat" or "normal".
bandwidth : float
Value of the bandwidth for the window.
metric : string
Metric used to compute the distance. See pairwise_distances doc to
look at all the possible values.
Attributes
----------
cluster_centers_ : array, [n_clusters, n_features]
Coordinates of cluster centers.
labels_ :
Labels of each point
cluster_centers_idx_ : array, shape=[n_clusters]
Index in data of cluster centers.
"""
def __init__(self, tau=None, bandwidth=None,
window_type="flat", metric="euclidean"):
self.tau = tau
self.bandwidth = bandwidth
self.window_type = window_type
self.metric = metric
def fit(self, data):
"""Perform clustering.
Parameters
-----------
data : array-like, shape=[n_samples, n_features]
Samples to cluster.
"""
self.cluster_centers_, self.labels_, self.cluster_centers_idx_ = \
quick_shift(data, tau=self.tau, window_type=self.window_type,
bandwidth=self.bandwidth, metric=self.metric)
return self