Skip to content

Commit

Permalink
feat: restructure plotting scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
RuanJohn committed Aug 28, 2024
1 parent 0864455 commit b85d822
Show file tree
Hide file tree
Showing 28 changed files with 876 additions and 224 deletions.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,27 @@
data_process_pipeline,
)

base_folder_name = "biggest-benchmark"

ENV_NAME = "MaConnector"
SAVE_PDF = False

data_dir = "data/full-benchmark-update/merged_data/interim_seed_duplicated.json"
png_plot_dir = "plots/full-benchmark-update/connector_no_retmat/png/"
pdf_plot_dir = "plots/full-benchmark-update/connector_no_retmat/pdf/"
data_dir = f"data/{base_folder_name}/merged_data/metrics_seed_processed.json"
png_plot_dir = f"plots/{base_folder_name}/connector_no_retmat/png/"
pdf_plot_dir = f"plots/{base_folder_name}/connector_no_retmat/pdf/"

PLOT_METRIC = "win_rate" # "mean_episode_return"
PLOT_METRIC = "mean_episode_return" # "mean_episode_return" "win_rate"

legend_map = {
"rec_mappo": "Rec MAPPO",
"rec_ippo": "Rec IPPO",
"ff_mappo": "FF MAPPO",
"ff_ippo": "FF IPPO",
"mat": "MAT",
# "retmat": "RetMAT",
"retmat": "RetMAT",
"retmat_memory": "RetMAT Memory",
"ff_happo": "FF HAPPO",
"rec_happo": "Rec HAPPO",
# "retmat_main_memory": "RetMAT Main Memory",
# "retmat_yarn_memory": "RetMAT Yarn Memory",
}
Expand Down Expand Up @@ -106,9 +110,13 @@
],
legend_map=legend_map,
)
fig.figure.savefig(f"{png_plot_dir}_{PLOT_METRIC}_prob_of_improvement.png", bbox_inches="tight")
fig.figure.savefig(
f"{png_plot_dir}_{PLOT_METRIC}_prob_of_improvement.png", bbox_inches="tight"
)
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}_{PLOT_METRIC}_prob_of_improvement.pdf", bbox_inches="tight")
fig.figure.savefig(
f"{pdf_plot_dir}_{PLOT_METRIC}_prob_of_improvement.pdf", bbox_inches="tight"
)

# aggregate scores
fig, _, _ = aggregate_scores( # type: ignore
Expand All @@ -117,10 +125,15 @@
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(
f"{png_plot_dir}_{PLOT_METRIC}_aggregate_scores.png", bbox_inches="tight"
)
fig.figure.savefig(f"{png_plot_dir}_{PLOT_METRIC}_aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}_{PLOT_METRIC}_aggregate_scores.pdf", bbox_inches="tight")
fig.figure.savefig(
f"{pdf_plot_dir}_{PLOT_METRIC}_aggregate_scores.pdf", bbox_inches="tight"
)

# performance profiles
fig = performance_profiles(
Expand All @@ -129,9 +142,13 @@
metrics_to_normalize=METRICS_TO_NORMALIZE,
legend_map=legend_map,
)
fig.figure.savefig(f"{png_plot_dir}_{PLOT_METRIC}_performance_profile.png", bbox_inches="tight")
fig.figure.savefig(
f"{png_plot_dir}_{PLOT_METRIC}_performance_profile.png", bbox_inches="tight"
)
if SAVE_PDF:
fig.figure.savefig(f"{pdf_plot_dir}_{PLOT_METRIC}_performance_profile.pdf", bbox_inches="tight")
fig.figure.savefig(
f"{pdf_plot_dir}_{PLOT_METRIC}_performance_profile.pdf", bbox_inches="tight"
)


##############################
Expand Down
13 changes: 9 additions & 4 deletions plot_data_lbf.py → data_plotting_scripts/plot_data_lbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@
data_process_pipeline,
)

base_folder_name = "biggest-benchmark"

ENV_NAME = "LevelBasedForaging"
SAVE_PDF = False

data_dir = "data/full-benchmark-update/merged_data/interim_seed_duplicated.json"
png_plot_dir = "plots/full-benchmark-update/lbf_no_retmat/png/"
pdf_plot_dir = "plots/full-benchmark-update/lbf_no_retmat/pdf/"
data_dir = f"data/{base_folder_name}/merged_data/metrics_seed_processed.json"
png_plot_dir = f"plots/{base_folder_name}/lbf_no_retmat/png/"
pdf_plot_dir = f"plots/{base_folder_name}/lbf_no_retmat/pdf/"

