Skip to content

Commit

Permalink
fix(pu): fix gradient accumulation_steps option
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan committed Feb 8, 2025
1 parent 51185e3 commit 581310b
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 38 deletions.
45 changes: 32 additions & 13 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _init_learn(self) -> None:
wandb.watch(self._learn_model.representation_network, log="all")

# TODO: ========
self.accumulation_steps = 4 # 累积的步数
self.accumulation_steps = 1 # 累积的步数

# @profile
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
Expand Down Expand Up @@ -467,8 +467,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"

# Core learn model update step
if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数
# print(f'train_iter:{train_iter}')
# print(f'train_iter:{train_iter}')
# 假设 train_iter 是从 0 开始计数
if (train_iter % self.accumulation_steps) == 0:
# 每个累计周期的第一个step时清零梯度
# print(f'train_iter:{train_iter} self._optimizer_world_model.zero_grad()')
self._optimizer_world_model.zero_grad()

weighted_total_loss = weighted_total_loss / self.accumulation_steps # 累积梯度时对 loss 进行缩放
Expand All @@ -481,16 +484,30 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# if param.requires_grad:
# print(name, param.grad.norm())

if self._cfg.analysis_sim_norm:
del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
self._target_model.encoder_hook.clear_data()

total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(),
self._cfg.grad_clip_value)
# if self._cfg.analysis_sim_norm:
# del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
# self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
# self._target_model.encoder_hook.clear_data()

# total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(),
# self._cfg.grad_clip_value)

# 判断是否完成了一个累计周期(例如:如果 accumulation_steps=4, 则在 4,8,12,... 次迭代时更新参数)
if (train_iter + 1) % self.accumulation_steps == 0:
# print(f'train_iter:{train_iter} self._optimizer_world_model.step()')

# ========== 分析梯度模的代码 ==========
if self._cfg.analysis_sim_norm:
# 删除上次的分析结果,防止累积过多内存
del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after
self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze()
self._target_model.encoder_hook.clear_data()

# 对梯度进行裁剪
total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(
self._learn_model.world_model.parameters(), self._cfg.grad_clip_value
)

if train_iter % self.accumulation_steps == 0: # 每 accumulation_steps 步更新一次参数
# print(f'pos 2 train_iter:{train_iter}')

if self._cfg.multi_gpu:
self.sync_gradients(self._learn_model)
Expand All @@ -503,8 +520,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# Core target model update step
self._target_model.update(self._learn_model.state_dict())

if self.accumulation_steps>1:
if self.accumulation_steps > 1:
torch.cuda.empty_cache()
else:
total_grad_norm_before_clip_wm = torch.tensor(0.)

if torch.cuda.is_available():
torch.cuda.synchronize()
Expand Down
18 changes: 10 additions & 8 deletions zoo/jericho/configs/jericho_ppo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
action_space_size = 10
max_steps = 50
model_name = 'BAAI/bge-base-en-v1.5'
env_id = 'detective.z5'
# env_id = 'detective.z5'

action_space_size = 10
max_steps = 400
env_id = 'zork1.z5'

evaluator_env_num = 2

