Skip to content

Commit

Permalink
Access to a protected member of a class
Browse files Browse the repository at this point in the history
Note: ListStorage._storage is not publicly accessible
trackmania_rl/buffer_utilities.py:281:
`target_buffer.extend(source_buffer.storage._storage)`

Fixes Linesight-RL#56
  • Loading branch information
Wuodan committed Aug 21, 2024
1 parent 09d84b5 commit 22fddf6
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 85 deletions.
20 changes: 10 additions & 10 deletions scripts/tools/tmi2/add_cp_as_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,42 +29,42 @@ def main():
print(e)

while True:
msgtype = iface._read_int32()
msgtype = iface.read_int32()
# =============================================
# READ INCOMING MESSAGES
# =============================================
if msgtype == int(MessageType.SC_RUN_STEP_SYNC):
_time = iface._read_int32()
_time = iface.read_int32()
# ============================
# BEGIN ON RUN STEP
# ============================
# ============================
# END ON RUN STEP
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
current = iface._read_int32()
target = iface._read_int32()
current = iface.read_int32()
target = iface.read_int32()
# ============================
# BEGIN ON CP COUNT
# ============================
# ============================
# END ON CP COUNT
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
iface._read_int32()
iface._respond_to_call(msgtype)
iface.read_int32()
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
for i in range(0, len(checkpoint_positions), 1):
iface.execute_command(
f"add_trigger {checkpoint_positions[i][0] - 2} {checkpoint_positions[i][1] - 2} {checkpoint_positions[i][2] - 2} {checkpoint_positions[i][0] + 2} {checkpoint_positions[i][1] + 2} {checkpoint_positions[i][2] + 2}"
)
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass

Expand Down
20 changes: 10 additions & 10 deletions scripts/tools/tmi2/add_vcp_as_triggers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,34 +29,34 @@ def main():
print(e)

while True:
msgtype = iface._read_int32()
msgtype = iface.read_int32()
# =============================================
# READ INCOMING MESSAGES
# =============================================
if msgtype == int(MessageType.SC_RUN_STEP_SYNC):
_time = iface._read_int32()
_time = iface.read_int32()
# ============================
# BEGIN ON RUN STEP
# ============================
# ============================
# END ON RUN STEP
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
current = iface._read_int32()
target = iface._read_int32()
current = iface.read_int32()
target = iface.read_int32()
# ============================
# BEGIN ON CP COUNT
# ============================
# ============================
# END ON CP COUNT
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
iface._read_int32()
iface._respond_to_call(msgtype)
iface.read_int32()
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
Expand All @@ -67,7 +67,7 @@ def main():
# print(
# f"add_trigger {vcp[i][0] - 0.4:.2f} {vcp[i][1] - 0.4:.2f} {vcp[i][2] - 0.4:.2f} {vcp[i][0] + 0.4:.2f} {vcp[i][1] + 0.4:.2f} {vcp[i][2] + 0.4:.2f}"
# )
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass

Expand Down
20 changes: 10 additions & 10 deletions scripts/tools/tmi2/empty_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,39 @@ def main():
print(e)

while True:
msgtype = iface._read_int32()
msgtype = iface.read_int32()
# =============================================
# READ INCOMING MESSAGES
# =============================================
if msgtype == int(MessageType.SC_RUN_STEP_SYNC):
_time = iface._read_int32()
_time = iface.read_int32()
# ============================
# BEGIN ON RUN STEP
# ============================

# ============================
# END ON RUN STEP
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
current = iface._read_int32()
target = iface._read_int32()
current = iface.read_int32()
target = iface.read_int32()
# ============================
# BEGIN ON CP COUNT
# ============================
# ============================
# END ON CP COUNT
# ============================
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
iface._read_int32()
iface._respond_to_call(msgtype)
iface.read_int32()
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass

