-
Notifications
You must be signed in to change notification settings - Fork 3
/
onnx_to_tvm.py
176 lines (140 loc) · 5.92 KB
/
onnx_to_tvm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import os
import onnx
import numpy as np
import tvm
import tvm.relay as relay
import argparse
import pandas as pd
from tvm import autotvm
from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
from tvm.contrib.util import tempdir
import tvm.contrib.graph_runtime as runtime
parser = argparse.ArgumentParser(description='AutoTVM from ONNX checkpoints')
parser.add_argument('--model', default='resnet50', type=str)
parser.add_argument('--layer_info', default='resnet/layer_info.csv', help='.csv file generated by one of the *_to_onnx.py scripts')
parser.add_argument('--layer', default='', type=str)
parser.add_argument('--device_key', default='1080ti', type=str)
parser.add_argument('--opencl', action='store_true')
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--n_trials', default=1000, type=int)
parser.add_argument('--drop_until', default=0, type=int)
parser.add_argument('--benchmark', action='store_true')
parser.add_argument('--log_file', default='', type=str)
args = parser.parse_args()
if not args.opencl:
os.environ["CUDA_VISIBLE_DEVICES"]='1'
dtype = 'float32'
def get_network(filename, input_shape):
onnx_model = onnx.load(filename)
data = np.random.uniform(-1, 1, size=input_shape).astype("float32")
shape_dict = {'input' : data.shape}
sym, params = relay.frontend.from_onnx(onnx_model, shape_dict)
return sym, params
if args.cpu:
target = 'llvm -target=aarch64-linux-gnu'
target_host = 'llvm -target=aarch64-linux-gnu'
else:
if args.opencl:
target = tvm.target.create('opencl -device=mali')
target_host = 'llvm -target=aarch64-linux-gnu'
else:
target = tvm.target.cuda()
target_host = 'llvm'
ctx = tvm.gpu()
# You can skip the implementation of this function for this tutorial.
def tune_tasks(tasks,
measure_option,
tuner='xgb',
n_trial=1000,
early_stopping=None,
log_filename='tuning.log'):
# create tmp log file
tmp_log_file = log_filename + ".tmp"
if os.path.exists(tmp_log_file):
os.remove(tmp_log_file)
for i, tsk in enumerate(reversed(tasks)):
prefix = "\t[Task %2d/%2d] " %(i+1, len(tasks))
tuner_obj = XGBTuner(tsk, loss_type='rank')
# do tuning
tuner_obj.tune(n_trial=min(n_trial, len(tsk.config_space)),
early_stopping=early_stopping,
measure_option=measure_option,
callbacks=[
autotvm.callback.progress_bar(n_trial, prefix=prefix),
autotvm.callback.log_to_file(tmp_log_file)])
# pick best records to a cache file
autotvm.record.pick_best(tmp_log_file, log_filename)
os.remove(tmp_log_file)
def tune_and_evaluate():
df = pd.read_csv(args.layer_info)
df = df[df['filename']==args.layer]
filenames = df.filename
for net_fname in filenames:
print('Tuning: ', net_fname)
#### TUNING OPTION ####
log_file = "models/%s/logs/%s.log" % (args.model, args.log_file)
tuning_opt = {
'log_filename': log_file,
'n_trial': args.n_trials,
'measure_option': autotvm.measure_option(
builder=autotvm.LocalBuilder(timeout=10),
runner=autotvm.RPCRunner(
args.device_key,
'0.0.0.0', 9190,
number=20, repeat=3, timeout=4, min_repeat_ms=150)
),
}
in_c = int(df.loc[df.filename==net_fname, 'in_channels'])
in_x = int(df.loc[df.filename==net_fname, 'input_spatial_x'])
out_c = int(df.loc[df.filename==net_fname, 'out_channels'])
input_shape = (1,in_c,in_x,in_x)
print(input_shape)
# extract workloads from relay program
print("\tExtract tasks...")
net, params = get_network(net_fname, input_shape)
tasks = autotvm.task.extract_from_program(net['main'], target=target, target_host=target_host, params=params, ops=(relay.op.nn.conv2d,))
# run tuning tasks
print("\tTuning...")
tune_tasks(tasks, **tuning_opt)
def benchmark(log_file):
print(args.layer)
df = pd.read_csv(args.layer_info)
df = df[df['filename']==args.layer]
filenames = df.filename
in_c = int(df.loc[df.filename==args.layer, 'in_channels'])
in_x = int(df.loc[df.filename==args.layer, 'input_spatial_x'])
out_c = int(df.loc[df.filename==args.layer, 'out_channels'])
input_shape = (1,in_c,in_x,in_x)
net, params = get_network(args.layer, input_shape)
# compile kernels with history best records
with autotvm.apply_history_best(log_file):
print("\tCompile...")
with relay.build_config(opt_level=3):
graph, lib, params = relay.build_module.build(
net, target=target, target_host=target_host, params=params)
# export library
tmp = tempdir()
filename = "net.tar"
lib.export_library(tmp.relpath(filename))
# upload module to device
print("\tUpload...")
remote = autotvm.measure.request_remote(args.device_key, '0.0.0.0', 9190,
timeout=10000)
remote.upload(tmp.relpath(filename))
rlib = remote.load_module(filename)
# upload parameters to device
ctx = remote.context(str(target), 0)
module = runtime.create(graph, rlib, ctx)
data_tvm = tvm.nd.array((np.random.uniform(size=input_shape)).astype(dtype))
module.set_input('0', data_tvm)
module.set_input(**params)
# evaluate
print("\tEvaluate inference time cost...")
ftimer = module.module.time_evaluator("run", ctx, number=1, repeat=30)
prof_res = np.array(ftimer().results) * 1000 # convert to millisecond
print("\tMean inference time (std dev): %.8f ms (%.2f ms)" %
(np.mean(prof_res), np.std(prof_res)))
if args.benchmark:
benchmark(args.log_file)
else:
tune_and_evaluate()