-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathmodeling.py
335 lines (276 loc) · 10.9 KB
/
modeling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
"""
# Statistical model to extrapolate Covid-19 actives cases
Here we build the statistical model behind
https://covid19-dash.github.io/
The model is made by fitting a weighted least square on the log of the
number of cases over a window of the few last days.
This is a computational notebook: it is the actual code that is run to
build the model.
"""
# %%
# Use joblib, to speed up interactions
from joblib import Memory
mem = Memory(location='.')
# %%
# ## Load and plot the data
# %%
# Download the data
import data_input
data = mem.cache(data_input.get_data)()
# We model only the confirmed cases
confirmed = data['confirmed']
# %%
# First plot the time course of the most affected countries
last_day = confirmed.iloc[-1]
most_affected_countries = confirmed.columns[last_day.argsort()][::-1]
import matplotlib.pyplot as plt
ax = confirmed[most_affected_countries[:20]].plot(figsize=(12, 7))
ax.set_yscale('log')
ax.set_title("Log-scale plot of number of confirmed cases")
plt.legend(loc='best', ncol=3)
plt.tight_layout()
# %%
# ## Define our weighted window
#
# The windows that we use are weighted: we give more weight to the last
# day, and less to the days further away in time.
# %%
# Some functions to build normalized weighted windows
import numpy as np
def ramp_window(start=14, middle=7):
window = np.ones(start)
window[:middle] = np.arange(middle) / float(middle)
window /= window.sum()
return window
def exp_window(start=14, growth=1.1):
window = growth ** np.arange(start)
window /= window.sum()
return window
# %%
# To define the optimal window, specified below, we used a historical
# replay: we chose the window and the set of weights that predict been
# observations already seen from their past
window_size = 17
weighted_window = exp_window(start=window_size, growth=1.6)
plt.figure()
plt.plot(weighted_window)
plt.title('The weights over the last few days')
# %%
# # A simple model: fit the last few points
# %%
# the log of the confirmed counts in the last fortnight
last_fortnight = confirmed.iloc[-window_size:]
np.seterr(divide='ignore')
log_last_fortnight = np.log(last_fortnight)
log_last_fortnight[log_last_fortnight == -np.inf] = 0
ax = log_last_fortnight[most_affected_countries[:20]].plot(figsize=(12, 7))
ax.set_title('Log of the number of confirmed cases in the last fortnight')
plt.legend(loc='best', ncol=3)
plt.tight_layout()
# %%
# Our model-fitting routine: a weighted least square on the log of
# the confirmed counts
#
# The errors in the data are expected to be proportional to the value of
# the data: the more cases are present, the more tests are realized, and
# the more errors as well as the more cases are missed. This noise
# becomes additive after taking the log, and can then reasonnably be
# assumed Gaussian, which justifies the use of a squared loss.
import pandas as pd
import statsmodels.api as sm
def fit_on_window(data, window):
""" Fit the last window of the data
"""
window_size = len(window)
last_fortnight = data.iloc[-window_size:]
log_last_fortnight = np.log(last_fortnight)
log_last_fortnight[log_last_fortnight == -np.inf] = 0
design = pd.DataFrame({'linear': np.arange(window_size),
'const': np.ones(window_size)})
growth_rate = pd.DataFrame(data=np.zeros((1, len(data.columns))),
columns=data.columns)
predicted_cases = pd.DataFrame()
predicted_cases_lower = pd.DataFrame()
predicted_cases_upper = pd.DataFrame()
prediction_dates = pd.date_range(data.index[-window_size],
periods=window_size + 7)
for country in data.columns:
mod_wls = sm.WLS(log_last_fortnight[country].values, design,
weights=window, hasconst=True)
res_wls = mod_wls.fit()
growth_rate[country] = np.exp(res_wls.params.linear)
predicted_cases[country] = np.exp(res_wls.params.const +
res_wls.params.linear * np.arange(len(prediction_dates))
)
# 1st and 3rd quartiles in the confidence intervals
conf_int = res_wls.conf_int(alpha=.25)
# We chose to account only for error in growth rate, and not in
# baseline number of cases
predicted_cases_lower[country] = np.exp(res_wls.params.const +
conf_int[0].linear * np.arange(len(prediction_dates))
)
predicted_cases_upper[country] = np.exp(res_wls.params.const +
conf_int[1].linear * np.arange(len(prediction_dates))
)
predicted_cases = pd.concat(dict(prediction=predicted_cases,
lower_bound=predicted_cases_lower,
upper_bound=predicted_cases_upper),
axis=1)
predicted_cases['date'] = prediction_dates
predicted_cases = predicted_cases.set_index('date')
if window_size > 10:
# Don't show predictions more than 10 days ago
predicted_cases = predicted_cases.iloc[window_size - 10:]
return growth_rate, predicted_cases
# %%
# Fit it on the data
growth_rate, predicted_cases = fit_on_window(confirmed, weighted_window)
ax = growth_rate[most_affected_countries[:20]].T.plot(kind='barh',
legend=False)
ax.set_title('Estimated growth rate')
ax.axvline(1, color='.5')
plt.tight_layout()
# %%
# Display the estimated growth rates
pd.set_option('display.max_rows', 60)
growth_rate[most_affected_countries[:60]].T
# %%
# Plot our prediction
ax = last_fortnight[most_affected_countries[:10]].plot(figsize=(12, 7))
predicted_cases['prediction'][most_affected_countries[:10]].plot(
ax=ax, style='--')
predicted_cases['lower_bound'][most_affected_countries[:10]].plot(
ax=ax, style=':')
predicted_cases['upper_bound'][most_affected_countries[:10]].plot(
ax=ax, style=':')
plt.legend(loc=(.8, -1.3))
ax.set_yscale('log')
ax.set_title('Number of confirmed cases in the last fortnight and prediction')
# %%
# Save our results for the dashboard. We pickle a dict, because
# hierachical columns do not pickle right
import pickle
with open('predictions.pkl', 'wb') as out_file:
pickle.dump(dict(prediction=predicted_cases['prediction'],
lower_bound=predicted_cases['lower_bound'],
upper_bound=predicted_cases['upper_bound'],
),
out_file)
# %%
# --------
# Now an analysis to optimize the window.
#
# This takes longer and is left out from the notebook displayed on the
# website (modeling_short)
# %%
# Historical replay to estimate an error
def historical_replay(data, window, threshold=50, prediction_horizon=4):
""" Run the forecasting model in the past and measure how well it does.
Parameters
==========
data: dataframe
The dataframe of the cases across countries (columns) and
time (index)
window: 1d numpy array
The array of weights defining the window
threshold: number
Do not include a country in the evaluation if the
last observed data point has less cases than "threshold"
prediction_horizon: number
The number of points we consider in the future to compute the
error
"""
all_errors = list()
for i in range(len(window) + prediction_horizon, data.shape[0]):
past_data = data[:i]
# First, limit to countries with cases at the end that are more than
# threshold
past_data = past_data[past_data.columns[
(past_data.iloc[-prediction_horizon - 1:]
> threshold).all()]]
train_data = past_data[:-prediction_horizon]
test_data = past_data[-prediction_horizon:]
_, predicted_data = fit_on_window(train_data, window)
predicted_data = predicted_data['prediction']
# We now compute the mean absolute relative error
# Note that pandas' axis align magical matches the dates below
relative_error = ((test_data - predicted_data[-prediction_horizon:])
/ test_data)
relative_error = relative_error.abs().mean(axis=1)
all_errors.append(relative_error.reset_index()[0])
return np.mean(all_errors, axis=0)
# %%
# Calibrate the errors of our model for different windows
# %%
# First with ramp windows
errors_by_window = dict()
for start in range(8, 14):
for middle in range(2, start + 1):
window = ramp_window(start, middle)
window_name = f'Ramp, from -{start} to -{middle}'
errors = mem.cache(historical_replay)(confirmed, window)
errors_by_window[window_name] = (errors, start, middle)
# %%
# First we plot the errors are a function of prediction time
plt.figure()
for window_name, errors in errors_by_window.items():
plt.plot(errors[0], label=window_name)
plt.legend(loc='best')
plt.xlabel('Days to predict')
plt.ylabel('Relative absolute error')
plt.title('Errors as a function of time')
# %%
# Our conclusion from the above is that the shape of the error does not
# depend much on the window
# %%
# We now plot the error after 4 days as a function of window params
plt.figure()
error, start, middle = zip(*errors_by_window.values())
plt.scatter(start, middle, np.array(error)[:, 1])
plt.scatter(start, middle, s=300*np.array(error)[:, 1],
c=np.array(error)[:, 3], marker='o')
plt.colorbar()
plt.xlabel('start parameter')
plt.ylabel('middle parameter')
plt.title('Errors as a function of ramp window parameter')
# %%
# These results tell us that we want a ramp with a length of 10 and
# ramping all the way
# %%
# Now the exponential windows
errors_by_window = dict()
for start in range(12, 18):
for growth in [1.4, 1.5, 1.6, 1.7, 1.8, 1.9]:
window = exp_window(start, growth)
window_name = f'Exp, from -{start} with growth {growth}'
errors = mem.cache(historical_replay)(confirmed, window)
errors_by_window[window_name] = (errors, start, growth)
# %%
# First we plot the errors are a function of prediction time
plt.figure()
for window_name, errors in errors_by_window.items():
plt.plot(errors[0], label=window_name)
plt.legend(loc='best')
plt.xlabel('Days to predict')
plt.ylabel('Relative absolute error')
plt.title('Errors as a function of time')
# %%
# Our conclusion from the above is that the shape of the error does not
# depend much on the window
# %%
# We now plot the error after 4 days as a function of window params
plt.figure()
error, start, growth = zip(*errors_by_window.values())
plt.scatter(start, growth, np.array(error)[:, 1])
plt.scatter(start, growth, s=300*np.array(error)[:, 1],
c=np.array(error)[:, 3], marker='o')
plt.colorbar()
plt.xlabel('start parameter')
plt.ylabel('growth parameter')
plt.title('Errors as a function of exp window parameter')
# %%
# We see that longer windows are better, with 1.6 growth.
#
# We chose not to explore longer than 17 days because these very long
# windows only improve prediction slightly, but risk biasing it when
# there is a change in public policy.