Skip to content

Commit

Permalink
Fix warnings: Access to a protected member of a class
Browse files Browse the repository at this point in the history
IntelliJ complains when a protected member is accessed from outside the class.

Example:
- tminterface2.py: class `TMInterface` has a method `_read_int32()`
- add_cp_as_triggers.py: uses the method and has a warning

This commit fixes the warnings where possible.

The warning is not always correct as found here:

trackmania_rl/buffer_utilities.py:280
`target_buffer.extend(source_buffer.storage._storage)`
Warning: _ListStorage._storage is not publicly accessible_ is wrong because `ListStorage._storage` is not publicly accessible

Fixes #56
  • Loading branch information
Wuodan committed Aug 30, 2024
1 parent 040fa15 commit 07c6fc3
Show file tree
Hide file tree
Showing 9 changed files with 53 additions and 52 deletions.
10 changes: 5 additions & 5 deletions scripts/tools/tmi2/add_cp_as_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
# ============================
# END ON RUN STEP
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
current = iface._read_int32()
target = iface._read_int32()
Expand All @@ -52,20 +52,20 @@ def main():
# ============================
# END ON CP COUNT
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
iface._read_int32()
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
for i in range(0, len(checkpoint_positions), 1):
iface.execute_command(
f"add_trigger {checkpoint_positions[i][0] - 2} {checkpoint_positions[i][1] - 2} {checkpoint_positions[i][2] - 2} {checkpoint_positions[i][0] + 2} {checkpoint_positions[i][1] + 2} {checkpoint_positions[i][2] + 2}"
)
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass

Expand Down
10 changes: 5 additions & 5 deletions scripts/tools/tmi2/add_vcp_as_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def main():
# ============================
# END ON RUN STEP
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
current = iface._read_int32()
target = iface._read_int32()
Expand All @@ -52,12 +52,12 @@ def main():
# ============================
# END ON CP COUNT
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
iface._read_int32()
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
Expand All @@ -68,7 +68,7 @@ def main():
# print(
# f"add_trigger {vcp[i][0] - 0.4:.2f} {vcp[i][1] - 0.4:.2f} {vcp[i][2] - 0.4:.2f} {vcp[i][0] + 0.4:.2f} {vcp[i][1] + 0.4:.2f} {vcp[i][2] + 0.4:.2f}"
# )
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass

Expand Down
10 changes: 5 additions & 5 deletions scripts/tools/tmi2/empty_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main():
# ============================
# END ON RUN STEP
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
current = iface._read_int32()
target = iface._read_int32()
Expand All @@ -45,16 +45,16 @@ def main():
# ============================
# END ON CP COUNT
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
iface._read_int32()
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass

Expand Down
10 changes: 5 additions & 5 deletions scripts/tools/video_stuff/inputs_to_gbx.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def press_enter(N_presses=1):
# expecting_replay_file = True
iface.execute_command("finish")
# request_map(iface,args.map_path)
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
# print("On CP")
current = iface._read_int32()
Expand All @@ -235,14 +235,14 @@ def press_enter(N_presses=1):
iface.close()
# iface.prevent_simulation_finish()
else:
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
# print("On lap")
iface._read_int32()
iface._read_int32()
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
Expand All @@ -263,7 +263,7 @@ def press_enter(N_presses=1):
map_loaded = True
# else:
# need_to_get_out_of_menu = True
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass
close_game(tm_process_id)
Expand Down
10 changes: 5 additions & 5 deletions trackmania_rl/analysis_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,15 @@ def highest_prio_transitions(buffer, save_dir):
shutil.rmtree(save_dir / "high_prio_figures", ignore_errors=True)
(save_dir / "high_prio_figures").mkdir(parents=True, exist_ok=True)

prios = [buffer._sampler._sum_tree.at(i) for i in range(len(buffer))]
prios = [buffer.sampler.sum_tree.at(i) for i in range(len(buffer))]

for high_error_idx in np.argsort(prios)[-20:]:
for idx in range(max(0, high_error_idx - 4), min(len(buffer) - 1, high_error_idx + 5)):
Image.fromarray(
np.hstack((buffer._storage[idx].state_img.squeeze(), buffer._storage[idx].next_state_img.squeeze()))
np.hstack((buffer.storage[idx].state_img.squeeze(), buffer.storage[idx].next_state_img.squeeze()))
.repeat(4, 0)
.repeat(4, 1)
).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer._storage[idx].n_steps}_{prios[idx]:.2f}.png")
).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer.storage[idx].n_steps}_{prios[idx]:.2f}.png")


