diff --git a/scripts/tools/tmi2/add_cp_as_triggers.py b/scripts/tools/tmi2/add_cp_as_triggers.py index fcea293a..358f6232 100644 --- a/scripts/tools/tmi2/add_cp_as_triggers.py +++ b/scripts/tools/tmi2/add_cp_as_triggers.py @@ -29,34 +29,34 @@ def main(): print(e) while True: - msgtype = iface._read_int32() + msgtype = iface.read_int32() # ============================================= # READ INCOMING MESSAGES # ============================================= if msgtype == int(MessageType.SC_RUN_STEP_SYNC): - _time = iface._read_int32() + _time = iface.read_int32() # ============================ # BEGIN ON RUN STEP # ============================ # ============================ # END ON RUN STEP # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): - current = iface._read_int32() - target = iface._read_int32() + current = iface.read_int32() + target = iface.read_int32() # ============================ # BEGIN ON CP COUNT # ============================ # ============================ # END ON CP COUNT # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): - iface._read_int32() - iface._respond_to_call(msgtype) + iface.read_int32() + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -64,7 +64,7 @@ def main(): iface.execute_command( f"add_trigger {checkpoint_positions[i][0] - 2} {checkpoint_positions[i][1] - 2} {checkpoint_positions[i][2] - 2} {checkpoint_positions[i][0] + 2} {checkpoint_positions[i][1] + 2} {checkpoint_positions[i][2] + 2}" ) - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass diff --git a/scripts/tools/tmi2/add_vcp_as_triggers.py b/scripts/tools/tmi2/add_vcp_as_triggers.py index 606debdd..8bf87c05 100644 --- a/scripts/tools/tmi2/add_vcp_as_triggers.py +++ b/scripts/tools/tmi2/add_vcp_as_triggers.py @@ -29,34 +29,34 @@ def main(): print(e) while True: - msgtype = iface._read_int32() + msgtype = iface.read_int32() # ============================================= # READ INCOMING MESSAGES # ============================================= if msgtype == int(MessageType.SC_RUN_STEP_SYNC): - _time = iface._read_int32() + _time = iface.read_int32() # ============================ # BEGIN ON RUN STEP # ============================ # ============================ # END ON RUN STEP # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): - current = iface._read_int32() - target = iface._read_int32() + current = iface.read_int32() + target = iface.read_int32() # ============================ # BEGIN ON CP COUNT # ============================ # ============================ # END ON CP COUNT # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): - iface._read_int32() - iface._respond_to_call(msgtype) + iface.read_int32() + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -67,7 +67,7 @@ def main(): # print( # f"add_trigger {vcp[i][0] - 0.4:.2f} {vcp[i][1] - 0.4:.2f} {vcp[i][2] - 0.4:.2f} {vcp[i][0] + 0.4:.2f} {vcp[i][1] + 0.4:.2f} {vcp[i][2] + 0.4:.2f}" # ) - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass diff --git a/scripts/tools/tmi2/empty_template.py b/scripts/tools/tmi2/empty_template.py index 81bfd64d..a12a398f 100644 --- a/scripts/tools/tmi2/empty_template.py +++ b/scripts/tools/tmi2/empty_template.py @@ -21,12 +21,12 @@ def main(): print(e) while True: - msgtype = iface._read_int32() + msgtype = iface.read_int32() # ============================================= # READ INCOMING MESSAGES # ============================================= if msgtype == int(MessageType.SC_RUN_STEP_SYNC): - _time = iface._read_int32() + _time = iface.read_int32() # ============================ # BEGIN ON RUN STEP # ============================ @@ -34,26 +34,26 @@ def main(): # ============================ # END ON RUN STEP # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): - current = iface._read_int32() - target = iface._read_int32() + current = iface.read_int32() + target = iface.read_int32() # ============================ # BEGIN ON CP COUNT # ============================ # ============================ # END ON CP COUNT # ============================ - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): - iface._read_int32() - iface._respond_to_call(msgtype) + iface.read_int32() + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass diff --git a/scripts/tools/video_stuff/inputs_to_gbx.py b/scripts/tools/video_stuff/inputs_to_gbx.py index 31b38118..16304e4f 100644 --- a/scripts/tools/video_stuff/inputs_to_gbx.py +++ b/scripts/tools/video_stuff/inputs_to_gbx.py @@ -211,10 +211,10 @@ def press_enter(N_presses=1): time.sleep(0.1) press_enter() if iface.registered: - msgtype = iface._read_int32() + msgtype = iface.read_int32() if msgtype == int(MessageType.SC_RUN_STEP_SYNC): # print("On step") - _time = iface._read_int32() + _time = iface.read_int32() if not give_up_signal_has_been_sent: iface.execute_command("load " + input_files[current_input_idx]) iface.give_up() @@ -223,25 +223,25 @@ def press_enter(N_presses=1): # expecting_replay_file = True iface.execute_command("finish") # request_map(iface,args.map_path) - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): # print("On CP") - current = iface._read_int32() - target = iface._read_int32() + current = iface.read_int32() + target = iface.read_int32() if current == target and not expecting_replay_file: # Run finished expecting_replay_file = True # press_enter() iface.close() # iface.prevent_simulation_finish() else: - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): # print("On lap") - iface._read_int32() - iface._read_int32() - iface._respond_to_call(msgtype) + iface.read_int32() + iface.read_int32() + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -262,7 +262,7 @@ def press_enter(N_presses=1): map_loaded = True # else: # need_to_get_out_of_menu = True - iface._respond_to_call(msgtype) + iface.respond_to_call(msgtype) else: pass close_game(tm_process_id) diff --git a/trackmania_rl/analysis_metrics.py b/trackmania_rl/analysis_metrics.py index 30507c38..e759fbe8 100644 --- a/trackmania_rl/analysis_metrics.py +++ b/trackmania_rl/analysis_metrics.py @@ -189,15 +189,15 @@ def highest_prio_transitions(buffer, save_dir): shutil.rmtree(save_dir / "high_prio_figures", ignore_errors=True) (save_dir / "high_prio_figures").mkdir(parents=True, exist_ok=True) - prios = [buffer._sampler._sum_tree.at(i) for i in range(len(buffer))] + prios = [buffer.sampler.sum_tree.at(i) for i in range(len(buffer))] for high_error_idx in np.argsort(prios)[-20:]: for idx in range(max(0, high_error_idx - 4), min(len(buffer) - 1, high_error_idx + 5)): Image.fromarray( - np.hstack((buffer._storage[idx].state_img.squeeze(), buffer._storage[idx].next_state_img.squeeze())) + np.hstack((buffer.storage[idx].state_img.squeeze(), buffer.storage[idx].next_state_img.squeeze())) .repeat(4, 0) .repeat(4, 1) - ).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer._storage[idx].n_steps}_{prios[idx]:.2f}.png") + ).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer.storage[idx].n_steps}_{prios[idx]:.2f}.png") def get_output_and_target_for_batch(batch, online_network, target_network, num_quantiles): @@ -382,7 +382,7 @@ def distribution_curves(buffer, save_dir, online_network, target_network): np.hstack( ( np.expand_dims( - np.hstack((buffer._storage[i].state_img.squeeze(), buffer._storage[i].next_state_img.squeeze())) + np.hstack((buffer.storage[i].state_img.squeeze(), buffer.storage[i].next_state_img.squeeze())) .repeat(4, 0) .repeat(4, 1), axis=-1, @@ -391,5 +391,5 @@ def distribution_curves(buffer, save_dir, online_network, target_network): ) ) ) - ).save(save_dir / "distribution_curves" / f"{i}_{buffer._storage[i].n_steps}.png") + ).save(save_dir / "distribution_curves" / f"{i}_{buffer.storage[i].n_steps}.png") plt.close() diff --git a/trackmania_rl/buffer_utilities.py b/trackmania_rl/buffer_utilities.py index 176484db..ac2597e7 100644 --- a/trackmania_rl/buffer_utilities.py +++ b/trackmania_rl/buffer_utilities.py @@ -13,7 +13,7 @@ from torchrl.data import ListStorage, ReplayBuffer from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler from torchrl.data.replay_buffers.storages import Storage -from torchrl.data.replay_buffers.utils import INT_CLASSES, _to_numpy +from torchrl.data.replay_buffers.utils import INT_CLASSES, _to_numpy as to_numpy from config_files import config_copy @@ -26,7 +26,6 @@ "float64": torch.float32, } - def fast_collate_cpu(batch, attr_name): elem = getattr(batch[0], attr_name) elem_array = hasattr(elem, "__len__") @@ -248,8 +247,8 @@ def update_priority( else: if not (isinstance(priority, float) or len(priority) == 1 or len(index) == len(priority)): raise RuntimeError("priority should be a number or an iterable of the same length as index") - index = _to_numpy(index) - priority = _to_numpy(priority) + index = to_numpy(index) + priority = to_numpy(priority) # We track the _approximate_ number of memories in the buffer that have default priority : self._uninitialized_memories -= 0.3 * len(index) priority = np.power(priority + self._eps, self._alpha) @@ -275,17 +274,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: def copy_buffer_content_to_other_buffer(source_buffer: ReplayBuffer, target_buffer: ReplayBuffer) -> None: - assert source_buffer._storage.max_size <= target_buffer._storage.max_size + assert source_buffer.storage.max_size <= target_buffer.storage.max_size - target_buffer.extend(source_buffer._storage._storage) + target_buffer.extend(source_buffer.storage._storage) - if isinstance(source_buffer._sampler, CustomPrioritizedSampler) and isinstance(target_buffer._sampler, CustomPrioritizedSampler): - target_buffer._sampler._average_priority = source_buffer._sampler._average_priority - target_buffer._sampler._uninitialized_memories = source_buffer._sampler._uninitialized_memories + if isinstance(source_buffer.sampler, CustomPrioritizedSampler) and isinstance(target_buffer.sampler, CustomPrioritizedSampler): + target_buffer.sampler._average_priority = source_buffer.sampler.average_priority + target_buffer.sampler._uninitialized_memories = source_buffer.sampler.uninitialized_memories - if isinstance(source_buffer._sampler, PrioritizedSampler) and isinstance(target_buffer._sampler, PrioritizedSampler): + if isinstance(source_buffer.sampler, PrioritizedSampler) and isinstance(target_buffer.sampler, PrioritizedSampler): for i in range(len(source_buffer)): - target_buffer._sampler._sum_tree[i] = source_buffer._sampler._sum_tree.at(i) + target_buffer.sampler.sum_tree[i] = source_buffer.sampler.sum_tree.at(i) def make_buffers(buffer_size: int) -> tuple[ReplayBuffer, ReplayBuffer]: diff --git a/trackmania_rl/multiprocess/learner_process.py b/trackmania_rl/multiprocess/learner_process.py index ffa5a9ce..830c6cdb 100644 --- a/trackmania_rl/multiprocess/learner_process.py +++ b/trackmania_rl/multiprocess/learner_process.py @@ -280,10 +280,10 @@ def learner_process_fn( param_group["epsilon"] = config_copy.adam_epsilon param_group["betas"] = (config_copy.adam_beta1, config_copy.adam_beta2) - if isinstance(buffer._sampler, PrioritizedSampler): - buffer._sampler._alpha = config_copy.prio_alpha - buffer._sampler._beta = config_copy.prio_beta - buffer._sampler._eps = config_copy.prio_epsilon + if isinstance(buffer.sampler, PrioritizedSampler): + buffer.sampler._alpha = config_copy.prio_alpha + buffer.sampler._beta = config_copy.prio_beta + buffer.sampler._eps = config_copy.prio_epsilon if config_copy.plot_race_time_left_curves and not is_explo and (loop_number // 5) % 17 == 0: race_time_left_curves(rollout_results, inferer, save_dir, map_name) @@ -494,7 +494,7 @@ def learner_process_fn( loss, grad_norm = trainer.train_on_batch(buffer, do_learn=True) accumulated_stats["cumul_number_single_memories_used"] += ( 10 * config_copy.batch_size - if (len(buffer) < buffer._storage.max_size and buffer._storage.max_size > 200_000) + if (len(buffer) < buffer.storage.max_size and buffer.storage.max_size > 200_000) else config_copy.batch_size ) # do fewer batches while memory is not full train_on_batch_duration_history.append(time.perf_counter() - train_start_time) @@ -574,8 +574,8 @@ def learner_process_fn( f"{key}_max": np.max(val), } ) - if isinstance(buffer._sampler, PrioritizedSampler): - all_priorities = np.array([buffer._sampler._sum_tree.at(i) for i in range(len(buffer))]) + if isinstance(buffer.sampler, PrioritizedSampler): + all_priorities = np.array([buffer.sampler.sum_tree.at(i) for i in range(len(buffer))]) step_stats.update( { "priorities_min": np.min(all_priorities), @@ -682,8 +682,8 @@ def learner_process_fn( # BUFFER STATS # =============================================== - mean_in_buffer = np.array([experience.state_float for experience in buffer._storage]).mean(axis=0) - std_in_buffer = np.array([experience.state_float for experience in buffer._storage]).std(axis=0) + mean_in_buffer = np.array([experience.state_float for experience in buffer.storage]).mean(axis=0) + std_in_buffer = np.array([experience.state_float for experience in buffer.storage]).std(axis=0) print("Raw mean in buffer :", mean_in_buffer.round(1)) print("Raw std in buffer :", std_in_buffer.round(1)) @@ -701,7 +701,7 @@ def learner_process_fn( # =============================================== # HIGH PRIORITY TRANSITIONS # =============================================== - if config_copy.make_highest_prio_figures and isinstance(buffer._sampler, PrioritizedSampler): + if config_copy.make_highest_prio_figures and isinstance(buffer.sampler, PrioritizedSampler): highest_prio_transitions(buffer, save_dir) # =============================================== diff --git a/trackmania_rl/tmi_interaction/game_instance_manager.py b/trackmania_rl/tmi_interaction/game_instance_manager.py index 483a3056..b3006cea 100644 --- a/trackmania_rl/tmi_interaction/game_instance_manager.py +++ b/trackmania_rl/tmi_interaction/game_instance_manager.py @@ -327,7 +327,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt self.request_speed(self.running_speed) if self.msgtype_response_to_wakeup_TMI is not None: - self.iface._respond_to_call(self.msgtype_response_to_wakeup_TMI) + self.iface.respond_to_call(self.msgtype_response_to_wakeup_TMI) self.msgtype_response_to_wakeup_TMI = None self.last_rollout_crashed = False @@ -475,13 +475,13 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt compute_action_asap_floats = False - msgtype = self.iface._read_int32() + msgtype = self.iface.read_int32() # ============================================= # READ INCOMING MESSAGES # ============================================= if msgtype == int(MessageType.SC_RUN_STEP_SYNC): - _time = self.iface._read_int32() + _time = self.iface.read_int32() if _time > 0 and this_rollout_has_seen_t_negative: if _time % 50 == 0: @@ -570,7 +570,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt # END ON RUN STEP # ============================ if self.msgtype_response_to_wakeup_TMI is None: - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) if _time > 0 and this_rollout_has_seen_t_negative: if _time % 40 == 0: @@ -580,8 +580,8 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt instrumentation__answer_action_step += time.perf_counter_ns() - pc pc = time.perf_counter_ns() elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC): - current = self.iface._read_int32() - target = self.iface._read_int32() + current = self.iface.read_int32() + target = self.iface.read_int32() simulation_state = self.iface.get_simulation_state() end_race_stats["cp_time_ms"].append(simulation_state.race_time) @@ -658,11 +658,11 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt # END ON CP COUNT # ============================ if self.msgtype_response_to_wakeup_TMI is None: - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC): - self.iface._read_int32() - self.iface._read_int32() - self.iface._respond_to_call(msgtype) + self.iface.read_int32() + self.iface.read_int32() + self.iface.respond_to_call(msgtype) elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC): frame = self.grab_screen() frame_expected = False @@ -719,7 +719,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt self.game_spawning_lock.release() instrumentation__request_inputs_and_speed += time.perf_counter_ns() - pc8 - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) elif msgtype == int(MessageType.C_SHUTDOWN): self.iface.close() elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC): @@ -737,7 +737,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt if self.iface.is_in_menus() and map_path != self.latest_map_path_requested: print("Requested map load") self.request_map(map_path, zone_centers) - self.iface._respond_to_call(msgtype) + self.iface.respond_to_call(msgtype) else: pass except socket.timeout as err: diff --git a/trackmania_rl/tmi_interaction/tminterface2.py b/trackmania_rl/tmi_interaction/tminterface2.py index 520cfba3..c0db8230 100644 --- a/trackmania_rl/tmi_interaction/tminterface2.py +++ b/trackmania_rl/tmi_interaction/tminterface2.py @@ -90,7 +90,7 @@ def reset_camera(self): def get_simulation_state(self): self.sock.sendall(struct.pack("i", MessageType.C_GET_SIMULATION_STATE)) - state_length = self._read_int32() + state_length = self.read_int32() state = SimStateData(self.sock.recv(state_length, socket.MSG_WAITALL)) state.cp_data.resize(CheckpointData.cp_states_field, state.cp_data.cp_states_length) state.cp_data.resize(CheckpointData.cp_times_field, state.cp_data.cp_times_length) @@ -119,7 +119,7 @@ def set_speed(self, new_speed): def race_finished(self): self.sock.sendall(struct.pack("i", MessageType.C_RACE_FINISHED)) - a = self._read_int32() + a = self.read_int32() return a def request_frame(self, W: int, H: int): @@ -140,15 +140,15 @@ def set_on_step_period(self, period: int): def is_in_menus(self): self.sock.sendall(struct.pack("i", MessageType.C_IS_IN_MENUS)) - return self._read_int32() > 0 + return self.read_int32() > 0 def get_inputs(self): self.sock.sendall(struct.pack("i", MessageType.C_GET_INPUTS)) - string_length = self._read_int32() + string_length = self.read_int32() return self.sock.recv(string_length, socket.MSG_WAITALL).decode("utf-8") - def _respond_to_call(self, response_type): + def respond_to_call(self, response_type): self.sock.sendall(struct.pack("i", np.int32(response_type))) - def _read_int32(self): + def read_int32(self): return struct.unpack("i", self.sock.recv(4, socket.MSG_WAITALL))[0]