Skip to content

Commit

Permalink
move more agent params to config file
Browse files Browse the repository at this point in the history
  • Loading branch information
campbelljc committed Jan 1, 2019
1 parent ecc9768 commit 124be8a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 33 deletions.
67 changes: 58 additions & 9 deletions gym_nethack/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from keras.optimizers import Adam
from libs.rl.agents.dqn import DQNAgent
from libs.rl.memory import SequentialMemory
from libs.rl.policy import LinearAnnealedPolicy

from gym_nethack.nhdata import *
from gym_nethack.policies import *
Expand Down Expand Up @@ -28,6 +31,7 @@

combat_thesis_configs = [
({ # 0
# env. parameters
'env_name': 'NetHackCombat-v0',
'name': 'thesis_dqn',
'num_actions': 40000,
Expand All @@ -41,13 +45,28 @@
'item_sampling': 'all',
'clvl_to_mlvl_diff': -3,
'dlvl': 10
}, {
}, { # model parameters
'agent': DQNAgent,
'agent_params': {
'nb_steps_warmup': 4000, # 10%
'enable_dueling_network': True,
'dueling_type': 'max',
'gamma': 0.99,
'delta_clip': 1.,
'memory': SequentialMemory,
'target_model_update': 400
},
'optimizer': Adam(0.0001),
'policy': LinearAnnealedPolicy,
'test_policy': EpsGreedyPossibleQPolicy(eps=0),
'memory': SequentialMemory,
'lr': 0.0001,
'units_d1': 32,
'units_d2': 16
}, { # policy parameters
'inner_policy': EpsGreedyPossibleQPolicy(),
'attr': 'eps',
'value_max': 1,
'value_min': 0,
'value_test': 0
}),
({ # 1
'env_name': 'NetHackCombat-v0',
Expand Down Expand Up @@ -96,12 +115,27 @@
'fixed_ac': 0,
'dlvl': 25
}, {
'agent': DQNAgent,
'agent_params': {
'nb_steps_warmup': 4000, # 10%
'enable_dueling_network': True,
'dueling_type': 'max',
'gamma': 0.99,
'delta_clip': 1.,
'memory': SequentialMemory,
'target_model_update': 400,
},
'optimizer': Adam(0.000001),
'policy': LinearAnnealedPolicy,
'test_policy': EpsGreedyPossibleQPolicy(eps=0),
'memory': SequentialMemory,
'lr': 0.000001,
'units_d1': 64,
'units_d2': 32
}, { # policy parameters
'inner_policy': EpsGreedyPossibleQPolicy(),
'attr': 'eps',
'value_max': 1,
'value_min': 0,
'value_test': 0
}),
({ # 4
'env_name': 'NetHackCombat-v0',
Expand All @@ -128,12 +162,27 @@
'clvl_to_mlvl_diff': 3,
'fixed_ac': -15
}, {
'agent': DQNAgent,
'agent_params': {
'nb_steps_warmup': 4000, # 10%
'enable_dueling_network': True,
'dueling_type': 'max',
'gamma': 0.99,
'delta_clip': 1.,
'memory': SequentialMemory,
'target_model_update': 400
},
'optimizer': Adam(0.000001),
'policy': LinearAnnealedPolicy,
'test_policy': EpsGreedyPossibleQPolicy(eps=0),
'memory': SequentialMemory,
'lr': 0.000001,
'units_d1': 64,
'units_d2': 32
}, { # policy parameters
'inner_policy': EpsGreedyPossibleQPolicy(),
'attr': 'eps',
'value_max': 1,
'value_min': 0,
'value_test': 0
}),
]

Expand Down Expand Up @@ -254,6 +303,6 @@
})
]

#configs = combat_thesis_configs
configs = exploration_configs
configs = combat_thesis_configs
#configs = exploration_configs
#configs = level_configs
3 changes: 1 addition & 2 deletions gym_nethack/envs/combat.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def get_savedir_info_list(self):
'df' + str(self.clvl_to_mlvl_diff)
]

