Skip to content

Commit

Permalink
Minor updates to tests and examples
Browse files Browse the repository at this point in the history
  • Loading branch information
Federico-PizarroBejarano committed Jan 15, 2025
1 parent 5f20bac commit 0c3c028
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 28 deletions.
52 changes: 26 additions & 26 deletions examples/mpsc/mpsc_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.'):

# Run without safety filter
experiment = BaseExperiment(env, ctrl)
results, uncert_metrics = experiment.run_evaluation(n_episodes=n_episodes, n_steps=n_steps)
elapsed_time_uncert = results['timestamp'][0][-1] - results['timestamp'][0][0]
uncert_results, uncert_metrics = experiment.run_evaluation(n_episodes=n_episodes, n_steps=n_steps)
elapsed_time_uncert = uncert_results['timestamp'][0][-1] - uncert_results['timestamp'][0][0]

# Setup MPSC.
config.task_config['normalized_rl_action_space'] = False
Expand Down Expand Up @@ -92,12 +92,12 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.'):

# Run with safety filter
experiment = BaseExperiment(env, ctrl, safety_filter=safety_filter)
certified_results, cert_metrics = experiment.run_evaluation(n_episodes=n_episodes, n_steps=n_steps)
cert_results, cert_metrics = experiment.run_evaluation(n_episodes=n_episodes, n_steps=n_steps)
ctrl.close()
mpsc_results = certified_results['safety_filter_data'][0]
mpsc_results = cert_results['safety_filter_data'][0]
safety_filter.close()

elapsed_time_cert = results['timestamp'][0][-1] - results['timestamp'][0][0]
elapsed_time_cert = cert_results['timestamp'][0][-1] - cert_results['timestamp'][0][0]

corrections = mpsc_results['correction'][0] > 1e-6
corrections = np.append(corrections, False)
Expand All @@ -115,10 +115,10 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.'):
graph3_2 = 2

_, ax = plt.subplots()
ax.plot(results['obs'][0][:, graph1_1], results['obs'][0][:, graph1_2], 'r--', label='Uncertified')
ax.plot(certified_results['obs'][0][:, graph1_1], certified_results['obs'][0][:, graph1_2], '.-', label='Certified')
ax.plot(certified_results['obs'][0][corrections, graph1_1], certified_results['obs'][0][corrections, graph1_2], 'r.', label='Modified')
ax.scatter(results['obs'][0][0, graph1_1], results['obs'][0][0, graph1_2], color='g', marker='o', s=100, label='Initial State')
ax.plot(uncert_results['obs'][0][:, graph1_1], uncert_results['obs'][0][:, graph1_2], 'r--', label='Uncertified')
ax.plot(cert_results['obs'][0][:, graph1_1], cert_results['obs'][0][:, graph1_2], '.-', label='Certified')
ax.plot(cert_results['obs'][0][corrections, graph1_1], cert_results['obs'][0][corrections, graph1_2], 'r.', label='Modified')
ax.scatter(uncert_results['obs'][0][0, graph1_1], uncert_results['obs'][0][0, graph1_2], color='g', marker='o', s=100, label='Initial State')
if config.task == Environment.CARTPOLE:
theta_constraint = config.task_config['constraints'][0].upper_bounds[2]
elif config.task == Environment.QUADROTOR:
Expand All @@ -132,31 +132,31 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.'):

if config.task_config.task == Task.TRAJ_TRACKING and config.task == Environment.CARTPOLE:
_, ax2 = plt.subplots()
ax2.plot(np.linspace(0, 20, certified_results['obs'][0].shape[0]), safety_filter.env.X_GOAL[:, 0], 'g--', label='Reference')
ax2.plot(np.linspace(0, 20, results['obs'][0].shape[0]), results['obs'][0][:, 0], 'r--', label='Uncertified')
ax2.plot(np.linspace(0, 20, certified_results['obs'][0].shape[0]), certified_results['obs'][0][:, 0], '.-', label='Certified')
ax2.plot(np.linspace(0, 20, certified_results['obs'][0].shape[0])[corrections], certified_results['obs'][0][corrections, 0], 'r.', label='Modified')
ax2.plot(np.linspace(0, 20, cert_results['obs'][0].shape[0]), safety_filter.env.X_GOAL[:, 0], 'g--', label='Reference')
ax2.plot(np.linspace(0, 20, uncert_results['obs'][0].shape[0]), uncert_results['obs'][0][:, 0], 'r--', label='Uncertified')
ax2.plot(np.linspace(0, 20, cert_results['obs'][0].shape[0]), cert_results['obs'][0][:, 0], '.-', label='Certified')
ax2.plot(np.linspace(0, 20, cert_results['obs'][0].shape[0])[corrections], cert_results['obs'][0][corrections, 0], 'r.', label='Modified')
ax2.set_xlabel(r'Time')
ax2.set_ylabel(r'X')
ax2.set_box_aspect(0.5)
ax2.legend(loc='upper right')
elif config.task == Environment.QUADROTOR:
_, ax2 = plt.subplots()
ax2.plot(results['obs'][0][:, 1], results['obs'][0][:, 3], 'r--', label='Uncertified')
ax2.plot(certified_results['obs'][0][:, 1], certified_results['obs'][0][:, 3], '.-', label='Certified')
ax2.plot(certified_results['obs'][0][corrections, 1], certified_results['obs'][0][corrections, 3], 'r.', label='Modified')
ax2.plot(uncert_results['obs'][0][:, 1], uncert_results['obs'][0][:, 3], 'r--', label='Uncertified')
ax2.plot(cert_results['obs'][0][:, 1], cert_results['obs'][0][:, 3], '.-', label='Certified')
ax2.plot(cert_results['obs'][0][corrections, 1], cert_results['obs'][0][corrections, 3], 'r.', label='Modified')
ax2.set_xlabel(r'x_dot')
ax2.set_ylabel(r'z_dot')
ax2.set_box_aspect(0.5)
ax2.legend(loc='upper right')