def get_output_and_target_for_batch(batch, online_network, target_network, num_quantiles):
Expand Down Expand Up @@ -383,7 +383,7 @@ def distribution_curves(buffer, save_dir, online_network, target_network):
np.hstack(
(
np.expand_dims(
np.hstack((buffer._storage[i].state_img.squeeze(), buffer._storage[i].next_state_img.squeeze()))
np.hstack((buffer.storage[i].state_img.squeeze(), buffer.storage[i].next_state_img.squeeze()))
.repeat(4, 0)
.repeat(4, 1),
axis=-1,
Expand All @@ -392,5 +392,5 @@ def distribution_curves(buffer, save_dir, online_network, target_network):
)
)
)
).save(save_dir / "distribution_curves" / f"{i}_{buffer._storage[i].n_steps}.png")
).save(save_dir / "distribution_curves" / f"{i}_{buffer.storage[i].n_steps}.png")
plt.close()
21 changes: 11 additions & 10 deletions trackmania_rl/buffer_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from torchrl.data import ListStorage, ReplayBuffer
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import Storage
from torchrl.data.replay_buffers.utils import INT_CLASSES, _to_numpy
from torchrl.data.replay_buffers.utils import INT_CLASSES
from torchrl.data.replay_buffers.utils import _to_numpy as to_numpy

from config_files import config_copy

