-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
33 lines (25 loc) · 924 Bytes
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
if __name__ == '__main__':
import json
from xgen.xgen_run import xgen
from train_script_main import training_main
json_path = 'unet_config/xgen.json'
#if you are using new config
from xgen.utils.args_ai_map import get_old_config
old_json_path = 'unet_config/xgen.json'
with open(json_path) as f:
new = json.load(f)
old = get_old_config(new)
with open(old_json_path, 'w') as f:
json.dump(old, f)
# using old patn instead of the new version
json_path = old_json_path
# json_path = 'args_ai_template_sgpu.json'
def run(onnx_path, quantized, pruning, output_path, **kwargs):
import random
res = {}
# for simulation
pr = kwargs['sp_prune_ratios']
res['output_dir'] = output_path
res['latency'] = 50
return res
xgen(training_main, run, xgen_config_path=json_path, xgen_mode='customization')