-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathdata_input.py
156 lines (126 loc) · 4.77 KB
/
data_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Data massaging: prepare the data so that it is easy to plot it.
"""
import os
import pickle
import pandas as pd
from fetcher import fetch_john_hopkins_data
def tidy_most_recent(df, column='confirmed'):
df = df[column].reset_index().melt(id_vars='date')
date_max = df['date'].max()
df = df.query("date == @date_max")
return df.sort_values('iso')
MAP_UNMATCHED_COUNTRIES = {
'Bahamas, The': "Bahamas",
'The Bahamas': "Bahamas",
'Congo (Kinshasa)': "Democratic Republic of the Congo",
"Congo, the Democratic Republic of the":
"Democratic Republic of the Congo",
'Congo (Brazzaville)': 'Republic of the Congo',
'Cape Verde': 'Cabo Verde',
"Czech Republic": 'Czechia',
"Cote d'Ivoire": 'Ivory Coast',
"Côte d'Ivoire": 'Ivory Coast',
"Swaziland": 'Eswatini',
'The Gambia': 'Gambia',
'Gambia, The': 'Gambia',
'Hong Kong SAR': "Hong Kong",
'Holy See': "Italy",
"Iran, Islamic Republic of": 'Iran',
'Iran (Islamic Republic of)': "Iran",
'Korea, South': "South Korea",
'Republic of Korea': "South Korea",
'Macau': 'Macao',
'Macao SAR': 'Macao',
'Mainland China': 'China',
"Moldova, Republic of": 'Moldova',
'Republic of Moldova': "Moldova",
'Republic of Ireland': 'Ireland',
"Macedonia, the former Yugoslav Republic of": 'North Macedonia',
"Réunion": 'Reunion',
"Russian Federation": "Russia",
"St. Martin": "Saint Martin",
'Taiwan*': "Taiwan",
'Taipei and environs': "Taiwan",
'Timor-Leste': 'Timor Leste',
'East Timor': 'Timor Leste',
'US': "United States",
'UK': 'United Kingdom',
'North Ireland': 'United Kingdom',
'occupied Palestinian territory': "Palestinian Territory",
'Palestine': "Palestinian Territory",
'Viet Nam': "Vietnam",
'Vatican City': 'Vatican',
}
UNMATCHED_COUNTRIES = ['Cruise Ship', 'Others', 'Channel Islands']
def get_data():
""" Download the data and return it as a 'wide' data frame
"""
df = fetch_john_hopkins_data()
# The number of reported cases per day, country, and type
df_day = df.groupby(['country_region', 'iso', 'date', 'type']).sum()
# Switch to wide format (time series)
data = df_day.pivot_table(values='cases',
columns=['type', 'iso', 'country_region'],
index=['date'])
data = data.fillna(method='ffill')
data = data.fillna(value=0)
return data
def exec_full(filepath):
""" Execute a Python file as a script
"""
global_namespace = {
"__file__": filepath,
"__name__": "__main__",
}
with open(filepath, 'rb') as file:
exec(compile(file.read(), filepath, 'exec'), global_namespace)
def get_all_data():
""" Retrieve both the actual data and the predictions from our model.
"""
df = get_data() # all data
if not os.path.exists('predictions.pkl'):
print('Running the model')
exec_full('modeling.py')
with open('predictions.pkl', 'rb') as f_pkl:
df_prediction = pickle.load(f_pkl)
# MultiIndex does not pickle, hence we need to rebuild it
for p in df_prediction.values():
p.columns = pd.MultiIndex.from_tuples(p.columns,
names=('iso', 'country'))
return df, df_prediction
def get_populations():
""" Load the information that we have about countries """
pop = pd.read_csv('data/countryInfo.txt', sep='\t', skiprows=50)
return pop
def normalize_by_population(tidy_df):
""" Normalize by population the column "value" of a dataframe with
lines being the country ISO
"""
pop = get_populations()
normalized_values = (tidy_df.set_index('iso')['value']
/ pop.set_index('ISO3')['Population'])
# NAs appeared because we don't have data for all entries of the pop
# table
normalized_values = normalized_values.dropna()
assert len(normalized_values) == len(tidy_df),\
("Not every country in the given dataframe was found in our "
"database of populations")
return normalized_values
def normalize_by_population_wide(df):
""" Normalize by population the columns of a dataframe with
column names being the country iso
"""
pop = get_populations()
# Grap a series, indexed by "iso"
pop = pop.rename(dict(ISO3='iso'), axis=1).set_index('iso')['Population']
# Use the ".div" for the divison because it support explicit
# alignement
normalized_df = df.div(pop, level='iso', axis=1)
return normalized_df
if __name__ == "__main__":
# Basic code to check that we can still do the entity matching
# between the different databases
df = get_data()
tidy_df = tidy_most_recent(df)
normalized_values = normalize_by_population(tidy_df)