legend_map = {
"rec_mappo": "Rec MAPPO",
"rec_ippo": "Rec IPPO",
"ff_mappo": "FF MAPPO",
"ff_ippo": "FF IPPO",
"mat": "MAT",
# "retmat": "RetMAT",
"retmat": "RetMAT",
"retmat_memory": "RetMAT Memory",
"ff_happo": "FF HAPPO",
"rec_happo": "Rec HAPPO",
# "retmat_main_memory": "RetMAT Main Memory",
# "retmat_yarn_memory": "RetMAT Yarn Memory",
}
Expand Down Expand Up @@ -115,6 +119,7 @@
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(f"{png_plot_dir}aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
Expand Down
26 changes: 22 additions & 4 deletions plot_data_mabrax.py → data_plotting_scripts/plot_data_mabrax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,24 +30,40 @@
data_process_pipeline,
)

base_folder_name = "biggest-benchmark-sac"

ENV_NAME = "MaBrax"
SAVE_PDF = False

data_dir = "data/full-benchmark-update/merged_data/interim_seed_duplicated.json"
png_plot_dir = "plots/full-benchmark-update/mabrax_no_retmat/png/"
pdf_plot_dir = "plots/full-benchmark-update/mabrax_no_retmat/pdf/"
data_dir = f"data/{base_folder_name}/merged_data/metrics_seed_processed.json"
png_plot_dir = f"plots/{base_folder_name}/mabrax/png/"
pdf_plot_dir = f"plots/{base_folder_name}/mabrax/pdf/"

legend_map = {
"rec_mappo": "Rec MAPPO",
"rec_ippo": "Rec IPPO",
"ff_mappo": "FF MAPPO",
"ff_ippo": "FF IPPO",
"mat": "MAT",
# "retmat": "RetMAT",
"retmat": "RetMAT",
"retmat_memory": "RetMAT Memory",
"ff_happo": "FF HAPPO",
"rec_happo": "Rec HAPPO",
"ff_masac": "FF MASAC",
"ff_hasac": "FF HASAC",
# "retmat_main_memory": "RetMAT Main Memory",
# "retmat_yarn_memory": "RetMAT Yarn Memory",
}
# legend_map = {
# "retmat_cont_memory_single-device-64-envs_mava-cont-system-lr-decay": "mava-nets-tpu-64-envs",
# "retmat_cont_memory_no-lr-decay-base": "default-no-lr-decay",
# "retmat_cont_memory_mava-cont-system_no-lr-decay": "mava-nets-no-lr-decay",
# "retmat_cont_memory_mava-cont-system_lr-decay": "mava-nets",
# "retmat_cont_memory_increase-epochs_mava-cont-system-lr-decay": "mava-nets-increase-epochs",
# "retmat_cont_memory_double-lr_mava-cont-system-lr-decay": "mava-nets-double-lr",
# "retmat_cont_memory_on-gpu-64-envs_mava-cont-system-lr-decay": "mava-nets-gpu-64-envs",
# }
# base_algo = "retmat_cont_memory_on-gpu-64-envs_mava-cont-system-lr-decay"

##############################
# Read in and process data
Expand Down Expand Up @@ -102,6 +118,7 @@
# ["retmat_memory", "retmat"],
# ["retmat_yarn_memory", "mat"],
],
# algorithms_to_compare = [[base_algo, other_algo] for other_algo in legend_map.keys() if other_algo != base_algo],
legend_map=legend_map,
)
fig.figure.savefig(f"{png_plot_dir}prob_of_improvement.png", bbox_inches="tight")
Expand All @@ -115,6 +132,7 @@
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(f"{png_plot_dir}aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
Expand Down
13 changes: 9 additions & 4 deletions plot_data_rware.py → data_plotting_scripts/plot_data_rware.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@
data_process_pipeline,
)

base_folder_name = "biggest-benchmark"

ENV_NAME = "RobotWarehouse"
SAVE_PDF = False

data_dir = "data/full-benchmark-update/merged_data/interim_seed_duplicated.json"
png_plot_dir = "plots/full-benchmark-update/rware_no_retmat/png/"
pdf_plot_dir = "plots/full-benchmark-update/rware_no_retmat/pdf/"
data_dir = f"data/{base_folder_name}/merged_data/metrics_seed_processed.json"
png_plot_dir = f"plots/{base_folder_name}/rware_no_retmat/png/"
pdf_plot_dir = f"plots/{base_folder_name}/rware_no_retmat/pdf/"

legend_map = {
"rec_mappo": "Rec MAPPO",
"rec_ippo": "Rec IPPO",
"ff_mappo": "FF MAPPO",
"ff_ippo": "FF IPPO",
"mat": "MAT",
# "retmat": "RetMAT",
"retmat": "RetMAT",
"retmat_memory": "RetMAT Memory",
"ff_happo": "FF HAPPO",
"rec_happo": "Rec HAPPO",
# "retmat_main_memory": "RetMAT Main Memory",
# "retmat_yarn_memory": "RetMAT Yarn Memory",
}
Expand Down Expand Up @@ -115,6 +119,7 @@
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(f"{png_plot_dir}aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
Expand Down
13 changes: 9 additions & 4 deletions plot_data_smax.py → data_plotting_scripts/plot_data_smax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,25 @@
data_process_pipeline,
)

