-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathmain.py
347 lines (321 loc) · 16.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
import numpy as np
import random
from person import Person
import pygame
from data_collector import DataCollector
import json
# Different color models (only one right now)
color_models = {'SIR': {'susceptible': (204, 255, 204), 'infected': (255, 204, 204), 'recovered': (204, 204, 255)}}
# Different shape models (if you care about SD or WM more)
shape_models = {'SD': {True: 'circle', False: 'rect'},
'WM': {True: 'circle', False: 'rect'}}
# Policies that define overall safety level of the population
policies_safety = {
'very high': {'social_distance_prob': 0.75, 'wear_mask_prob': 0.75},
'high': {'social_distance_prob': 0.5, 'wear_mask_prob': 0.5},
'medium': {'social_distance_prob': 0.25, 'wear_mask_prob': 0.25},
'low': {'social_distance_prob': 0.10, 'wear_mask_prob': 0.10},
}
class CellularAutomation:
def __init__(self, constants, data_collect):
self.grid_C = constants['grid']
self.render_C = constants['render']
self.person_C = constants['person']
self.disease_C = constants['disease']
self.data_collect = data_collect
# Need two sets of ids (correspond uniquely to a person)
# 1) Those who practice social distancing
# 2) Those who do not practice social distancing
self.ids_social_distance = set()
self.ids_not_social_distance = set()
# The person objects
# id (int): person (object)
self.id_person = {}
# Grid stores the person IDs in a 2D structure
self.grid = np.empty(shape=(self.grid_C['height'], self.grid_C['width']), dtype=np.object)
self.next_id = 0
# The currently open positions (no person on it)
self.open_positions = []
for y in range(self.grid_C['height']):
for x in range(self.grid_C['width']):
self.open_positions.append((x, y))
# Initialize the grid
self._initialize_grid()
# Out of bounds
def _oob(self, x, y):
return x < 0 or y < 0 or x >= self.grid_C['width'] or y >= self.grid_C['height']
# Env wraps - so get correct pos
def _get_cell_pos(self, x, y):
def corrected(val, max_val):
if val < 0:
return max_val + val
elif val >= max_val:
return max_val - val
else:
return val
corrected_x = corrected(x, self.grid_C['width'])
corrected_y = corrected(y, self.grid_C['height'])
return corrected_x, corrected_y
# Is the cell empty
def _is_empty(self, x=None, y=None, position=None):
if x and y:
return self.grid[y, x] is None
return self.grid[position[1], position[0]] is None
# A person can die from the disease
def _kill_person(self, id, social_distance):
person = self.id_person[id]
if social_distance: self.ids_social_distance.remove(id)
else: self.ids_not_social_distance.remove(id)
position = person.position
self._clear_cell(position)
assert person.infected
self.data_collect.increment_death_data(person)
# todo: Not sure if I should keep this because if someone dies early then they dont really get a full infectious lifetime
# self.data_collect.add_lifetime_infected(person.num_people_infected, person.infectious_days_info)
del self.id_person[id]
def _clear_cell(self, position):
self.grid[position[1], position[0]] = None
self.open_positions.append(position)
def _add_to_cell(self, id, position):
assert self._is_empty(position=position)
self.id_person[id].position = position
self.grid[position[1], position[0]] = id
self.open_positions.remove(position)
def _move_person(self, id, person, new_position):
current_position = person.position
# Clear current position
self._clear_cell(current_position)
# Place person in new cell
self._add_to_cell(id, new_position)
person.set_position(new_position)
# Grid initialization ------
def _create_person(self, position):
assert self._is_empty(position=position)
age = np.random.randint(self.person_C['age_range'][0], self.person_C['age_range'][1]+1)
policy = policies_safety[self.person_C['policy_type']]
SD_prob = policy['social_distance_prob']
SD = True if np.random.random() < SD_prob else False
WM_prob = policy['wear_mask_prob']
WM = True if np.random.random() < WM_prob else False
infected = np.random.random() < self.person_C['initial_infection_prob']
if not infected: self.data_collect.increment_initial_S()
person = Person(position, age, SD, WM, self.person_C['movement_prob'], self.person_C['altruistic_movement_prob'],
self.person_C['altruistic_prob'], infected,
self.disease_C['total_length_infection'], self.disease_C['incubation_period_duration_range'],
self.disease_C['infectious_start_before_symptoms_range'],
self.disease_C['infectious_period_duration_range'],
self.disease_C['severe_symptoms_start_range'],
self.disease_C['death_occurrence_range'], self.disease_C['asymptomatic_prob'],
self.disease_C['severity_prob'], self.disease_C['death_prob'])
self.id_person[self.next_id] = person
data_collect.update_data(person)
if SD: self.ids_social_distance.add(self.next_id)
else: self.ids_not_social_distance.add(self.next_id)
self._add_to_cell(self.next_id, position)
self.next_id += 1
# Create all the people
def _initialize_grid(self):
for p in range(self.grid_C['initial_pop_size']):
# Select random position
position = random.choice(self.open_positions)
# Create person
self._create_person(position)
# Yield neighbors
# Return Neighbor (or None), neighbor_position absolute and relative
def _yield_neighbors(self, position, side_length):
x, y = position[0], position[1]
for i in range(side_length ** 2):
cluster_mid = side_length // 2
rel_x = (i % side_length) - cluster_mid
rel_y = (i // side_length) - cluster_mid
abs_x = x + rel_x
abs_y = y + rel_y
correct_x, correct_y = self._get_cell_pos(abs_x, abs_y)
located_id = self.grid[correct_y, correct_x] # Might be None if no person there
if located_id == None:
yield None, (correct_x, correct_y), (rel_x, rel_y)
else:
yield self.id_person[located_id], (correct_x, correct_y), (rel_x, rel_y)
# This decides movement and num of infected neighbors FOR SD people
def _check_neighbors_SD(self, id, person):
def check_neighbors(last_position=None):
safe_cells = {(-1, -1): None, (0, -1): None, (1, -1): None, (-1, 0): None, (1, 0): None, (-1, 1): None, (0, 1): None, (1, 1): None}
infected_neighbors = []
for neighbor, neighbor_pos, neighbor_pos_rel in self._yield_neighbors(person.position, 3):
if neighbor_pos == person.position:
continue
# Get abs pos
safe_cells[neighbor_pos_rel] = neighbor_pos
# Add to infected if infectious neighbor
if neighbor and neighbor.is_infectious():
infected_neighbors.append(neighbor)
# Remove from safe cell if cell contains a person or it was the last pos
if neighbor or last_position == neighbor_pos:
del safe_cells[neighbor_pos_rel]
return infected_neighbors, safe_cells
# First get number of infected around this and check if it gets infected
infected_neighbors, safe_cells = check_neighbors()
# Check if infected
self._check_infection(person, infected_neighbors)
# Then Moving
# Move it if its own cell is not safe OR its moving intenionally
move_length_SD = 1 if len(safe_cells) < 8 else self.person_C['move_length'] # Move only one time if just moving cuz its unsafe
if len(safe_cells) < 8 or np.random.random() < person.movement_prob:
for m in range(move_length_SD):
last_position = person.position
did_move = False
# Shuffle the safe positions and choose the first actually safe one by checking a 3x3 around it (not including the person.position)
for safe_cell_rel_pos in random.sample(list(safe_cells.keys()), len(list(safe_cells.keys()))):
safe_cell_abs_pos = safe_cells[safe_cell_rel_pos]
# Check if its safe (dont check the person.position)
safe = True
for neighbor, neighbor_pos, _ in self._yield_neighbors(safe_cell_abs_pos, 3):
if neighbor_pos == person.position:
continue
if neighbor:
safe = False
break
# First one that is safe: move there
if safe:
self._move_person(id, person, safe_cell_abs_pos)
infected_neighbors, safe_cells = check_neighbors(last_position)
self._check_infection(person, infected_neighbors)
did_move = True
break
# End if it did not move
if not did_move:
break
# If not SD then move if moving intentionally
def _check_neighbors_not_SD(self, id, person):
def check_neighbors(last_position=None):
empty_spots = []
infected_neighbors = []
for neighbor, neighbor_pos, _ in self._yield_neighbors(person.position, 3):
if neighbor_pos == person.position:
continue
# Add to infected if infectious neighbor
if neighbor and neighbor.is_infectious():
infected_neighbors.append(neighbor)
# Add empty spot (also if not the last position the person was at if moving more than once)
if not neighbor and neighbor_pos != last_position:
empty_spots.append(neighbor_pos)
return infected_neighbors, empty_spots
# First get number of infected around this and check if it gets infected
infected_neighbors, empty_spots = check_neighbors()
# Check if infected
self._check_infection(person, infected_neighbors)
# Then Moving
if np.random.random() < person.movement_prob:
for m in range(self.person_C['move_length']):
# if somewhere to move
if len(empty_spots) > 0:
new_spot = random.choice(empty_spots)
last_position = person.position
self._move_person(id, person, new_spot)
infected_neighbors, empty_spots = check_neighbors(last_position)
self._check_infection(person, infected_neighbors)
else:
break
def _check_infection(self, person, infected_neighbors):
newly_infected = person.gets_infected(infected_neighbors, self.disease_C['base_infection_prob'],
self.disease_C['mask_infection_prob_decrease'],
self.data_collect)
# If this person was just infected then add to the num of people infected to each neighbor for calc. Ro
if newly_infected:
for person in infected_neighbors:
person.num_people_infected += 1
def _update_person(self, id):
person = self.id_person[id]
# Progress Infection (if infected)
dead, new_SD = person.progress_infection(self.data_collect)
if dead:
self._kill_person(id, person.social_distance)
return None # Continue to next person
# At the start figure out where the person is going to move AND the number of infected persons around them
self._check_neighbors_SD(id, person) if person.social_distance else self._check_neighbors_not_SD(id, person)
return new_SD
# For rendering
def _get_person_color(self, person):
if self.render_C['color_model'] == 'SIR':
colors = color_models['SIR']
if person.susceptible: return colors['susceptible']
if person.infected: return colors['infected']
return colors['recovered']
def _render(self, id, screen):
person = self.id_person[id]
shape_model = self.render_C['shape_model']
if shape_model == 'SD':
shape = shape_models[shape_model][person.social_distance]
else:
shape = shape_models[shape_model][person.wear_mask]
color = self._get_person_color(person)
cell_size = self.render_C['cell_size']
if shape == 'rect':
pygame.draw.rect(screen, color,
[person.position[0] * cell_size, person.position[1] * cell_size, cell_size, cell_size])
else:
radius = cell_size // 2
center_x = (person.position[0] * cell_size) + radius
center_y = (person.position[1] * cell_size) + radius
pygame.draw.circle(screen, color, (center_x, center_y), radius)
def run(self, render=False):
if render:
# Initialize the game engine
pygame.init()
# Set the height and width and title of the screen
screen_width = self.render_C['cell_size'] * self.grid_C['width']
screen_height = self.render_C['cell_size'] * self.grid_C['height']
screen = pygame.display.set_mode((screen_width, screen_height))
pygame.display.set_caption("Population Dynamics")
clock = pygame.time.Clock()
# Initially set the screen to all black
screen.fill((0, 0, 0))
for t in range(self.grid_C['number_iterations']):
self.data_collect.reset(t)
def loop_through_ids(ids):
# Keep track of any switches between SD lists
new_SD_list = []
new_not_SD_list = []
# Shuffle - Random order
lis = list(ids)
random.shuffle(lis)
for id in lis:
new_SD = self._update_person(id)
# If dead then continue
if id not in self.id_person: continue
if new_SD is True: new_SD_list.append(id)
elif new_SD is False: new_not_SD_list.append(id)
# Update data collection
self.data_collect.update_data(self.id_person[id])
# Render
if render: self._render(id, screen)
return new_SD_list, new_not_SD_list
# Update (in random order) those who do NOT practice social distancing
new_SD, new_not_SD_list = loop_through_ids(self.ids_not_social_distance)
assert len(new_not_SD_list) == 0
# Next update (in random order) those who DO practice social distancing, so they get to be at a safe dist.
# from others at the end of the iteration
new_SD_list, new_not_SD = loop_through_ids(self.ids_social_distance)
assert len(new_SD_list) == 0
# Switch people
for id in new_SD:
self.ids_not_social_distance.remove(id)
self.ids_social_distance.add(id)
for id in new_not_SD:
self.ids_social_distance.remove(id)
self.ids_not_social_distance.add(id)
if render:
pygame.display.flip()
screen.fill((0, 0, 0))
# Frames per second
if self.render_C['fps']: clock.tick(self.render_C['fps'])
self.data_collect.reset(t+1, last=True)
if __name__ == '__main__':
constants = json.load(open('constants.json'))
# Can save a run as an experiment which saves the data, visualizations and constants in a experiments directory
data_collect = DataCollector(constants, save_experiment=True, print_visualizations=True)
# Can print data (look at `data_options` at top of `data_collector.py` for options) and how often to print
data_collect.set_print_options(basic_to_print=['S', 'I', 'R', 'death'], frequency=1)
CA = CellularAutomation(constants, data_collect)
# Can render each timestep with pygame
CA.run(render=True)