From ca9efc1e6fc165d95d450a4a6befbb2e4b1a4efa Mon Sep 17 00:00:00 2001 From: Chuanbo Hua Date: Wed, 13 Mar 2024 15:34:19 +0900 Subject: [PATCH] [BugFix] fix the vehicle_capacity dimension mismatching bug --- rl4co/envs/routing/cvrp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rl4co/envs/routing/cvrp.py b/rl4co/envs/routing/cvrp.py index d3e3763b..4a28f7bc 100644 --- a/rl4co/envs/routing/cvrp.py +++ b/rl4co/envs/routing/cvrp.py @@ -156,7 +156,7 @@ def _reset( def get_action_mask(td: TensorDict) -> torch.Tensor: # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting exceeds_cap = ( - td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"] + td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] ) # Nodes that cannot be visited are already visited or too much demand to be served now