-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgrid.py
202 lines (166 loc) · 7.34 KB
/
grid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import math, geopy
import numpy as np
import pandas as pd
import geopandas as gpd
import contextily as ctx
import matplotlib.pyplot as plt
# from functional import seq
from geopy.distance import distance
from shapely.geometry import Point, Polygon
# from spatialnet.spatialnet import Trajectory
def generate_hex_grid(bounds, w):
"""Get a hexagonal grid over the provided rectangular area.
Params
------
bounds : list or dict of floats (north, south, west, east)
Geo coordinates of input area bounding box
w : int
Length of a hex's long side in meters
"""
if isinstance(bounds, dict):
north = bounds['north']
south = bounds['south']
west = bounds['west']
east = bounds['east']
else:
north, south, west, east = bounds
# distance between centers of consecutive hexes in each axis
x_sep = w * 3/4
y_sep = math.sqrt(3) * w / 4
# get number of hex rows and columns required to cover input area
nw = geopy.Point(north, west)
x_size = distance(nw, geopy.Point(north, east)).meters
y_size = distance(nw, geopy.Point(south, west)).meters
n_rows = 1 + int(x_size / x_sep)
n_cols = 2 + int(y_size / y_sep)
print("\tConstructing hex grid with {} cells".format(n_rows*n_cols//2))
print('\trow='+str(n_rows) + ', cols='+str(n_cols))
# get cell geo coord offset from the start
# see stackoverflow.com/questions/24427828
def offset(row, col):
x = distance(meters=col*x_sep).destination(point=nw,bearing=90)
return distance(meters=row*y_sep).destination(point=x,bearing=180)
# get the coordinates of the 6 vertices as a Polygon given its center
def get_hex(c):
vertices = [distance(meters=w/2).destination(point=c,bearing=330),
distance(meters=w/2).destination(point=c,bearing=30),
distance(meters=w/2).destination(point=c,bearing=90),
distance(meters=w/2).destination(point=c,bearing=150),
distance(meters=w/2).destination(point=c,bearing=210),
distance(meters=w/2).destination(point=c,bearing=270)]
return Polygon([(p.longitude, p.latitude) for p in vertices])
# hex grid algorithm, see redblobgames.com/grids/hexagons/ for details
cells = {'cellID':[], 'geometry':[]}
for row in range(n_cols):
print(row)
# print("Cell {}/{} ({:.1f})%".format(
# row*n_cols//2, n_rows*n_cols//2, 100*row/n_rows), end="\r")
for col in range(n_rows):
print('\t',col)
# doubled coord system
if (row + col) % 2 == 0:
center = offset(row, col)
cells['cellID'].append('Hex_{}_{}'.format(row,col))
cells['geometry'].append(get_hex(center))
gdf = GridDataFrame(cells, crs="EPSG:4326")
# print(gdf.head())
# return GridDataFrame(cells, crs="EPSG:4326")
return gdf
class GridDataFrame(gpd.GeoDataFrame):
"""Handles construction and querying of geographical grid blocks."""
def __init__(self, *args, **kwargs):
"""Initialize GridDataFrame with grid cells as index."""
super().__init__(*args, **kwargs)
self.set_index('cellID', inplace=False)
# def get_cell_counts(self, trajectories):
# """Get the count of intersecting trajectories for each cell.
# Params
# ------
# trajectories: GeoDataFrame, Series, iter of Trajectories, LineStrings
# Collection of trajectory traces of moving objects.
# Returns
# -------
# pandas.Series of ints
# The counts of traces passing through grid cells, indexed by cell
# """
# # transform to Sequence of Linestrings
# if isinstance(trajectories, gpd.GeoDataFrame):
# ls_seq = seq(trajectories.geometry.to_list())
# elif isinstance(trajectories, pd.Series):
# ls_seq = seq(trajectories.to_list())
# else:
# trajectories = list(trajectories)
# if isinstance(trajectories[0], Trajectory):
# ls_seq = seq(trajectories).map(get_linestring)
# else:
# ls_seq = seq(trajectories)
# bool_seq = ls_seq.map(lambda l: self.geometry.intersects(l))
# int_seq = bool_seq.map(lambda s:s.astype(int))
# counts_series = pd.Series(int_seq.aggregate(lambda c, a: c + a),
# index=self.index)
# return counts_series
def plot_map(self, **kwargs):
"""Plot the grid block outlines with a map background.
Params
------
kwargs
Arguments for GeoDataFrame.plot()
"""
# plot cells (cells are in the GeoDataFrame 'self')
ax = self.plot(**kwargs)
# get OpenStreetMap background
ctx.add_basemap(ax, crs=self.crs.to_string(),
source=ctx.providers.OpenStreetMap.Mapnik)
def plot_heatmap(self, block_risk, dynamic=False):
"""Plot a heatmap of block risk values.
Params
------
block_risk : pandas.Series or list/iterable of floats
"""
# if given the whole dataframe, get the first column only
if isinstance(block_risk, pd.DataFrame):
b_risk_series = block_risk.iloc[:,0]
# make sure this is a series and index is grid cell blocks
elif not isinstance(block_risk, pd.Series):
b_risk_series = pd.Series(block_risk, index=self.index)
else:
b_risk_series = block_risk
# TODO visualization
# remember that the cells are in the GeoDataFrame 'self'
# also self.geometry.bounds is a thing
heatmap = pd.DataFrame(self)
# problem: not sure at which timestamp we should plot, now it's the first col
# I think we can either have dynamic plot or have an aggregate one.
# If aggregate, we just need to add up rows
col = len(block_risk.columns) if dynamic else 1
fig, ax = plt.subplots(1, figsize=(10, 6))
def set_format():
variable = 'risk'
heatmap.plot(column=variable, cmap='Reds', linewidth=0.8, ax=ax, edgecolor='0.8')
# add a title
ax.set_title('Risk Map', fontdict={'fontsize': '25', 'fontweight' : '3'})
# create an annotation for the data source
ax.annotate('Source: SUMO generator', xy=(.1, .08), xycoords='figure fraction',
horizontalalignment='left', verticalalignment='top',
fontsize=5, color='#555555')
# create colorbar as a legend
# problem: in dynamic plotting, colorbars keep showing up in each loop
if not dynamic:
sm = plt.cm.ScalarMappable(cmap='Reds', norm=plt.Normalize())
sm._A = []
cbar = fig.colorbar(sm)
i = 0
while(i<col): # loop for dynamic plotting
b_risk_series = block_risk.loc[:,i]
heatmap['risk'] = b_risk_series
heatmap = gpd.GeoDataFrame(heatmap)
set_format()
plt.tight_layout()
plt.draw()
plt.pause(0.1)
i += 1
plt.cla()
@classmethod
def from_file(cls, *args, **kwargs):
"""Read a GridDataFrame object from file."""
return GridDataFrame(super().from_file(*args, **kwargs))