Skip to content

Commit

Permalink
[Feat] test POMO's MS with more envs #102
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Nov 29, 2023
1 parent 91a76cd commit d4acda9
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions tests/test_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@ def test_base_policy(env_name, size=20, batch_size=2):
assert out["reward"].shape == (batch_size,)


@pytest.mark.parametrize("env_name", ["tsp", "cvrp", "pctsp", "spctsp"])
@pytest.mark.parametrize(
"env_name", ["tsp", "cvrp", "pctsp", "spctsp", "sdvrp", "op", "pdp"]
)
def test_base_policy_multistart(env_name, size=20, batch_size=2):
env, x = generate_env_data(env_name, size, batch_size)
td = env.reset(x)
policy = AutoregressivePolicy(env.name)
out = policy(td, env, decode_type="greedy_multistart", num_starts=size)
num_starts = size // 2 if env.name in ["pdp"] else size
out = policy(td, env, decode_type="greedy_multistart", num_starts=num_starts)
assert out["reward"].shape == (
batch_size * size,
batch_size * num_starts,
) # to evaluate, we could just unbatchify


Expand Down

0 comments on commit d4acda9

Please sign in to comment.