diff --git a/rl4co/envs/routing/cvrp.py b/rl4co/envs/routing/cvrp.py index d3e3763b..4a28f7bc 100644 --- a/rl4co/envs/routing/cvrp.py +++ b/rl4co/envs/routing/cvrp.py @@ -156,7 +156,7 @@ def _reset( def get_action_mask(td: TensorDict) -> torch.Tensor: # For demand steps_dim is inserted by indexing with id, for used_capacity insert node dim for broadcasting exceeds_cap = ( - td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"] + td["demand"][:, None, :] + td["used_capacity"][..., None] > td["vehicle_capacity"][..., None] ) # Nodes that cannot be visited are already visited or too much demand to be served now