diff --git a/scripts/tools/tmi2/add_cp_as_triggers.py b/scripts/tools/tmi2/add_cp_as_triggers.py index 6638f42..27fc9b6 100644 --- a/scripts/tools/tmi2/add_cp_as_triggers.py +++ b/scripts/tools/tmi2/add_cp_as_triggers.py @@ -42,7 +42,7 @@ def main(): # ============================ # END ON RUN STEP # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): current = iface._read_int32() target = iface._read_int32() @@ -52,12 +52,12 @@ def main(): # ============================ # END ON CP COUNT # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): iface._read_int32() - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -65,7 +65,7 @@ def main(): iface.execute_command( f"add_trigger {checkpoint_positions[i][0] - 2} {checkpoint_positions[i][1] - 2} {checkpoint_positions[i][2] - 2} {checkpoint_positions[i][0] + 2} {checkpoint_positions[i][1] + 2} {checkpoint_positions[i][2] + 2}" ) - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass diff --git a/scripts/tools/tmi2/add_vcp_as_triggers.py b/scripts/tools/tmi2/add_vcp_as_triggers.py index abf716c..c348ca6 100644 --- a/scripts/tools/tmi2/add_vcp_as_triggers.py +++ b/scripts/tools/tmi2/add_vcp_as_triggers.py @@ -42,7 +42,7 @@ def main(): # ============================ # END ON RUN STEP # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): current = iface._read_int32() target = iface._read_int32() @@ -52,12 +52,12 @@ def main(): # ============================ # END ON CP COUNT # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): iface._read_int32() - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -68,7 +68,7 @@ def main(): # print( # f"add_trigger {vcp[i][0] - 0.4:.2f} {vcp[i][1] - 0.4:.2f} {vcp[i][2] - 0.4:.2f} {vcp[i][0] + 0.4:.2f} {vcp[i][1] + 0.4:.2f} {vcp[i][2] + 0.4:.2f}" # ) - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass diff --git a/scripts/tools/tmi2/empty_template.py b/scripts/tools/tmi2/empty_template.py index 420f720..ca15763 100644 --- a/scripts/tools/tmi2/empty_template.py +++ b/scripts/tools/tmi2/empty_template.py @@ -35,7 +35,7 @@ def main(): # ============================ # END ON RUN STEP # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): current = iface._read_int32() target = iface._read_int32() @@ -45,16 +45,16 @@ def main(): # ============================ # END ON CP COUNT # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): iface._read_int32() - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass diff --git a/scripts/tools/video_stuff/inputs_to_gbx.py b/scripts/tools/video_stuff/inputs_to_gbx.py index 5ca82ca..b4b2c76 100644 --- a/scripts/tools/video_stuff/inputs_to_gbx.py +++ b/scripts/tools/video_stuff/inputs_to_gbx.py @@ -224,7 +224,7 @@ def press_enter(N_presses=1): # expecting_replay_file = True iface.execute_command("finish") # request_map(iface,args.map_path) - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): # print("On CP") current = iface._read_int32() @@ -235,14 +235,14 @@ def press_enter(N_presses=1): iface.close() # iface.prevent_simulation_finish() else: - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): # print("On lap") iface._read_int32() iface._read_int32() - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -263,7 +263,7 @@ def press_enter(N_presses=1): map_loaded = True # else: # need_to_get_out_of_menu = True - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass close_game(tm_process_id) diff --git a/trackmania_rl/analysis_metrics.py b/trackmania_rl/analysis_metrics.py index 53db0aa..414d65f 100644 --- a/trackmania_rl/analysis_metrics.py +++ b/trackmania_rl/analysis_metrics.py @@ -190,15 +190,15 @@ def highest_prio_transitions(buffer, save_dir): shutil.rmtree(save_dir / "high_prio_figures", ignore_errors=True) (save_dir / "high_prio_figures").mkdir(parents=True, exist_ok=True) - prios = [buffer._sampler._sum_tree.at(i) for i in range(len(buffer))] + prios = [buffer.sampler.sum_tree.at(i) for i in range(len(buffer))] for high_error_idx in np.argsort(prios)[-20:]: for idx in range(max(0, high_error_idx - 4), min(len(buffer) - 1, high_error_idx + 5)): Image.fromarray( - np.hstack((buffer._storage[idx].state_img.squeeze(), buffer._storage[idx].next_state_img.squeeze())) + np.hstack((buffer.storage[idx].state_img.squeeze(), buffer.storage[idx].next_state_img.squeeze())) .repeat(4, 0) .repeat(4, 1) - ).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer._storage[idx].n_steps}_{prios[idx]:.2f}.png") + ).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer.storage[idx].n_steps}_{prios[idx]:.2f}.png") def get_output_and_target_for_batch(batch, online_network, target_network, num_quantiles): @@ -383,7 +383,7 @@ def distribution_curves(buffer, save_dir, online_network, target_network): np.hstack( ( np.expand_dims( - np.hstack((buffer._storage[i].state_img.squeeze(), buffer._storage[i].next_state_img.squeeze())) + np.hstack((buffer.storage[i].state_img.squeeze(), buffer.storage[i].next_state_img.squeeze())) .repeat(4, 0) .repeat(4, 1), axis=-1, @@ -392,5 +392,5 @@ def distribution_curves(buffer, save_dir, online_network, target_network): ) ) ) - ).save(save_dir / "distribution_curves" / f"{i}_{buffer._storage[i].n_steps}.png") + ).save(save_dir / "distribution_curves" / f"{i}_{buffer.storage[i].n_steps}.png") plt.close() diff --git a/trackmania_rl/buffer_utilities.py b/trackmania_rl/buffer_utilities.py index 3e94c61..62e303b 100644 --- a/trackmania_rl/buffer_utilities.py +++ b/trackmania_rl/buffer_utilities.py @@ -14,7 +14,8 @@ from torchrl.data import ListStorage, ReplayBuffer from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import Storage -from torchrl.data.replay_buffers.utils import INT_CLASSES, _to_numpy +from torchrl.data.replay_buffers.utils import INT_CLASSES +from torchrl.data.replay_buffers.utils import _to_numpy as to_numpy from config_files import config_copy @@ -247,8 +248,8 @@ def update_priority( else: if not (isinstance(priority, float) or len(priority) == 1 or len(index) == len(priority)): raise RuntimeError("priority should be a number or an iterable of the same length as index") - index = _to_numpy(index) - priority = _to_numpy(priority) + index = to_numpy(index) + priority = to_numpy(priority) # We track the _approximate_ number of memories in the buffer that have default priority : self._uninitialized_memories -= 0.3 * len(index) priority = np.power(priority + self._eps, self._alpha) @@ -274,17 +275,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def copy_buffer_content_to_other_buffer(source_buffer: ReplayBuffer, target_buffer: ReplayBuffer) -> None: - assert source_buffer._storage.max_size <= target_buffer._storage.max_size + assert source_buffer.storage.max_size <= target_buffer.storage.max_size - target_buffer.extend(source_buffer._storage._storage) + target_buffer.extend(source_buffer.storage._storage) - if isinstance(source_buffer._sampler, CustomPrioritizedSampler) and isinstance(target_buffer._sampler, CustomPrioritizedSampler): - target_buffer._sampler._average_priority = source_buffer._sampler._average_priority - target_buffer._sampler._uninitialized_memories = source_buffer._sampler._uninitialized_memories + if isinstance(source_buffer.sampler, CustomPrioritizedSampler) and isinstance(target_buffer.sampler, CustomPrioritizedSampler): + target_buffer.sampler._average_priority = source_buffer.sampler._average_priority + target_buffer.sampler._uninitialized_memories = source_buffer.sampler._uninitialized_memories - if isinstance(source_buffer._sampler, PrioritizedSampler) and isinstance(target_buffer._sampler, PrioritizedSampler): + if isinstance(source_buffer.sampler, PrioritizedSampler) and isinstance(target_buffer.sampler, PrioritizedSampler): for i in range(len(source_buffer)): - target_buffer._sampler._sum_tree[i] = source_buffer._sampler._sum_tree.at(i) + target_buffer.sampler.sum_tree[i] = source_buffer.sampler.sum_tree.at(i) def make_buffers(buffer_size: int) -> tuple[ReplayBuffer, ReplayBuffer]: diff --git a/trackmania_rl/multiprocess/learner_process.py b/trackmania_rl/multiprocess/learner_process.py index 8ae3ca0..71caa02 100644 --- a/trackmania_rl/multiprocess/learner_process.py +++ b/trackmania_rl/multiprocess/learner_process.py @@ -281,10 +281,10 @@ def learner_process_fn( param_group["epsilon"] = config_copy.adam_epsilon param_group["betas"] = (config_copy.adam_beta1, config_copy.adam_beta2) - if isinstance(buffer._sampler, PrioritizedSampler): - buffer._sampler._alpha = config_copy.prio_alpha - buffer._sampler._beta = config_copy.prio_beta - buffer._sampler._eps = config_copy.prio_epsilon + if isinstance(buffer.sampler, PrioritizedSampler): + buffer.sampler._alpha = config_copy.prio_alpha + buffer.sampler._beta = config_copy.prio_beta + buffer.sampler._eps = config_copy.prio_epsilon if config_copy.plot_race_time_left_curves and not is_explo and (loop_number // 5) % 17 == 0: race_time_left_curves(rollout_results, inferer, save_dir, map_name) @@ -495,7 +495,7 @@ def learner_process_fn( loss, grad_norm = trainer.train_on_batch(buffer, do_learn=True) accumulated_stats["cumul_number_single_memories_used"] += ( 10 * config_copy.batch_size - if (len(buffer) < buffer._storage.max_size and buffer._storage.max_size > 200_000) + if (len(buffer) < buffer.storage.max_size and buffer.storage.max_size > 200_000) else config_copy.batch_size ) # do fewer batches while memory is not full train_on_batch_duration_history.append(time.perf_counter() - train_start_time) @@ -575,8 +575,8 @@ def learner_process_fn( f"{key}_max": np.max(val), } ) - if isinstance(buffer._sampler, PrioritizedSampler): - all_priorities = np.array([buffer._sampler._sum_tree.at(i) for i in range(len(buffer))]) + if isinstance(buffer.sampler, PrioritizedSampler): + all_priorities = np.array([buffer.sampler.sum_tree.at(i) for i in range(len(buffer))]) step_stats.update( { "priorities_min": np.min(all_priorities), @@ -683,8 +683,8 @@ def learner_process_fn( # BUFFER STATS # =============================================== - mean_in_buffer = np.array([experience.state_float for experience in buffer._storage]).mean(axis=0) - std_in_buffer = np.array([experience.state_float for experience in buffer._storage]).std(axis=0) + mean_in_buffer = np.array([experience.state_float for experience in buffer.storage]).mean(axis=0) + std_in_buffer = np.array([experience.state_float for experience in buffer.storage]).std(axis=0) print("Raw mean in buffer :", mean_in_buffer.round(1)) print("Raw std in buffer :", std_in_buffer.round(1)) @@ -702,7 +702,7 @@ def learner_process_fn( # =============================================== # HIGH PRIORITY TRANSITIONS # =============================================== - if config_copy.make_highest_prio_figures and isinstance(buffer._sampler, PrioritizedSampler): + if config_copy.make_highest_prio_figures and isinstance(buffer.sampler, PrioritizedSampler): highest_prio_transitions(buffer, save_dir) # =============================================== diff --git a/trackmania_rl/tmi_interaction/game_instance_manager.py b/trackmania_rl/tmi_interaction/game_instance_manager.py index 4324e9b..2aa3d2e 100644 --- a/trackmania_rl/tmi_interaction/game_instance_manager.py +++ b/trackmania_rl/tmi_interaction/game_instance_manager.py @@ -331,7 +331,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt self.request_speed(self.running_speed) if self.msgtype_response_to_wakeup_TMI is not None: - self.iface._respond_to_call(self.msgtype_response_to_wakeup_TMI) + self.iface.respond_to_call(self.msgtype_response_to_wakeup_TMI) self.msgtype_response_to_wakeup_TMI = None self.last_rollout_crashed = False @@ -579,7 +579,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt # END ON RUN STEP # ============================ if self.msgtype_response_to_wakeup_TMI is None: - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) if _time > 0 and this_rollout_has_seen_t_negative: if _time % 40 == 0: @@ -667,11 +667,11 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt # END ON CP COUNT # ============================ if self.msgtype_response_to_wakeup_TMI is None: - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): self.iface._read_int32() self.iface._read_int32() - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): frame = self.grab_screen() frame_expected = False @@ -728,7 +728,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt self.game_spawning_lock.release() instrumentation__request_inputs_and_speed += time.perf_counter_ns() - pc8 - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): self.iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -746,7 +746,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt if self.iface.is_in_menus() and map_path != self.latest_map_path_requested: print("Requested map load") self.request_map(map_path, zone_centers) - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) else: pass except socket.timeout as err: diff --git a/trackmania_rl/tmi_interaction/tminterface2.py b/trackmania_rl/tmi_interaction/tminterface2.py index 5e67e10..c512a50 100644 --- a/trackmania_rl/tmi_interaction/tminterface2.py +++ b/trackmania_rl/tmi_interaction/tminterface2.py @@ -145,7 +145,7 @@ def get_inputs(self): string_length = self._read_int32() return self.sock.recv(string_length, socket.MSG_WAITALL).decode("utf-8") - def _respond_to_call(self, response_type): + def respond_to_call(self, response_type): self.sock.sendall(struct.pack("i", np.int32(response_type))) def _read_int32(self):