-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_actor_critic_opt.py
104 lines (87 loc) · 3.57 KB
/
train_actor_critic_opt.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from tqdm import tqdm
from time import time
import torch
from torch.utils.data import DataLoader
import argparse
import numpy as np
import os
from model.agent import *
from model.policy import *
from model.critic import *
from model.buffer import *
from env import *
import utils
if __name__ == '__main__':
# initial args
init_parser = argparse.ArgumentParser()
init_parser.add_argument('--env_class', type=str, required=True, help='Environment class.')
init_parser.add_argument('--policy_class', type=str, required=True, help='Policy class')
init_parser.add_argument('--critic_class', type=str, required=True, help='Critic class')
init_parser.add_argument('--agent_class', type=str, required=True, help='Learning agent class')
init_parser.add_argument('--buffer_class', type=str, required=True, help='Buffer class.')
initial_args, _ = init_parser.parse_known_args()
print(initial_args)
envClass = eval('{0}.{0}'.format(initial_args.env_class))
policyClass = eval('{0}.{0}'.format(initial_args.policy_class))
criticClass = eval('{0}.{0}'.format(initial_args.critic_class))
agentClass = eval('{0}.{0}'.format(initial_args.agent_class))
bufferClass = eval('{0}.{0}'.format(initial_args.buffer_class))
# experimental control args
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=11, help='random seed')
parser.add_argument('--cuda', type=int, default=-1, help='cuda device number; set to -1 (default) if using cpu')
parser.add_argument('--w', type=float, default=0.1, help='w for potential')
# customized args
parser = envClass.parse_model_args(parser)
parser = policyClass.parse_model_args(parser)
parser = criticClass.parse_model_args(parser)
parser = agentClass.parse_model_args(parser)
parser = bufferClass.parse_model_args(parser)
args, _ = parser.parse_known_args()
if args.cuda >= 0 and torch.cuda.is_available():
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
torch.cuda.set_device(args.cuda)
device = f"cuda:{args.cuda}"
else:
device = "cpu"
args.device = device
utils.set_random_seed(args.seed)
print(f"Device: {device}")
# Environment
print("Loading environment")
env = envClass(args)
# Policy, Critic, Buffer, Agent
print("Setup policy:")
policy = policyClass(args, env)
policy.to(device)
print(policy)
print("Setup critic:")
if initial_args.agent_class == 'TD3':
critic1 = criticClass(args, env, policy)
critic1.to(device)
critic2 = criticClass(args, env, policy)
critic2.to(device)
critic = [critic1, critic2]
else:
critic = criticClass(args, env, policy)
critic.to(device)
print(critic)
print("Setup potential:")
potential = criticClass(args, env, policy)
potential.to(device)
print("Setup buffer:")
buffer = bufferClass(args, env, policy, critic)
print(buffer)
print("Setup agent:")
agent = agentClass(args, env, policy, critic, potential, buffer)
print(agent)
# online training
try:
print(args)
agent.train()
except KeyboardInterrupt:
print("Early stop manually")
exit_here = input("Exit completely without evaluation? (y/n) (default n):")
if exit_here.lower().startswith('y'):
print(os.linesep + '-' * 20 + ' END: ' + utils.get_local_time() + ' ' + '-' * 20)
exit(1)