def set_config(self, proc_id, num_actions=-1, num_episodes=-1, clvl_to_mlvl_diff=-3, monsters='none', initial_equipment=[], items=None, item_sampling='uniform', num_start_items=5, action_list='all', fixed_ac=999, dlvl=None, tabular=False, test_policy=None, lr=0, units_d1=0, units_d2=0, skip_training=False, load_combats=False, **args):
def set_config(self, proc_id, num_actions=-1, num_episodes=-1, clvl_to_mlvl_diff=-3, monsters='none', initial_equipment=[], items=None, item_sampling='uniform', num_start_items=5, action_list='all', fixed_ac=999, dlvl=None, tabular=False, test_policy=None, units_d1=0, units_d2=0, skip_training=False, load_combats=False, **args):
"""Set config.
Args:
Expand All @@ -164,7 +164,6 @@ def set_config(self, proc_id, num_actions=-1, num_episodes=-1, clvl_to_mlvl_diff
dlvl: dungeon level for the episode. affects monster attributes (thus difficulty).
tabular: whether we are using a tabular representation for the Q-values (deprecated)
test_policy: used for record folder name (also used in ngym.py)
lr: used for record folder name (also used in ngym.py)
units_d1: used for record folder name (also used in ngym.py)
units_d2: used for record folder name (also used in ngym.py)
skip_training: if True, will not add above info to folder name
Expand Down
35 changes: 13 additions & 22 deletions ngym.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def get_agent(model, env, config, policy_config=None):
Args:
config['learning_agent']: If False, then use a TestAgent with heuristic (not learned) policy given by config['test_policy']. Else, use deep Q-learning agent.
config['policy']: Keras annealing policy class, e.g., LinearAnnealedPolicy.
config['policy_eps']: Tuple to specify max/min vals for annealing policy.
config['test_policy']: The test policy class object to use (should already be instantiated).
config['memory']: Memory class to use, e.g., SequentialMemory.
config['lr']: learning rate to use with Adam optimizer (default is 0.00001).
config['optimizer']: instantiated Keras-RL optimizer object
config['agent']: Keras-RL learning agent class
config['agent_params']: parameters for constructor of above agent class
"""

policy = None
Expand All @@ -97,20 +97,10 @@ def get_agent(model, env, config, policy_config=None):
test_policy.agent = agent
test_policy.set_config(**policy_config)
else:
from keras.optimizers import Adam
from libs.rl.agents.dqn import DQNAgent
from libs.rl.policy import LinearAnnealedPolicy

memory = config['memory'](limit=env.memory_size if env.from_file else env.max_num_actions, window_length=1)

if 'policy' in config:
policy_eps = config['policy_eps'] if 'policy_eps' in config else (1, 0)
policy = config['policy'](inner_policy=EpsGreedyPossibleQPolicy(), attr='eps', value_max=policy_eps[0], value_min=policy_eps[1], value_test=policy_eps[1], nb_steps=env.max_num_actions_to_anneal_eps if env.from_file else env.max_num_actions)

agent = DQNAgent(model=model, nb_actions=env.action_space.n, nb_steps_warmup=env.max_num_actions/10,
enable_dueling_network=True, dueling_type='max', target_model_update=env.max_num_actions/100, gamma=0.99,
delta_clip=1., policy=policy, test_policy=test_policy, memory=memory)
agent.compile(Adam(lr=config['lr'] if 'lr' in config else 0.00001), metrics=['mae'])
config['agent_params']['memory'] = config['agent_params']['memory'](limit=env.memory_size if env.from_file else env.max_num_actions, window_length=1)
policy = config['policy'](nb_steps=env.max_num_actions_to_anneal_eps if env.from_file else env.max_num_actions, **policy_config)
agent = config['agent'](model=model, nb_actions=env.action_space.n, policy=policy, test_policy=test_policy, **config['agent_params'])
agent.compile(config['optimizer'], metrics=['mae'])

return agent

Expand All @@ -135,17 +125,18 @@ def get_agent(model, env, config, policy_config=None):
num_procs = int(sys.argv[3]) if len(sys.argv) > 3 else 1
print("Proc id:", proc_id, ", config id:", config_id)

configs[config_id][0]['num_procs'] = num_procs
if len(configs[config_id]) >= 3:
configs[config_id][2]['proc_id'] = proc_id
configs[config_id][2]['num_procs'] = num_procs
config = configs[config_id]

config[0]['num_procs'] = num_procs
if len(config) >= 3 and ('learning_agent' in config[1] and not config[1]['learning_agent']):
configs[2]['proc_id'] = proc_id
configs[2]['num_procs'] = num_procs

env = get_env(proc_id, config_id)

if not os.path.exists(env.savedir):
os.makedirs(env.savedir)

config = configs[config_id]
learning = False if 'learning_agent' in config[1] and not config[1]['learning_agent'] else True
model = get_model(env, config[1]) if learning else None
dqn = get_agent(model, env, config[1], config[2] if len(config) >= 3 else None)
Expand Down

0 comments on commit 124be8a

Please sign in to comment.