_, ax3 = plt.subplots()
ax3.plot(results['obs'][0][:, graph3_1], results['obs'][0][:, graph3_2], 'r--', label='Uncertified')
ax3.plot(certified_results['obs'][0][:, graph3_1], certified_results['obs'][0][:, graph3_2], '.-', label='Certified')
ax3.plot(uncert_results['obs'][0][:, graph3_1], uncert_results['obs'][0][:, graph3_2], 'r--', label='Uncertified')
ax3.plot(cert_results['obs'][0][:, graph3_1], cert_results['obs'][0][:, graph3_2], '.-', label='Certified')
if config.task_config.task == Task.TRAJ_TRACKING and config.task == Environment.QUADROTOR:
ax3.plot(safety_filter.env.X_GOAL[:, 0], safety_filter.env.X_GOAL[:, 2], 'g--', label='Reference')
ax3.plot(certified_results['obs'][0][corrections, graph3_1], certified_results['obs'][0][corrections, graph3_2], 'r.', label='Modified')
ax3.scatter(results['obs'][0][0, graph3_1], results['obs'][0][0, graph3_2], color='g', marker='o', s=100, label='Initial State')
ax3.plot(cert_results['obs'][0][corrections, graph3_1], cert_results['obs'][0][corrections, graph3_2], 'r.', label='Modified')
ax3.scatter(uncert_results['obs'][0][0, graph3_1], uncert_results['obs'][0][0, graph3_2], color='g', marker='o', s=100, label='Initial State')
ax3.set_xlabel(r'X')
if config.task == Environment.CARTPOLE:
ax3.set_ylabel(r'Vel')
Expand All @@ -167,16 +167,16 @@ def run(plot=True, training=False, n_episodes=1, n_steps=None, curr_path='.'):

_, ax_act = plt.subplots()
if config.task == Environment.CARTPOLE:
ax_act.plot(certified_results['current_physical_action'][0][:], 'b-', label='Certified Input')
ax_act.plot(cert_results['current_physical_action'][0][:], 'b-', label='Certified Input')
ax_act.plot(mpsc_results['uncertified_action'][0][:], 'r--', label='Attempted Input')
ax_act.plot(results['current_physical_action'][0][:], 'g--', label='Uncertified Input')
ax_act.plot(uncert_results['current_physical_action'][0][:], 'g--', label='Uncertified Input')
else:
ax_act.plot(certified_results['current_physical_action'][0][:, 0], 'b-', label='Certified Input 1')
ax_act.plot(certified_results['current_physical_action'][0][:, 1], 'b--', label='Certified Input 2')
ax_act.plot(cert_results['current_physical_action'][0][:, 0], 'b-', label='Certified Input 1')
ax_act.plot(cert_results['current_physical_action'][0][:, 1], 'b--', label='Certified Input 2')
ax_act.plot(mpsc_results['uncertified_action'][0][:, 0], 'r-', label='Attempted Input 1')
ax_act.plot(mpsc_results['uncertified_action'][0][:, 1], 'r--', label='Attempted Input 2')
ax_act.plot(results['current_physical_action'][0][:, 0], 'g-', label='Uncertified Input 1')
ax_act.plot(results['current_physical_action'][0][:, 1], 'g--', label='Uncertified Input 2')
ax_act.plot(uncert_results['current_physical_action'][0][:, 0], 'g-', label='Uncertified Input 1')
ax_act.plot(uncert_results['current_physical_action'][0][:, 1], 'g--', label='Uncertified Input 2')
ax_act.legend()
ax_act.set_title('Input comparison')
ax_act.set_xlabel('Step')
Expand Down
2 changes: 1 addition & 1 deletion tests/test_examples/test_lqr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def test_lqr(SYS, TASK, ALGO):
'--kv_overrides',
'algo_config.max_iterations=2'
]
run(gui=False, n_episodes=None, n_steps=10, save_data=False)
run(gui=False, plot=False, n_episodes=None, n_steps=10, save_data=False)
2 changes: 1 addition & 1 deletion tests/test_examples/test_mpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def test_lqr(SYS, TASK, ALGO):
'--kv_overrides',
'algo_config.max_iterations=2'
]
run(gui=False, n_episodes=None, n_steps=10, save_data=False)
run(gui=False, plot=False, n_episodes=None, n_steps=10, save_data=False)

0 comments on commit 0c3c028

Please sign in to comment.