base_folder_name = "biggest-benchmark"

ENV_NAME = "Smax"
SAVE_PDF = False

data_dir = "data/full-benchmark-update/merged_data/interim_seed_duplicated.json"
png_plot_dir = "plots/full-benchmark-update/smax_no_retmat/png/"
pdf_plot_dir = "plots/full-benchmark-update/smax_no_retmat/pdf/"
data_dir = f"data/{base_folder_name}/merged_data/metrics_seed_processed.json"
png_plot_dir = f"plots/{base_folder_name}/smax_no_retmat/png/"
pdf_plot_dir = f"plots/{base_folder_name}/smax_no_retmat/pdf/"

legend_map = {
"rec_mappo": "Rec MAPPO",
"rec_ippo": "Rec IPPO",
"ff_mappo": "FF MAPPO",
"ff_ippo": "FF IPPO",
"mat": "MAT",
# "retmat": "RetMAT",
"retmat": "RetMAT",
"retmat_memory": "RetMAT Memory",
"ff_happo": "FF HAPPO",
"rec_happo": "Rec HAPPO",
# "retmat_main_memory": "RetMAT Main Memory",
# "retmat_yarn_memory": "RetMAT Yarn Memory",
}
Expand Down Expand Up @@ -115,6 +119,7 @@
metrics_to_normalize=METRICS_TO_NORMALIZE,
save_tabular_as_latex=True,
legend_map=legend_map,
tabular_results_file_path=f"{png_plot_dir[:-4]}aggregated_score",
)
fig.figure.savefig(f"{png_plot_dir}aggregate_scores.png", bbox_inches="tight")
if SAVE_PDF:
Expand Down
29 changes: 29 additions & 0 deletions data_processing_scripts/check_for_absolutes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import json


def check_for_absolutes(data):
for env_name in data:
for task_name in data[env_name]:
for algo_name in data[env_name][task_name]:
for seed in data[env_name][task_name][algo_name]:
if (
"absolute_metrics"
not in data[env_name][task_name][algo_name][seed]
):
print(
f"Found absolute metrics not in {env_name}/{task_name}/{algo_name}/{seed}"
)
elif (
"absolute_metrics" in data[env_name][task_name][algo_name][seed]
# and "sac" in algo_name
):
print(
f"Found absolute metrics in {env_name}/{task_name}/{algo_name}/{seed}"
)


in_file_path = "data/biggest-benchmark-sac/merged_data/metrics_winrate_processed.json"
with open(in_file_path, "r") as file:
data = json.load(file)

check_for_absolutes(data)
File renamed without changes.
39 changes: 39 additions & 0 deletions data_processing_scripts/keep_certain_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json


def filter_json(data, tasks_to_keep):
filtered_data = {}
for env_name, env_tasks in data.items():
kept_tasks = {
task: info for task, info in env_tasks.items() if task in tasks_to_keep
}
if kept_tasks:
filtered_data[env_name] = kept_tasks
return filtered_data


base_folder_name = "biggest-benchmark-sac"

# Example usage:
input_file = f"./data/{base_folder_name}/merged_data/metrics_winrate_processed.json"
output_file = f"./data/{base_folder_name}/merged_data/metrics_seed_processed.json"
tasks_to_keep = [
"hopper_3x1",
"halfcheetah_6x1",
"walker2d_2x3",
"ant_4x2",
# "humanoid_9|8",
] # Replace with your list of tasks to keep

# Read the input JSON file
with open(input_file, "r") as f:
data = json.load(f)

# Filter the data
filtered_data = filter_json(data, tasks_to_keep)

# Write the filtered data to the output JSON file
with open(output_file, "w") as f:
json.dump(filtered_data, f, indent=2)

print(f"Filtered data has been written to {output_file}")
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ def main(json_filename, new_json_filename):
save_json(new_json_filename, data)


base_folder_name = "biggest-benchmark-sac"

# Replace 'your_file.json' with your actual JSON file name
json_filename = (
"./data/full-benchmark-update/merged_data/metrics.json"
)
new_json_filename = "./data/full-benchmark-update/merged_data/metrics_name_processed.json"
json_filename = f"./data/{base_folder_name}/merged_data/metrics.json"
new_json_filename = f"./data/{base_folder_name}/merged_data/metrics_name_processed.json"
main(json_filename, new_json_filename)
Loading

0 comments on commit b85d822

Please sign in to comment.