# proj train
Expand All @@ -22,14 +27,13 @@
# num_unroll_steps = 5
# infer_context_length = 2
jericho_ppo_config = dict(
# exp_name=f"data_ppo_detective/jericho_ppo_projtrain_bs{batch_size}_seed0",
exp_name=f"data_ppo_detective_debug/jericho_add-loc-inv_ppo_projtrain_bs{batch_size}_seed0",
exp_name=f"data_ppo_detective/jericho_{env_id}_ms{max_steps}_ass{action_space_size}_ppo_projtrain_bs{batch_size}_seed0",
# exp_name=f"data_ppo_detective_debug/jericho_add-loc-inv_ppo_projtrain_bs{batch_size}_seed0",
env=dict(
remove_stuck_actions=False,
# remove_stuck_actions=True,
add_location_and_inventory=True,
# add_location_and_inventory=False,

# add_location_and_inventory=True,
add_location_and_inventory=False,
stop_value=int(1e6),
observation_shape=512,
max_steps=max_steps,
Expand Down Expand Up @@ -60,13 +64,11 @@
epoch_per_collect=4,
batch_size=batch_size,
learning_rate=0.0005,
# entropy_weight=0.01,
entropy_weight=0.05,
value_norm=True,
grad_clip_value=10,
),
collect=dict(
# n_sample=1024,
n_sample=320, # TODO: DEBUG
discount_factor=0.99,
gae_lambda=0.95,
Expand Down
7 changes: 4 additions & 3 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def main(env_id='detective.z5', seed=0):
# ==============================================================
jericho_unizero_config = dict(
env=dict(
# remove_stuck_actions=False,
remove_stuck_actions=True,
remove_stuck_actions=False,
# remove_stuck_actions=True,

stop_value=int(1e6),
observation_shape=512,
Expand Down Expand Up @@ -167,7 +167,8 @@ def main(env_id='detective.z5', seed=0):
main_config = jericho_unizero_config
create_config = jericho_unizero_create_config

main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}-remove-novalid_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
main_config.exp_name = f'data_unizero_detective_20250209/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_proj-train-accstep1_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
# main_config.exp_name = f'data_unizero_detective_20250209/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}-remove-novalid_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_proj-train-accstep4_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'

# main_config.exp_name = f'data_unizero_detective_20250107/{model_name}/{env_id[:8]}_ms{max_steps}_action-space-{action_space_size}_all-train_uz_nlayer{num_layers}_embed512_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
Expand Down
34 changes: 20 additions & 14 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,19 @@ def prepare_obs(self, obs, return_str: bool = False):

if len(available_actions) <= self.max_action_num:
action_mask = [1] * len(available_actions) + [0] * (self.max_action_num - len(available_actions))
else:
elif len(available_actions) == self.max_action_num:
action_mask = [1] * len(available_actions)
else:
action_mask = [1] * self.max_action_num

action_mask = np.array(action_mask, dtype=np.int8)

if return_str: # TODO===============
# return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
return {'observation': full_obs, 'action_mask': action_mask}
if return_str: # TODO: unizero需要加上'to_play'===============
return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
# return {'observation': full_obs, 'action_mask': action_mask}
else:
# return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1}
return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask}
return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1}
# return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask}


def reset(self, return_str: bool = False):
Expand Down Expand Up @@ -179,8 +181,8 @@ def step(self, action: int, return_str: bool = False):
self.timestep += 1
# print(f'step: {self.timestep}, [OBS]:{observation} self._action_list:{self._action_list}')

# TODO: for PPO
reward = np.array([float(reward)])
# TODO: for PPO, 如果是unizero需要注释下面这行
# reward = np.array([float(reward)])

self.env_step += 1
self.episode_return += reward
Expand Down Expand Up @@ -234,16 +236,20 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
from easydict import EasyDict
env_cfg = EasyDict(
dict(
max_steps=100,
max_steps=400,
# game_path="z-machine-games-master/jericho-game-suite/zork1.z5",
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5",
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/zork1.z5",
# game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5",
# game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/905.z5",
max_action_num=50,
max_env_step=100,
# max_action_num=50,
max_action_num=10,
# max_env_step=100,
tokenizer_path="google-bert/bert-base-uncased",
max_seq_len=512,
remove_stuck_actions=True, # 启用移除无效动作的功能
add_location_and_inventory=True
remove_stuck_actions=False, # 启用移除无效动作的功能
add_location_and_inventory=False
# remove_stuck_actions=True, # 启用移除无效动作的功能
# add_location_and_inventory=True
)
)
env = JerichoEnv(env_cfg)
Expand Down

0 comments on commit 581310b

Please sign in to comment.