Expand Down Expand Up @@ -247,8 +248,8 @@ def update_priority(
else:
if not (isinstance(priority, float) or len(priority) == 1 or len(index) == len(priority)):
raise RuntimeError("priority should be a number or an iterable of the same length as index")
index = _to_numpy(index)
priority = _to_numpy(priority)
index = to_numpy(index)
priority = to_numpy(priority)
# We track the _approximate_ number of memories in the buffer that have default priority :
self._uninitialized_memories -= 0.3 * len(index)
priority = np.power(priority + self._eps, self._alpha)
Expand All @@ -274,17 +275,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:


def copy_buffer_content_to_other_buffer(source_buffer: ReplayBuffer, target_buffer: ReplayBuffer) -> None:
assert source_buffer._storage.max_size <= target_buffer._storage.max_size
assert source_buffer.storage.max_size <= target_buffer.storage.max_size

target_buffer.extend(source_buffer._storage._storage)
target_buffer.extend(source_buffer.storage._storage)

if isinstance(source_buffer._sampler, CustomPrioritizedSampler) and isinstance(target_buffer._sampler, CustomPrioritizedSampler):
target_buffer._sampler._average_priority = source_buffer._sampler._average_priority
target_buffer._sampler._uninitialized_memories = source_buffer._sampler._uninitialized_memories
if isinstance(source_buffer.sampler, CustomPrioritizedSampler) and isinstance(target_buffer.sampler, CustomPrioritizedSampler):
target_buffer.sampler._average_priority = source_buffer.sampler._average_priority
target_buffer.sampler._uninitialized_memories = source_buffer.sampler._uninitialized_memories

if isinstance(source_buffer._sampler, PrioritizedSampler) and isinstance(target_buffer._sampler, PrioritizedSampler):
if isinstance(source_buffer.sampler, PrioritizedSampler) and isinstance(target_buffer.sampler, PrioritizedSampler):
for i in range(len(source_buffer)):
target_buffer._sampler._sum_tree[i] = source_buffer._sampler._sum_tree.at(i)
target_buffer.sampler.sum_tree[i] = source_buffer.sampler.sum_tree.at(i)


def make_buffers(buffer_size: int) -> tuple[ReplayBuffer, ReplayBuffer]:
Expand Down
20 changes: 10 additions & 10 deletions trackmania_rl/multiprocess/learner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,10 @@ def learner_process_fn(
param_group["epsilon"] = config_copy.adam_epsilon
param_group["betas"] = (config_copy.adam_beta1, config_copy.adam_beta2)

if isinstance(buffer._sampler, PrioritizedSampler):
buffer._sampler._alpha = config_copy.prio_alpha
buffer._sampler._beta = config_copy.prio_beta
buffer._sampler._eps = config_copy.prio_epsilon
if isinstance(buffer.sampler, PrioritizedSampler):
buffer.sampler._alpha = config_copy.prio_alpha
buffer.sampler._beta = config_copy.prio_beta
buffer.sampler._eps = config_copy.prio_epsilon

if config_copy.plot_race_time_left_curves and not is_explo and (loop_number // 5) % 17 == 0:
race_time_left_curves(rollout_results, inferer, save_dir, map_name)
Expand Down Expand Up @@ -495,7 +495,7 @@ def learner_process_fn(
loss, grad_norm = trainer.train_on_batch(buffer, do_learn=True)
accumulated_stats["cumul_number_single_memories_used"] += (
10 * config_copy.batch_size
if (len(buffer) < buffer._storage.max_size and buffer._storage.max_size > 200_000)
if (len(buffer) < buffer.storage.max_size and buffer.storage.max_size > 200_000)
else config_copy.batch_size
) # do fewer batches while memory is not full
train_on_batch_duration_history.append(time.perf_counter() - train_start_time)
Expand Down Expand Up @@ -575,8 +575,8 @@ def learner_process_fn(
f"{key}_max": np.max(val),
}
)
if isinstance(buffer._sampler, PrioritizedSampler):
all_priorities = np.array([buffer._sampler._sum_tree.at(i) for i in range(len(buffer))])
if isinstance(buffer.sampler, PrioritizedSampler):
all_priorities = np.array([buffer.sampler.sum_tree.at(i) for i in range(len(buffer))])
step_stats.update(
{
"priorities_min": np.min(all_priorities),
Expand Down Expand Up @@ -683,8 +683,8 @@ def learner_process_fn(
# BUFFER STATS
# ===============================================

mean_in_buffer = np.array([experience.state_float for experience in buffer._storage]).mean(axis=0)
std_in_buffer = np.array([experience.state_float for experience in buffer._storage]).std(axis=0)
mean_in_buffer = np.array([experience.state_float for experience in buffer.storage]).mean(axis=0)
std_in_buffer = np.array([experience.state_float for experience in buffer.storage]).std(axis=0)

print("Raw mean in buffer :", mean_in_buffer.round(1))
print("Raw std in buffer :", std_in_buffer.round(1))
Expand All @@ -702,7 +702,7 @@ def learner_process_fn(
# ===============================================
# HIGH PRIORITY TRANSITIONS
# ===============================================
if config_copy.make_highest_prio_figures and isinstance(buffer._sampler, PrioritizedSampler):
if config_copy.make_highest_prio_figures and isinstance(buffer.sampler, PrioritizedSampler):
highest_prio_transitions(buffer, save_dir)

# ===============================================
Expand Down
12 changes: 6 additions & 6 deletions trackmania_rl/tmi_interaction/game_instance_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt

self.request_speed(self.running_speed)
if self.msgtype_response_to_wakeup_TMI is not None:
self.iface._respond_to_call(self.msgtype_response_to_wakeup_TMI)
self.iface.respond_to_call(self.msgtype_response_to_wakeup_TMI)
self.msgtype_response_to_wakeup_TMI = None

self.last_rollout_crashed = False
Expand Down Expand Up @@ -579,7 +579,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt
# END ON RUN STEP
# ============================
if self.msgtype_response_to_wakeup_TMI is None:
self.iface._respond_to_call(msgtype)
self.iface.respond_to_call(msgtype)

if _time > 0 and this_rollout_has_seen_t_negative:
if _time % 40 == 0:
Expand Down Expand Up @@ -667,11 +667,11 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt
# END ON CP COUNT
# ============================
if self.msgtype_response_to_wakeup_TMI is None:
self.iface._respond_to_call(msgtype)
self.iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
self.iface._read_int32()
self.iface._read_int32()
self.iface._respond_to_call(msgtype)
self.iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
frame = self.grab_screen()
frame_expected = False
Expand Down Expand Up @@ -728,7 +728,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt
self.game_spawning_lock.release()

instrumentation__request_inputs_and_speed += time.perf_counter_ns() - pc8
self.iface._respond_to_call(msgtype)
self.iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
self.iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
Expand All @@ -746,7 +746,7 @@ def rollout(self, exploration_policy: Callable, map_path: str, zone_centers: npt
if self.iface.is_in_menus() and map_path != self.latest_map_path_requested:
print("Requested map load")
self.request_map(map_path, zone_centers)
self.iface._respond_to_call(msgtype)
self.iface.respond_to_call(msgtype)
else:
pass
except socket.timeout as err:
Expand Down
2 changes: 1 addition & 1 deletion trackmania_rl/tmi_interaction/tminterface2.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def get_inputs(self):
string_length = self._read_int32()
return self.sock.recv(string_length, socket.MSG_WAITALL).decode("utf-8")

def _respond_to_call(self, response_type):
def respond_to_call(self, response_type):
self.sock.sendall(struct.pack("i", np.int32(response_type)))

def _read_int32(self):
Expand Down

0 comments on commit 07c6fc3

Please sign in to comment.