-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcityflow_eval
68 lines (60 loc) · 2.92 KB
/
cityflow_eval
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
#!/usr/bin/env python3
import os
import torch
import argparse
from functools import partial
from easydict import EasyDict
from ding.config import compile_config
from ding.policy import create_policy
from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import InteractionSerialEvaluator
from ding.utils.default_helper import set_pkg_seed, deep_merge_dicts
from smartcross.utils.config_utils import read_ding_config
from smartcross.policy import FixedPolicy
def main(args, seed=None):
ding_cfg = args.ding_cfg
main_config, create_config = read_ding_config(ding_cfg)
cityflow_env_config = {'config_path': args.env_cfg}
if args.policy_type == 'fix':
create_config.policy['type'] = 'smartcross_fix'
main_config.env = deep_merge_dicts(main_config.env, cityflow_env_config)
cfg = compile_config(main_config, create_cfg=create_config, seed=seed, auto=True, save_cfg=False)
if args.env_num > 0:
cfg.env.evaluator_env_num = args.env_num
if cfg.env.n_evaluator_episode < args.env_num:
cfg.env.n_evaluator_episode = cfg.env.evaluator_env_num
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
env = env_fn(cfg=evaluator_env_cfg[0])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
if seed is not None:
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
if args.policy_type == 'fix':
policy = FixedPolicy(evaluator_env.action_space)
else:
policy = create_policy(cfg.policy, enable_field=['eval']).eval_mode
if args.ckpt_path is not None:
state_dict = torch.load(args.ckpt_path, map_location='cpu')
policy.load_state_dict(state_dict)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator,
evaluator_env,
policy,
)
_, eval_reward = evaluator.eval(None, -1, -1, cfg.env.n_evaluator_episode)
eval_reward = [r['final_eval_reward'].item() for r in eval_reward]
print('Eval is over! The performance is {}'.format(eval_reward))
evaluator.close()
return eval_reward
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='DI-smartcross training script')
parser.add_argument('-d', '--ding-cfg', default=None, help='DI-engine configuration path')
parser.add_argument('-e', '--env-cfg', required=True, help='sumo environment configuration path')
parser.add_argument('-s', '--seed', default=None, type=int, help='random seed for sumo')
parser.add_argument(
'-p', '--policy-type', default='dqn', choices=['fix', 'dqn', 'ppo'], help='RL policy type'
)
parser.add_argument('-n', '--env-num', type=int, default=1, help='sumo env num for evaluation')
parser.add_argument('-c', '--ckpt-path', type=str, default=None, help='model ckpt path')
args = parser.parse_args()
main(args, args.seed)