Expand Down
22 changes: 11 additions & 11 deletions scripts/tools/video_stuff/inputs_to_gbx.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def press_enter(N_presses=1):
time.sleep(0.1)
press_enter()
if iface.registered:
msgtype = iface._read_int32()
msgtype = iface.read_int32()
if msgtype == int(MessageType.SC_RUN_STEP_SYNC):
# print("On step")
_time = iface._read_int32()
_time = iface.read_int32()
if not give_up_signal_has_been_sent:
iface.execute_command("load " + input_files[current_input_idx])
iface.give_up()
Expand All @@ -223,25 +223,25 @@ def press_enter(N_presses=1):
# expecting_replay_file = True
iface.execute_command("finish")
# request_map(iface,args.map_path)
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_CHECKPOINT_COUNT_CHANGED_SYNC):
# print("On CP")
current = iface._read_int32()
target = iface._read_int32()
current = iface.read_int32()
target = iface.read_int32()
if current == target and not expecting_replay_file: # Run finished
expecting_replay_file = True
# press_enter()
iface.close()
# iface.prevent_simulation_finish()
else:
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_LAP_COUNT_CHANGED_SYNC):
# print("On lap")
iface._read_int32()
iface._read_int32()
iface._respond_to_call(msgtype)
iface.read_int32()
iface.read_int32()
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.SC_REQUESTED_FRAME_SYNC):
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
elif msgtype == int(MessageType.C_SHUTDOWN):
iface.close()
elif msgtype == int(MessageType.SC_ON_CONNECT_SYNC):
Expand All @@ -262,7 +262,7 @@ def press_enter(N_presses=1):
map_loaded = True
# else:
# need_to_get_out_of_menu = True
iface._respond_to_call(msgtype)
iface.respond_to_call(msgtype)
else:
pass
close_game(tm_process_id)
Expand Down
10 changes: 5 additions & 5 deletions trackmania_rl/analysis_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,15 @@ def highest_prio_transitions(buffer, save_dir):
shutil.rmtree(save_dir / "high_prio_figures", ignore_errors=True)
(save_dir / "high_prio_figures").mkdir(parents=True, exist_ok=True)

prios = [buffer._sampler._sum_tree.at(i) for i in range(len(buffer))]
prios = [buffer.sampler.sum_tree.at(i) for i in range(len(buffer))]

for high_error_idx in np.argsort(prios)[-20:]:
for idx in range(max(0, high_error_idx - 4), min(len(buffer) - 1, high_error_idx + 5)):
Image.fromarray(
np.hstack((buffer._storage[idx].state_img.squeeze(), buffer._storage[idx].next_state_img.squeeze()))
np.hstack((buffer.storage[idx].state_img.squeeze(), buffer.storage[idx].next_state_img.squeeze()))
.repeat(4, 0)
.repeat(4, 1)
).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer._storage[idx].n_steps}_{prios[idx]:.2f}.png")
).save(save_dir / "high_prio_figures" / f"{high_error_idx}_{idx}_{buffer.storage[idx].n_steps}_{prios[idx]:.2f}.png")


def get_output_and_target_for_batch(batch, online_network, target_network, num_quantiles):
Expand Down Expand Up @@ -382,7 +382,7 @@ def distribution_curves(buffer, save_dir, online_network, target_network):
np.hstack(
(
np.expand_dims(
np.hstack((buffer._storage[i].state_img.squeeze(), buffer._storage[i].next_state_img.squeeze()))
np.hstack((buffer.storage[i].state_img.squeeze(), buffer.storage[i].next_state_img.squeeze()))
.repeat(4, 0)
.repeat(4, 1),
axis=-1,
Expand All @@ -391,5 +391,5 @@ def distribution_curves(buffer, save_dir, online_network, target_network):
)
)
)
).save(save_dir / "distribution_curves" / f"{i}_{buffer._storage[i].n_steps}.png")
).save(save_dir / "distribution_curves" / f"{i}_{buffer.storage[i].n_steps}.png")
plt.close()
22 changes: 11 additions & 11 deletions trackmania_rl/buffer_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
from torchrl.data import ListStorage, ReplayBuffer
from torchrl.data.replay_buffers.samplers import PrioritizedSampler, RandomSampler
from torchrl.data.replay_buffers.storages import Storage
from torchrl.data.replay_buffers.utils import INT_CLASSES, _to_numpy
from torchrl.data.replay_buffers.utils import INT_CLASSES
from torchrl.data.replay_buffers.utils import _to_numpy as to_numpy

from config_files import config_copy

Expand All @@ -26,7 +27,6 @@
"float64": torch.float32,
}


