Skip to content

Commit

Permalink
streamlined args parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Jul 9, 2024
1 parent e59e1d4 commit 67f656a
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 148 deletions.
42 changes: 24 additions & 18 deletions benchmarking/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,10 @@ def summary_tables(
f.write(latex)


if __name__ == "__main__":
# parse the arguments
def parse_args(name: str, multiproblem: bool = True) -> argparse.ArgumentParser:
"""Parses the command line arguments."""
parser = argparse.ArgumentParser(
description="Visualization of benchmark results.",
description=f"Visualization of {name} results.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
Expand All @@ -455,34 +455,35 @@ def summary_tables(
help="Filenames of the results to be visualized.",
)
group = parser.add_argument_group("Include/Exclude options")
if multiproblem:
group.add_argument(
"--include-problems",
type=str,
nargs="+",
default=[],
help="List of benchmark problems patterns to plot.",
)
group.add_argument(
"--exclude-problems",
type=str,
nargs="+",
default=[],
help="List of benchmark problems patterns not to plot.",
)
group.add_argument(
"--include-methods",
type=str,
nargs="+",
default=[],
help="List of methods patterns to plot.",
)
group.add_argument(
"--include-problems",
type=str,
nargs="+",
default=[],
help="List of benchmark problems patterns to plot.",
)
group.add_argument(
"--exclude-methods",
type=str,
nargs="+",
default=[],
help="List of methods patterns not to plot.",
)
group.add_argument(
"--exclude-problems",
type=str,
nargs="+",
default=[],
help="List of benchmark problems patterns not to plot.",
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--plot",
Expand All @@ -499,7 +500,12 @@ def summary_tables(
action="store_true",
help="Generates the data files for PGFPLOTS.",
)
args = parser.parse_args()
return parser.parse_args()


if __name__ == "__main__":
# parse the arguments
args = parse_args("benchmark")
fplot, fsummary, fpgfplotstables = args.plot, args.summary, args.pgfplotstables

# load each result and plot
Expand Down
60 changes: 33 additions & 27 deletions benchmarking/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,33 +385,33 @@ def run_benchmarks(
)


if __name__ == "__main__":
# parse the arguments
def parse_args(name: str, multiproblem: bool = True) -> argparse.Namespace:
"""Parses the command-line arguments for the benchmarking script."""
parser = argparse.ArgumentParser(
description="Benchmarking of Global Optimization strategies on synthetic "
"benchmark problems.",
description=f"Benchmarking of Global Optimization strategies on {name}.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
group = parser.add_argument_group("Benchmarking options")
group.add_argument(
"--methods",
type=check_methods_arg,
nargs="+",
help="Methods to run the benchmarking on. Greedy algorithms include `ei` and "
" `myopic`. Non-myopic multi-step algorithms have the following semantic: "
"`ms-sampler.m1.m2. ...`, where `ms` stands for multi-step, `sampler` is either"
"`gh` or `mc` (for Gauss Hermite and Monte Carlo, respectively), while `m1`, "
"`m2` and so on are the number of fantasies at each stage. The overall horizon "
"of an `ms` method is the number of fantasies plus one.",
help="Methods to run. Greedy algorithms include `ei` and `myopic`. Non-myopic "
"multi-step algorithms have the following semantic: `ms-sampler.m1.m2. ...`, "
"where `ms` stands for multi-step, `sampler` is either `gh` or `mc` (for "
"Gauss-Hermite and Monte Carlo, respectively), while `m1`, `m2` and so on are "
"the number of fantasies at each stage. The overall horizon of an `ms` method "
"is the number of fantasies plus one.",
required=True,
)
group.add_argument(
"--problems",
choices=["all"] + BENCHMARK_PROBLEMS,
nargs="+",
default=["all"],
help="Problems to include in the benchmarking.",
)
if multiproblem:
group.add_argument(
"--problems",
choices=["all"] + BENCHMARK_PROBLEMS,
nargs="+",
default=["all"],
help="Problems to include in the benchmarking.",
)
group.add_argument(
"--n-trials", type=int, default=30, help="Number of trials to run per problem."
)
Expand All @@ -428,23 +428,29 @@ def run_benchmarks(
default=["cpu"],
help="List of torch devices to use, e.g., `cpu`, `cuda:0`, etc..",
)
args = parser.parse_args()
return parser.parse_args()

# if the output csv is not specified, create it, and write header if anew
if args.csv is None or args.csv == "":
args.csv = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
elif not args.csv.endswith(".csv"):
args.csv += ".csv"
if not Path(args.csv).is_file():
lock_write(args.csv, "problem;method;stage-reward;best-so-far;time")

# run the benchmarks
def create_csv_if_needed(filename: str, header: str) -> str:
"""Creates the output csv file if it does not exist."""
if filename is None or filename == "":
filename = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
elif not filename.endswith(".csv"):
filename += ".csv"
if not Path(filename).is_file():
lock_write(filename, header)
return filename


if __name__ == "__main__":
args = parse_args("synthetic/real benchmark problems")
csv = create_csv_if_needed(args.csv, "problem;method;stage-reward;best-so-far;time")
run_benchmarks(
args.methods,
args.problems,
args.n_trials,
args.seed,
args.n_jobs,
args.csv,
csv,
args.devices,
)
71 changes: 22 additions & 49 deletions mpc-tuning/plot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import os
import sys
from typing import Optional
Expand All @@ -17,7 +16,13 @@

sys.path.append(os.getcwd())

from benchmarking.plot import load_data, plot_converges, plot_timings, summarize
from benchmarking.plot import (
itertime_vs_gap,
load_data,
optimiser_convergences,
parse_args,
summary_tables,
)


def _extract_envdata(row: pd.Series) -> pd.Series:
Expand Down Expand Up @@ -77,55 +82,23 @@ def plot_envdata(df: pd.DataFrame, figtitle: Optional[str]) -> None:

if __name__ == "__main__":
# parse the arguments
parser = argparse.ArgumentParser(
description="Visualization of benchmark results.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"filenames",
type=str,
nargs="+",
help="Filenames of the results to be visualized.",
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--include",
type=str,
nargs="+",
default=[],
help="List of methods and/or benchmark patterns to plot.",
)
group.add_argument(
"--exclude",
type=str,
nargs="+",
default=[],
help="List of methods and/or benchmark patterns not to plot.",
)
group = parser.add_mutually_exclusive_group()
group.add_argument(
"--no-plot",
action="store_true",
help="Only print the summary and do not show the plots.",
)
group.add_argument(
"--no-summary",
action="store_true",
help="Only show the plot and do not print the summary.",
)
args = parser.parse_args()

setup_mpc_tuning()
args = parse_args("MPC tuning", multiproblem=False)
fplot, fsummary, fpgfplotstables = args.plot, args.summary, args.pgfplotstables

# load each result and plot
setup_mpc_tuning()
include_title = len(args.filenames) > 1
for filename in args.filenames:
title = filename if include_title else None
dataframe = load_data(filename, args.include, args.exclude)
if not args.no_plot:
plot_converges(dataframe, title, n_cols=1)
plot_timings(dataframe, title, single_problem=True)
plot_envdata(dataframe, title)
if not args.no_summary:
summarize(dataframe, title)
stitle = filename if include_title else None
dataframe = load_data(
filename, args.include_methods, [], args.exclude_methods, []
)
if fplot or fpgfplotstables:
optimiser_convergences(
dataframe, fplot, fpgfplotstables, "best-so-far", stitle
)
itertime_vs_gap(dataframe, fplot, fpgfplotstables, stitle)
# plot_envdata(dataframe, stitle)
if fsummary or fpgfplotstables:
summary_tables(dataframe, fsummary, fpgfplotstables, stitle)
plt.show()
61 changes: 7 additions & 54 deletions mpc-tuning/tune.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import argparse
import os
import sys
from collections.abc import Iterable
from datetime import datetime
from math import prod
from pathlib import Path
from typing import Any, Optional
from warnings import filterwarnings

Expand Down Expand Up @@ -32,7 +29,7 @@

# I am lazy so let's import all the helpful functions defined in benchmarking/run.py
# instead of coding them again here
from run import check_methods_arg, lock_write, run_benchmarks
from run import create_csv_if_needed, parse_args, run_benchmarks

PROBLEM_NAME = "cstr-mpc-tuning"
INIT_ITER = 5
Expand Down Expand Up @@ -433,63 +430,19 @@ def save_callback(problem: CstrMpcControllerTuning) -> str:


if __name__ == "__main__":
# parse the arguments
parser = argparse.ArgumentParser(
description="Benchmarking of Global Optimization strategies on synthetic "
"benchmark problems.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
args = parse_args("MPC tuning", multiproblem=False)
header = (
"problem;method;stage-reward;best-so-far;time;"
"env-states;env-actions;env-rewards"
)
group = parser.add_argument_group("Benchmarking options")
group.add_argument(
"--methods",
type=check_methods_arg,
nargs="+",
help="Methods to run the benchmarking on. Greedy algorithms include `ei` and "
" `myopic`. Non-myopic multi-step algorithms have the following semantic: "
"`ms-sampler.m1.m2. ...`, where `ms` stands for multi-step, `sampler` is either"
"`gh` or `mc` (for Gauss Hermite and Monte Carlo, respectively), while `m1`, "
"`m2` and so on are the number of fantasies at each stage. The overall horizon "
"of an `ms` method is the number of fantasies plus one.",
required=True,
)
group.add_argument(
"--n-trials", type=int, default=30, help="Number of trials to run per problem."
)
group = parser.add_argument_group("Simulation options")
group.add_argument(
"--n-jobs", type=int, default=2, help="Number (positive) of parallel processes."
)
group.add_argument("--seed", type=int, default=0, help="RNG seed.")
group.add_argument("--csv", type=str, default="", help="Output csv filename.")
group.add_argument(
"--devices",
type=str,
nargs="+",
default=["cpu"],
help="List of torch devices to use, e.g., `cpu`, `cuda:0`, etc..",
)
args = parser.parse_args()

# if the output csv is not specified, create it, and write header if anew
if args.csv is None or args.csv == "":
args.csv = f"results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
elif not args.csv.endswith(".csv"):
args.csv += ".csv"
if not Path(args.csv).is_file():
header = (
"problem;method;stage-reward;best-so-far;time;"
"env-states;env-actions;env-rewards"
)
lock_write(args.csv, header)

# run the benchmarks
csv = create_csv_if_needed(args.csv, header)
run_benchmarks(
args.methods,
[PROBLEM_NAME],
args.n_trials,
args.seed,
args.n_jobs,
args.csv,
csv,
args.devices,
n_init=INIT_ITER,
setup_callback=setup_mpc_tuning,
Expand Down

0 comments on commit 67f656a

Please sign in to comment.