Skip to content

Commit

Permalink
Linesight-RL#62 Floating point numbers should not be tested for equality
Browse files Browse the repository at this point in the history
  • Loading branch information
Wuodan committed Aug 21, 2024
1 parent f651512 commit d6bcd62
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion trackmania_rl/analysis_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def distribution_curves(buffer, save_dir, online_network, target_network):
quantiles_output = np.sort(quantiles_output.ravel())
quantiles_target = np.sort(quantiles_target.ravel())

if (np.min(quantiles_output) == np.max(quantiles_output)) and (np.min(quantiles_output) == 0.0):
if np.min(quantiles_output) == np.max(quantiles_output) and _is_min_close_to_zero(quantiles_output):
# terminal transition, can't be interpreted as long term
continue

Expand Down Expand Up @@ -393,3 +393,7 @@ def distribution_curves(buffer, save_dir, online_network, target_network):
)
).save(save_dir / "distribution_curves" / f"{i}_{buffer.storage[i].n_steps}.png")
plt.close()


def _is_min_close_to_zero(quantiles_output):
return np.isclose(np.min(quantiles_output), 0.0, rtol=1e-09, atol=1e-09)

0 comments on commit d6bcd62

Please sign in to comment.