def fast_collate_cpu(batch, attr_name):
elem = getattr(batch[0], attr_name)
elem_array = hasattr(elem, "__len__")
Expand Down Expand Up @@ -239,8 +239,8 @@ def update_priority(self, index: Union[int, torch.Tensor], priority: Union[float
else:
if not (isinstance(priority, float) or len(priority) == 1 or len(index) == len(priority)):
raise RuntimeError("priority should be a number or an iterable of the same " "length as index")
index = _to_numpy(index)
priority = _to_numpy(priority)
index = to_numpy(index)
priority = to_numpy(priority)
# We track the _approximate_ number of memories in the buffer that have default priority :
self._uninitialized_memories -= 0.3 * len(index)
priority = np.power(priority + self._eps, self._alpha)
Expand All @@ -266,17 +266,17 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:


def copy_buffer_content_to_other_buffer(source_buffer: ReplayBuffer, target_buffer: ReplayBuffer) -> None:
assert source_buffer._storage.max_size <= target_buffer._storage.max_size
assert source_buffer.storage.max_size <= target_buffer.storage.max_size

target_buffer.extend(source_buffer._storage._storage)
target_buffer.extend(source_buffer.storage._storage)

if isinstance(source_buffer._sampler, CustomPrioritizedSampler) and isinstance(target_buffer._sampler, CustomPrioritizedSampler):
target_buffer._sampler._average_priority = source_buffer._sampler._average_priority
target_buffer._sampler._uninitialized_memories = source_buffer._sampler._uninitialized_memories
if isinstance(source_buffer.sampler, CustomPrioritizedSampler) and isinstance(target_buffer.sampler, CustomPrioritizedSampler):
target_buffer.sampler._average_priority = source_buffer.sampler.average_priority
target_buffer.sampler._uninitialized_memories = source_buffer.sampler.uninitialized_memories

if isinstance(source_buffer._sampler, PrioritizedSampler) and isinstance(target_buffer._sampler, PrioritizedSampler):
if isinstance(source_buffer.sampler, PrioritizedSampler) and isinstance(target_buffer.sampler, PrioritizedSampler):
for i in range(len(source_buffer)):
target_buffer._sampler._sum_tree[i] = source_buffer._sampler._sum_tree.at(i)
target_buffer.sampler.sum_tree[i] = source_buffer.sampler.sum_tree.at(i)


def make_buffers(buffer_size: int) -> tuple[ReplayBuffer, ReplayBuffer]:
Expand Down
20 changes: 10 additions & 10 deletions trackmania_rl/multiprocess/learner_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,10 +280,10 @@ def learner_process_fn(
param_group["epsilon"] = config_copy.adam_epsilon
param_group["betas"] = (config_copy.adam_beta1, config_copy.adam_beta2)

if isinstance(buffer._sampler, PrioritizedSampler):
buffer._sampler._alpha = config_copy.prio_alpha
buffer._sampler._beta = config_copy.prio_beta
buffer._sampler._eps = config_copy.prio_epsilon
if isinstance(buffer.sampler, PrioritizedSampler):
buffer.sampler._alpha = config_copy.prio_alpha
buffer.sampler._beta = config_copy.prio_beta
buffer.sampler._eps = config_copy.prio_epsilon

if config_copy.plot_race_time_left_curves and not is_explo and (loop_number // 5) % 17 == 0:
race_time_left_curves(rollout_results, inferer, save_dir, map_name)
Expand Down Expand Up @@ -494,7 +494,7 @@ def learner_process_fn(
loss, grad_norm = trainer.train_on_batch(buffer, do_learn=True)
accumulated_stats["cumul_number_single_memories_used"] += (
10 * config_copy.batch_size
if (len(buffer) < buffer._storage.max_size and buffer._storage.max_size > 200_000)
if (len(buffer) < buffer.storage.max_size and buffer.storage.max_size > 200_000)
else config_copy.batch_size
) # do fewer batches while memory is not full
train_on_batch_duration_history.append(time.perf_counter() - train_start_time)
Expand Down Expand Up @@ -574,8 +574,8 @@ def learner_process_fn(
f"{key}_max": np.max(val),
}
)
if isinstance(buffer._sampler, PrioritizedSampler):
all_priorities = np.array([buffer._sampler._sum_tree.at(i) for i in range(len(buffer))])
if isinstance(buffer.sampler, PrioritizedSampler):
all_priorities = np.array([buffer.sampler.sum_tree.at(i) for i in range(len(buffer))])
step_stats.update(
{
"priorities_min": np.min(all_priorities),
Expand Down Expand Up @@ -682,8 +682,8 @@ def learner_process_fn(
# BUFFER STATS
# ===============================================

mean_in_buffer = np.array([experience.state_float for experience in buffer._storage]).mean(axis=0)
std_in_buffer = np.array([experience.state_float for experience in buffer._storage]).std(axis=0)
mean_in_buffer = np.array([experience.state_float for experience in buffer.storage]).mean(axis=0)
std_in_buffer = np.array([experience.state_float for experience in buffer.storage]).std(axis=0)

print("Raw mean in buffer :", mean_in_buffer.round(1))
print("Raw std in buffer :", std_in_buffer.round(1))
Expand All @@ -701,7 +701,7 @@ def learner_process_fn(
# ===============================================
# HIGH PRIORITY TRANSITIONS
# ===============================================
if config_copy.make_highest_prio_figures and isinstance(buffer._sampler, PrioritizedSampler):
if config_copy.make_highest_prio_figures and isinstance(buffer.sampler, PrioritizedSampler):
highest_prio_transitions(buffer, save_dir)

# ===============================================
Expand Down
Loading

0 comments on commit 22fddf6

Please sign in to comment.