-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathtrain_reward_model.py
57 lines (45 loc) · 1.58 KB
/
train_reward_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import hydra
import numpy as np
import torch as th
from torch.utils.tensorboard import SummaryWriter
import control_pcgrl
from control_pcgrl.configs.config import Config
from control_pcgrl.reward_model_wrappers import init_reward_model, train_reward_model
from control_pcgrl.rl.envs import make_env
from control_pcgrl.rl.utils import validate_config
batch_size = 64
n_train_iters = 10000
@hydra.main(config_path="control_pcgrl/configs", config_name="config")
def main(cfg: Config):
"""Train a model to predict relevant metrics in a PCGRL env. Generate data with random actions
(i.e. random map edits).
"""
if not validate_config(cfg):
print("Invalid config!")
return
log_dir = 'logs_reward_model'
log_dir = os.path.join(hydra.utils.get_original_cwd(), log_dir)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
cfg.train_reward_model = True
env = make_env(cfg)
metric_keys = list(env.metrics.keys())
env.reset()
model, optimizer = init_reward_model(env)
writer = SummaryWriter(log_dir=log_dir)
for i in range(n_train_iters):
# Collect data
while len(env.datapoints) < batch_size:
env.step(env.action_space.sample())
# print(f"Collected {len(env.datapoints)} datapoints")
# Train
feats, metrics = env.collect_data()
loss = train_reward_model(model, optimizer, feats, metrics)
writer.add_scalar("Loss", loss, i)
print(f"Loss: {loss}")
# Reset
env.datapoints = []
env.reset()
if __name__ == "__main__":
main()