Skip to content

Commit

Permalink
merged with main
Browse files Browse the repository at this point in the history
  • Loading branch information
LTluttmann committed Apr 4, 2024
2 parents 3c956f7 + 8bbd5d6 commit 7d0d6bf
Show file tree
Hide file tree
Showing 26 changed files with 800 additions and 281 deletions.
9 changes: 9 additions & 0 deletions configs/env/ffsp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: rl4co.envs.FFSPEnv
name: ffsp

num_stage: 3
num_machine: 4
num_job: 20
flatten_stages: False

data_dir: ${paths.root_dir}/data/ffsp
15 changes: 7 additions & 8 deletions configs/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,12 @@ seed: null
#https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
matmul_precision: "medium"

# metrics to be logged
metrics:
train: ["loss", "reward"]
val: ["reward"]
test: ["reward"]
log_on_step: True

# Set to True to generate data automatically on the first run
model:
generate_default_data: True
generate_default_data: True
# metrics to be logged
metrics:
train: ["loss", "reward"]
val: ["reward"]
test: ["reward"]
log_on_step: True
7 changes: 7 additions & 0 deletions configs/model/matnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
_target_: rl4co.models.MatNet

metrics:
train: ["loss", "reward", "max_reward"]
val: ["max_reward"]
test: ["max_reward"]
log_on_step: True
7 changes: 7 additions & 0 deletions rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from rl4co.data.dataset import TensorDictDataset
from rl4co.data.utils import load_npz_to_tensordict
from rl4co.utils.ops import get_num_starts, select_start_nodes
from rl4co.utils.pylogger import get_pylogger

log = get_pylogger(__name__)
Expand Down Expand Up @@ -173,6 +174,12 @@ def get_action_mask(self, td: TensorDict) -> torch.Tensor:
"""
raise NotImplementedError

def get_num_starts(self, td):
return get_num_starts(td, self.name)

def select_start_nodes(self, td, num_starts):
return select_start_nodes(td, self, num_starts)

def check_solution_validity(self, td, actions) -> TensorDict:
"""Function to check whether the solution is valid. Can be called by the agent to check the validity of the current state
This is called with the full solution (i.e. all actions) at the end of the episode
Expand Down
13 changes: 6 additions & 7 deletions rl4co/envs/routing/mtsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,8 @@ def _step(td: TensorDict) -> TensorDict:
# If done is True, then we make the depot available again, so that it will be selected as the next node with prob 1
available[..., 0] = torch.logical_or(done, available[..., 0])

# If current agent is different from previous agent, then we have a new subtour and reset the length, otherwise we add the new distance
current_length = torch.where(
cur_agent_idx != td["agent_idx"],
torch.zeros_like(td["current_length"]),
td["current_length"] + get_distance(cur_loc, prev_loc),
)
# Update the current length
current_length = td["current_length"] + get_distance(cur_loc, prev_loc)

# If done, we add the distance from the current_node to the depot as well
current_length = torch.where(
Expand All @@ -109,6 +105,9 @@ def _step(td: TensorDict) -> TensorDict:
td["max_subtour_length"],
)

# If current agent is different from previous agent, then we have a new subtour and reset the length
current_length *= (cur_agent_idx == td["agent_idx"]).float()

# The reward is the negative of the max_subtour_length (minmax objective)
reward = -max_subtour_length

Expand Down Expand Up @@ -225,7 +224,7 @@ def get_reward(self, td, actions=None) -> TensorDict:
return td["reward"].squeeze(-1)

# With distance, same as TSP
elif self.cost_type == "distance":
elif self.cost_type == "sum":
locs = td["locs"]
locs_ordered = locs.gather(1, actions.unsqueeze(-1).expand_as(locs))
return -get_tour_length(locs_ordered)
Expand Down
Loading

0 comments on commit 7d0d6bf

Please sign in to comment.