-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
45 lines (40 loc) · 1.15 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
# Loading the data and visualizing it.
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def load_data():
df = pd.read_csv("https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data",
header=None, sep='\s+')
df.columns = ['CRIM','ZN','INDUS',
'CHAS','NOX','RM',
'AGE','DIS','RAD',
'TAX','PTRATIO','B',
'LSTAT','MEDV']
return df
def visualize_scatterplot(df):
sns.set(style='whitegrid', context='notebook')
cols = ['LSTAT', 'INDUS', 'NOX', 'RM', 'MEDV']
sns.pairplot(df[cols], size=2.5);
# plt.show()
if not os.path.exists(os.path.join(os.getcwd(), 'figures')):
os.mkdir('figures')
plt.show()
# plt.savefig('./figures/scatterplot.png')
plt.gcf().clear() # to ensure the canvas is clear for next plot
def corr_heatmap(df):
cols = ['LSTAT', 'INDUS', 'NOX', 'RM', 'MEDV']
cm = np.corrcoef(df[cols].values.T)
sns.set(font_scale=1.5)
hm = sns.heatmap(cm,
cbar=True,
annot=True,
square=True,
fmt='.2f',
annot_kws={'size': 15},
yticklabels=cols,
xticklabels=cols)
plt.show()
# plt.savefig('./figures/correlation-heatmap.png')
plt.gcf().clear()