-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathdata_collector.py
230 lines (212 loc) · 11.8 KB
/
data_collector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import os
from datetime import datetime
import json
import pandas as pd
data_options = ['S', 'I', 'R', 'WM', 'SD', 'death', 'mild', 'severe', 'asymptomatic']
# SAR -> Secondary Attack Rate = total # infected people / total # susceptible (overall metric, ie calculated at end)
# R0 -> Basic Reproductive Number = The number of people an infected person directly infects
# R0S -> R0 x S -> If > 1 then can multiply, = 1 then can become endemic (persistent but tame), < 1 then can die off
advanced_equations = ['SAR', 'R0', 'R0S']
class DataCollector:
def __init__(self, constants, save_experiment, print_visualizations):
self.constants = constants
self.save_experiment = save_experiment
self.print_visualizations = print_visualizations
self.basic_to_print = None
self.adv_to_print = None
self.frequency_print = 1
self._reset_data_options(hist=True)
# Advanced Infection data collection
# WM, SD, both, neither, total
self.adv_infection_data = {'total': 0, 'SD': 0, 'not SD': 0}
self.adv_infection_data_history = OrderedDict({'total': [], 'SD': [], 'not SD': []})
# For adv equations
# For SAR (Secondary Attack Rate) need total number of infected overtime
self.total_infected = 0
# And need number of S not including initial infected
self.initial_S = 0
# For R0 need the current number of each infection lifetime for the current bin
self.lifetime_infected_bin_size = 5
self.current_bin_lifetime_infected = []
# Saves all the bin averages
self.lifetime_infected_bin_avgs = OrderedDict()
self.last_bin_avgs = {'total': None, 'SD': None, 'not SD': None, 'WM': None, 'not WM': None, 'both': None, 'neither': None}
def _reset_data_options(self, hist=False):
self.current_data = {}
if hist: self.data_history = OrderedDict()
for k in data_options:
self.current_data[k] = 0
if hist: self.data_history[k] = []
def set_print_options(self, basic_to_print='all', adv_to_print='all', frequency=1):
self.basic_to_print = data_options if basic_to_print == 'all' else basic_to_print
self.adv_to_print = advanced_equations if adv_to_print == 'all' else adv_to_print
self.frequency_print = frequency
def increment_total_infected(self):
self.total_infected += 1
def increment_initial_S(self):
self.initial_S += 1
def _update_adv_infection_data(self, person):
SD = person.social_distance
not_SD = not SD
self.adv_infection_data['total'] += 1
if SD: self.adv_infection_data['SD'] += 1
if not_SD: self.adv_infection_data['not SD'] += 1
def update_data(self, person):
self.current_data['S'] += person.susceptible
self.current_data['I'] += person.infected
if person.infected:
self._update_adv_infection_data(person)
self.current_data['R'] += person.recovered
self.current_data['WM'] += person.wear_mask
self.current_data['SD'] += person.social_distance
if person.current_symptom_stage == 'mild':
self.current_data['mild'] += 1
elif person.current_symptom_stage == 'severe':
self.current_data['severe'] += 1
elif person.current_symptom_stage == 'asymptomatic':
self.current_data['asymptomatic'] += 1
def increment_death_data(self, person):
self.current_data['death'] += 1
def add_lifetime_infected(self, num_infected, infectious_days_info):
# Bin infectious_days_info into majority SD, minority SD, majority WM, minority WM (ie did they SD more often then not)
SD = infectious_days_info['SD'] > infectious_days_info['not SD']
WM = infectious_days_info['WM'] > infectious_days_info['not WM']
both = SD and WM
neither = not SD and not WM
self.current_bin_lifetime_infected.append({'total': num_infected, 'SD': SD, 'not SD': not SD, 'WM': WM,
'not WM': not WM, 'both': both, 'neither': neither})
def reset(self, timestep, last=False):
# Aggregate history data
for key, value in list(self.current_data.items()):
self.data_history[key].append(value)
for key, value in list(self.adv_infection_data.items()):
self.adv_infection_data_history[key].append(value)
self.adv_infection_data[key] = 0
# If bin is done in lifetime infected get avg and empty bin
if timestep % self.lifetime_infected_bin_size == 0 and timestep != 0:
self.lifetime_infected_bin_avgs[timestep] = {}
# If no one infected recovered/died them move on
if len(self.current_bin_lifetime_infected) == 0:
# Set to the last avgs initially and if new ones then set them
for k, last_avg in list(self.last_bin_avgs.items()):
self.lifetime_infected_bin_avgs[timestep][k] = last_avg
else:
for k in list(self.last_bin_avgs.keys()):
bin_arr = [dic['total'] for dic in self.current_bin_lifetime_infected if dic[k]]
if len(bin_arr) == 0: # No people with that bin type
self.lifetime_infected_bin_avgs[timestep][k] = self.last_bin_avgs[k]
continue
bin_avg = sum(bin_arr) / len(bin_arr)
self.lifetime_infected_bin_avgs[timestep][k] = bin_avg
self.last_bin_avgs[k] = bin_avg
self.current_bin_lifetime_infected = []
# Print
if timestep % self.frequency_print == 0 and (self.basic_to_print or self.adv_to_print):
st = 'At timestep: {} --- '.format(timestep)
if self.basic_to_print:
for i, val in enumerate(self.basic_to_print):
st += '{}: {}'.format(val, self.current_data[val])
if i != len(self.basic_to_print)-1:
st += ' --- '
if self.adv_to_print:
if timestep in self.lifetime_infected_bin_avgs:
total_bin_avg = self.lifetime_infected_bin_avgs[timestep]['total']
if 'R0' in self.adv_to_print and total_bin_avg != None:
st += '\nBasic Reproduction Number (R0): {:.02f}'.format(total_bin_avg)
if 'R0S' in self.adv_to_print and total_bin_avg != None:
st += '\nR0S: {:.02f} x {} = {:.02f}'.format(total_bin_avg, self.current_data['S'], total_bin_avg * self.current_data['S'])
print(st)
# Reset data
self._reset_data_options()
# If last print advanced equations
if last:
SAR = self.total_infected / self.initial_S
if 'SAR' in self.adv_to_print:
print('Secondary Attack Rate (SAR): {} / {} = {:.02f}'.format(self.total_infected, self.initial_S, SAR))
# Convert the lifetime infected bin avgs to a a dict of lists and a list for the x-vals
self.R0_hist = {'total': [], 'SD': [], 'WM': [], 'not SD': [], 'not WM': [], 'both': [], 'neither': []}
self.R0S_hist = {'total': [], 'SD': [], 'WM': [], 'not SD': [], 'not WM': [], 'both': [], 'neither': []}
self.R0_xvals = []
for x_val, info in list(self.lifetime_infected_bin_avgs.items()):
self.R0_xvals.append(x_val)
S = self.data_history['S'][x_val]
for k, y_val in list(info.items()):
if not y_val: y_val = np.nan
R0 = y_val
R0S = S * y_val
self.R0_hist[k].append(R0)
self.R0S_hist[k].append(R0S)
# Visualizations
fig, axs = plt.subplots(2, 2, figsize=(15, 10))
# Infections
I_xvals = list(range(len(self.adv_infection_data_history['total'])))
axs[0, 0].plot(I_xvals, self.adv_infection_data_history['total'], 'C0', label='total')
axs[0, 0].plot(I_xvals, self.adv_infection_data_history['SD'], 'C2', label='SD')
axs[0, 0].plot(I_xvals, self.adv_infection_data_history['not SD'], 'C3', label='not SD')
axs[0, 0].set_title('Infections based on SD')
axs[0, 0].legend(loc="upper left")
# R0
R0_xvals = self.R0_xvals
axs[1, 0].plot(R0_xvals, self.R0_hist['total'], 'C0', label='total')
axs[1, 0].plot(R0_xvals, self.R0_hist['SD'], 'C2', label='SD')
axs[1, 0].plot(R0_xvals, self.R0_hist['not SD'], 'C3', label='not SD')
axs[1, 0].plot(R0_xvals, self.R0_hist['WM'], 'C4', label='WM')
axs[1, 0].plot(R0_xvals, self.R0_hist['not WM'], 'C1', label='not WM')
axs[1, 0].plot(R0_xvals, self.R0_hist['both'], 'C5', label='both')
axs[1, 0].plot(R0_xvals, self.R0_hist['neither'], 'C6', label='neither')
axs[1, 0].set_title('R0 based on SD and WM')
axs[1, 0].legend(loc="upper left")
# R0S
axs[1, 1].plot(R0_xvals, self.R0S_hist['total'], 'C0', label='total')
axs[1, 1].plot(R0_xvals, self.R0S_hist['SD'], 'C2', label='SD')
axs[1, 1].plot(R0_xvals, self.R0S_hist['not SD'], 'C3', label='not SD')
axs[1, 1].plot(R0_xvals, self.R0S_hist['WM'], 'C4', label='WM')
axs[1, 1].plot(R0_xvals, self.R0S_hist['not WM'], 'C1', label='not WM')
axs[1, 1].plot(R0_xvals, self.R0S_hist['both'], 'C5', label='both')
axs[1, 1].plot(R0_xvals, self.R0S_hist['neither'], 'C6', label='neither')
axs[1, 1].set_title('R0S based on SD and WM')
axs[1, 1].legend(loc="upper left")
# Save
if self.save_experiment:
# Create new directory (name of current date and time)
now = datetime.now()
dt_string = now.strftime("%d-%m-%Y_%H-%M-%S")
sub_dir = os.path.join('experiments', dt_string)
new_dir = os.path.join(os.getcwd(), sub_dir)
os.mkdir(new_dir)
# Save constants
constants_file = os.path.join(sub_dir, 'constants.json')
json.dump(self.constants, open(constants_file, 'w'), indent=4)
# Save visualizations
figure_file = os.path.join(sub_dir, 'plots.png')
plt.savefig(figure_file)
# Save data as .csv and txt
# Basic
basic_data_file = os.path.join(sub_dir, 'basic_data.csv')
self.data_history['timestep'] = I_xvals
basic_data_df = pd.DataFrame(data=self.data_history)
basic_data_df.to_csv(basic_data_file, index=False)
# Advanced infection
adv_I_file = os.path.join(sub_dir, 'infection_data.csv')
self.adv_infection_data_history['timestep'] = I_xvals
adv_I_df = pd.DataFrame(data=self.adv_infection_data_history)
adv_I_df.to_csv(adv_I_file, index=False)
# R0
R0_file = os.path.join(sub_dir, 'R0_data.csv')
self.R0_hist['timestep'] = self.R0_xvals
R0_df = pd.DataFrame(data=self.R0_hist)
R0_df.to_csv(R0_file, index=False)
# R0S
R0S_file = os.path.join(sub_dir, 'R0S_data.csv')
self.R0S_hist['timestep'] = self.R0_xvals
R0S_df = pd.DataFrame(data=self.R0S_hist)
R0S_df.to_csv(R0S_file, index=False)
# Save SAR to txt file
SAR_file = os.path.join(sub_dir, 'SAR.txt')
with open(SAR_file, 'w') as f:
f.write(str(SAR))
if self.print_visualizations:
plt.show()