Skip to content

Commit

Permalink
Replace pyvrp-based twoopt with custom one from DeepACO
Browse files Browse the repository at this point in the history
  • Loading branch information
hyeok9855 committed Apr 26, 2024
1 parent 12b702c commit 7ac643d
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 241 deletions.
183 changes: 183 additions & 0 deletions examples/advanced/3-local-search.ipynb

Large diffs are not rendered by default.

183 changes: 0 additions & 183 deletions examples/modeling/4-solution-improvement.ipynb

This file was deleted.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ dependencies = [
graph = ["torch_geometric"]
testing = ["pytest", "pytest-cov"]
dev = ["black", "ruff", "pre-commit>=3.3.3"]
routing = ["pyvrp>=0.7.0"]
routing = ["numba>=0.58.1", "pyvrp>=0.8.2"]

[project.urls]
"Homepage" = "https://github.com/ai4co/rl4co"
Expand Down
4 changes: 2 additions & 2 deletions rl4co/envs/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,11 @@ def check_solution_validity(self, td: TensorDict, actions: torch.Tensor) -> None
"""
raise NotImplementedError

def improve_solution(self, td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor:
def local_search(self, td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor:
"""Function to improve the solution. Can be called by the agent to improve the current state
This is called with the full solution (i.e. all actions) at the end of the episode
"""
raise NotImplementedError
raise NotImplementedError(f"Local is not implemented yet for {self.name} environment")

def dataset(self, batch_size=[], phase="train", filename=None):
"""Return a dataset of observations
Expand Down
117 changes: 63 additions & 54 deletions rl4co/envs/routing/tsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,10 @@
from rl4co.utils.ops import gather_by_index, get_distance_matrix, get_tour_length
from rl4co.utils.pylogger import get_pylogger


try: # for local search
import pyvrp
except ImportError:
pyvrp = None
else:
import numpy as np

from pyvrp import Client, CostEvaluator, ProblemData, RandomNumberGenerator, Solution, VehicleType
from pyvrp.search import LocalSearch, TwoOpt, NeighbourhoodParams, compute_neighbours
from pyvrp.read import scale_and_truncate_to_decimals

# For local search
import concurrent.futures
import numpy as np
import numba as nb

log = get_pylogger(__name__)

Expand Down Expand Up @@ -175,48 +167,30 @@ def check_solution_validity(td: TensorDict, actions: torch.Tensor) -> None:
).all(), "Invalid tour"

@staticmethod
def improve_solution(td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor:
"""Improve the solution using local search"""
assert pyvrp is not None, (
"`pyvrp` not found. Please install pyvrp using instructions from "
"https://github.com/PyVRP/PyVRP or just use pip install pyvrp."
)

rng = RandomNumberGenerator(kwargs.get("seed") or 0)

action_list = actions.cpu().detach().numpy().tolist()
scaled_locs = scale_and_truncate_to_decimals(td["locs"].cpu().detach().numpy(), 7)
scaled_distmat = scale_and_truncate_to_decimals(
get_distance_matrix(td["locs"]).cpu().detach().numpy(), 7
)

improved_action_list = []
for _locs, _distmat, _acts in zip(scaled_locs, scaled_distmat, action_list):
# Create a pyvrp problem data object
data = ProblemData(
clients=[Client(x=loc[0], y=loc[1]) for loc in _locs[1:]],
depots=[Client(x=_locs[0, 0], y=_locs[0, 1])],
vehicle_types=[VehicleType()], # nothing need to be specified for TSP
distance_matrix=_distmat,
duration_matrix=np.zeros((len(_locs), len(_locs)), dtype=np.float32),
)

# Create a pyvrp local search operator with 2-opt
neighbours = compute_neighbours(data, NeighbourhoodParams(nb_granular=data.num_clients))
local_search = LocalSearch(data=data, rng=rng, neighbours=neighbours)
local_search.add_node_operator(TwoOpt(data=data))

# Create a pyvrp solution object from the actions
depot_action = _acts.index(0)
solution = Solution(data=data, routes=[_acts[depot_action + 1:] + _acts[:depot_action]])

# Run the local search
improved_solution = local_search.search(solution=solution, cost_evaluator=CostEvaluator(0, 0))
improved_action_list.append([0] + improved_solution.get_routes()[0].visits())

improved_actions = torch.LongTensor(improved_action_list).to(actions.device)

return improved_actions
def local_search(td: TensorDict, actions: torch.Tensor, **kwargs) -> torch.Tensor:
"""
Improve the solution using local search, especially 2-opt for TSP.
Implementation credits to: https://github.com/henry-yeh/DeepACO
Args:
td: TensorDict, td from env with shape [batch_size,]
actions: torch.Tensor, Tour indices with shape [batch_size, num_loc]
max_iterations: int, maximum number of iterations for 2-opt
distances: torch.Tensor, distance matrix with shape [batch_size, num_loc, num_loc]
if None, it will be calculated from td["locs"]
"""
max_iterations = kwargs.get("max_iterations", 1000)

dists = kwargs.get("distances", None) or get_distance_matrix(td["locs"]).detach().cpu().numpy()
dists = dists + 1e9 * np.eye(dists.shape[1], dtype=np.float32)[None, :, :] # fill diagonal with large number

tours = actions.detach().cpu().numpy().astype(np.uint16)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = []
for dist, tour in zip(dists, tours):
future = executor.submit(_two_opt_python, distmat=dist, tour=tour, max_iterations=max_iterations)
futures.append(future)
return torch.from_numpy(np.stack([f.result() for f in futures]).astype(np.int64)).to(actions.device)

def generate_data(self, batch_size) -> TensorDict:
batch_size = [batch_size] if isinstance(batch_size, int) else batch_size
Expand Down Expand Up @@ -270,3 +244,38 @@ def render(td, actions=None, ax=None):
# Setup limits and show
ax.set_xlim(-0.05, 1.05)
ax.set_ylim(-0.05, 1.05)


@nb.njit(nb.float32(nb.float32[:,:], nb.uint16[:], nb.uint16), nogil=True)
def two_opt_once(distmat, tour, fixed_i = 0):
'''in-place operation'''
n = tour.shape[0]
p = q = 0
delta = 0
for i in range(1, n - 1) if fixed_i==0 else range(fixed_i, fixed_i + 1):
for j in range(i + 1, n):
node_i, node_j = tour[i], tour[j]
node_prev, node_next = tour[i - 1], tour[(j + 1) % n]
if node_prev == node_j or node_next == node_i:
continue
change = (
distmat[node_prev, node_j] + distmat[node_i, node_next]
- distmat[node_prev, node_i] - distmat[node_j, node_next]
)
if change < delta:
p, q, delta = i, j, change
if delta < -1e-6:
tour[p: q + 1] = np.flip(tour[p: q + 1])
return delta
else:
return 0.0


@nb.njit(nb.uint16[:](nb.float32[:,:], nb.uint16[:], nb.int64), nogil=True)
def _two_opt_python(distmat, tour, max_iterations=1000):
iterations = 0
min_change = -1.0
while min_change < -1e-6 and iterations < max_iterations:
min_change = two_opt_once(distmat, tour, 0)
iterations += 1
return tour
2 changes: 1 addition & 1 deletion tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_routing(env_cls, batch_size=2, size=20):
reward, td, actions = rollout(env, env.reset(batch_size=[batch_size]), random_policy)
env.render(td, actions)
try:
env.improve_solution(td, actions)
env.local_search(td, actions)
except NotImplementedError:
pass
assert reward.shape == (batch_size,)
Expand Down

0 comments on commit 7ac643d

Please sign in to comment.