-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathkmean.py
28 lines (23 loc) · 831 Bytes
/
kmean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# -*- coding: utf-8 -*-
"""kmean.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/1uC70_PvR3z6aSvxNmalHBtpQT1WHScTr
"""
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
scaler = StandardScaler()
def clustering(emb):
temp = scaler.fit_transform(emb)
Y = TSNE(n_components=2).fit_transform(temp)
kmeans = KMeans(n_clusters=7,init = 'k-means++',n_init=20, max_iter=500,algorithm='elkan')
kmeans.fit(Y)
y_kmeans = kmeans.predict(Y)
plt.figure
plt.scatter(Y[:,0], Y[:, 1], c=y_kmeans, s=50, cmap='viridis')
centers = kmeans.cluster_centers_
plt.scatter(centers[:, 0], centers[:, 1], c='black', s=200, alpha=1)
plt.show()
return y_kmeans