Skip to content

Commit

Permalink
[Chore] better typing #237
Browse files Browse the repository at this point in the history
  • Loading branch information
fedebotu committed Jan 17, 2025
1 parent abf8516 commit e378ddc
Show file tree
Hide file tree
Showing 71 changed files with 435 additions and 407 deletions.
8 changes: 3 additions & 5 deletions rl4co/data/dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

import tensordict
import torch

Expand Down Expand Up @@ -35,7 +33,7 @@ def add_key(self, key, value):
return ExtraKeyDataset(self, value, key_name=key)

@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
def collate_fn(batch: dict | TensorDict):
"""Collate function compatible with TensorDicts that reassembles a list of dicts."""
return batch

Expand Down Expand Up @@ -66,7 +64,7 @@ def add_key(self, key, value):
return ExtraKeyDataset(self, value, key_name=key)

@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
def collate_fn(batch: dict | TensorDict):
"""Collate function compatible with TensorDicts that reassembles a list of dicts."""
return TensorDict(
{key: torch.stack([b[key] for b in batch]) for key in batch[0].keys()},
Expand Down Expand Up @@ -129,6 +127,6 @@ def add_key(self, key, value):
return self

@staticmethod
def collate_fn(batch: Union[dict, TensorDict]):
def collate_fn(batch: dict | TensorDict):
"""Equivalent to collating with `lambda x: x`"""
return batch
39 changes: 21 additions & 18 deletions rl4co/data/generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import os
import sys

from typing import List, Union

import numpy as np

from rl4co.data.utils import check_extension
Expand All @@ -20,7 +18,7 @@
"op": ["const", "unif", "dist"],
"mdpp": [None],
"pdp": [None],
"atsp": [None]
"atsp": [None],
}


Expand Down Expand Up @@ -213,29 +211,30 @@ def generate_mdpp_data(
"action_mask": available.astype(bool),
}


def generate_atsp_data(dataset_size, atsp_size, tmat_class: bool = True):
cost_matrix = np.random.uniform(size=(dataset_size, atsp_size, atsp_size))
cost_matrix[..., np.arange(atsp_size), np.arange(atsp_size)] = 0
if tmat_class:
for i in range(atsp_size):
cost_matrix = np.minimum(cost_matrix, cost_matrix[..., :, [i]] + cost_matrix[..., [i], :])
return {
"cost_matrix": cost_matrix.astype(np.float32)
}
cost_matrix = np.minimum(
cost_matrix, cost_matrix[..., :, [i]] + cost_matrix[..., [i], :]
)
return {"cost_matrix": cost_matrix.astype(np.float32)}


def generate_dataset(
filename: Union[str, List[str]] = None,
filename: str | list[str] | None = None,
data_dir: str = "data",
name: str = None,
problem: Union[str, List[str]] = "all",
name: str | None = None,
problem: str | list[str] = "all",
data_distribution: str = "all",
dataset_size: int = 10000,
graph_sizes: Union[int, List[int]] = [20, 50, 100],
graph_sizes: int | list[int] = [20, 50, 100],
overwrite: bool = False,
seed: int = 1234,
disable_warning: bool = True,
distributions_per_problem: Union[int, dict] = None,
distributions_per_problem: int | dict = None,
):
"""We keep a similar structure as in Kool et al. 2019 but save and load the data as npz
This is way faster and more memory efficient than pickle and also allows for easy transfer to TensorDict
Expand Down Expand Up @@ -266,9 +265,11 @@ def generate_dataset(
problems = distributions_per_problem
else:
problems = {
problem: distributions_per_problem[problem]
if data_distribution == "all"
else [data_distribution]
problem: (
distributions_per_problem[problem]
if data_distribution == "all"
else [data_distribution]
)
}

# Support multiple filenames if necessary
Expand All @@ -286,9 +287,11 @@ def generate_dataset(
datadir,
"{}{}{}_{}_seed{}.npz".format(
problem,
"_{}".format(distribution)
if distribution is not None
else "",
(
"_{}".format(distribution)
if distribution is not None
else ""
),
graph_size,
name,
seed,
Expand Down
17 changes: 10 additions & 7 deletions rl4co/data/transforms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import math
from typing import Union

from typing import Callable

import torch

from tensordict.tensordict import TensorDict
Expand All @@ -8,7 +10,6 @@
from rl4co.utils.ops import batchify
from rl4co.utils.pylogger import get_pylogger


log = get_pylogger(__name__)


Expand Down Expand Up @@ -91,22 +92,24 @@ def min_max_normalize(x):
return (x - x.min()) / (x.max() - x.min())


def get_augment_function(augment_fn: Union[str, callable]):
if callable(augment_fn):
def get_augment_function(augment_fn: str | Callable):
if isinstance(augment_fn, Callable):
return augment_fn
if augment_fn == "dihedral8":
return dihedral_8_augmentation_wrapper
if augment_fn == "symmetric":
return symmetric_augmentation
raise ValueError(f"Unknown augment_fn: {augment_fn}. Available options: 'symmetric', 'dihedral8' or a custom callable")
raise ValueError(
f"Unknown augment_fn: {augment_fn}. Available options: 'symmetric', 'dihedral8' or a custom callable"
)


class StateAugmentation(object):
"""Augment state by N times via symmetric rotation/reflection transform
Args:
num_augment: number of augmentations
augment_fn: augmentation function to use, e.g. 'symmetric' (default) or 'dihedral8', if callable,
augment_fn: augmentation function to use, e.g. 'symmetric' (default) or 'dihedral8', if callable,
then use the function directly. If 'dihedral8', then num_augment must be 8
first_aug_identity: whether to augment the first data point too
normalize: whether to normalize the augmented data
Expand All @@ -116,7 +119,7 @@ class StateAugmentation(object):
def __init__(
self,
num_augment: int = 8,
augment_fn: Union[str, callable] = 'symmetric',
augment_fn: str | Callable = "symmetric",
first_aug_identity: bool = True,
normalize: bool = False,
feats: list = None,
Expand Down
4 changes: 2 additions & 2 deletions rl4co/envs/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc

from typing import Callable, Union
from typing import Callable

import torch

Expand Down Expand Up @@ -33,7 +33,7 @@ def _generate(self, batch_size, **kwargs) -> TensorDict:

def get_sampler(
val_name: str,
distribution: Union[int, float, str, type, Callable],
distribution: int | float | str | type | Callable,
low: float = 0,
high: float = 1.0,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions rl4co/envs/graph/flp/generator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math

from typing import Callable, Union
from typing import Callable

import torch

Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
num_loc: int = 100,
min_loc: float = 0.0,
max_loc: float = 1.0,
loc_distribution: Union[int, float, str, type, Callable] = Uniform,
loc_distribution: int | float | str | type | Callable = Uniform,
to_choose: int = 10,
**kwargs,
):
Expand Down
6 changes: 3 additions & 3 deletions rl4co/envs/graph/mcp/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable

import torch

Expand Down Expand Up @@ -66,8 +66,8 @@ def __init__(
min_size: int = 5,
max_size: int = 15,
n_sets_to_choose: int = 10,
size_distribution: Union[int, float, str, type, Callable] = Uniform,
weight_distribution: Union[int, float, str, type, Callable] = Uniform,
size_distribution: int | float | str | type | Callable = Uniform,
weight_distribution: int | float | str | type | Callable = Uniform,
**kwargs,
):
self.num_items = num_items
Expand Down
13 changes: 6 additions & 7 deletions rl4co/envs/routing/atsp/generator.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Union, Callable
from typing import Callable

import torch

from torch.distributions import Uniform
from tensordict.tensordict import TensorDict
from torch.distributions import Uniform

from rl4co.envs.common.utils import Generator, get_sampler
from rl4co.utils.pylogger import get_pylogger
from rl4co.envs.common.utils import get_sampler, Generator

log = get_pylogger(__name__)

Expand All @@ -27,16 +27,15 @@ class ATSPGenerator(Generator):
A TensorDict with the following keys:
locs [batch_size, num_loc, 2]: locations of each customer
"""

def __init__(
self,
num_loc: int = 10,
min_dist: float = 0.0,
max_dist: float = 1.0,
dist_distribution: Union[
int, float, str, type, Callable
] = Uniform,
dist_distribution: int | float | str | type | Callable = Uniform,
tmat_class: bool = True,
**kwargs
**kwargs,
):
self.num_loc = num_loc
self.min_dist = min_dist
Expand Down
18 changes: 10 additions & 8 deletions rl4co/envs/routing/cvrp/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Union
from typing import Callable

import torch

Expand Down Expand Up @@ -57,11 +57,11 @@ def __init__(
num_loc: int = 20,
min_loc: float = 0.0,
max_loc: float = 1.0,
loc_distribution: Union[int, float, str, type, Callable] = Uniform,
depot_distribution: Union[int, float, str, type, Callable] = None,
loc_distribution: int | float | str | type | Callable = Uniform,
depot_distribution: int | float | str | type | Callable = None,
min_demand: int = 1,
max_demand: int = 10,
demand_distribution: Union[int, float, type, Callable] = Uniform,
demand_distribution: int | float | type | Callable = Uniform,
vehicle_capacity: float = 1.0,
capacity: float = None,
**kwargs,
Expand All @@ -85,9 +85,11 @@ def __init__(
if kwargs.get("depot_sampler", None) is not None:
self.depot_sampler = kwargs["depot_sampler"]
else:
self.depot_sampler = get_sampler(
"depot", depot_distribution, min_loc, max_loc, **kwargs
) if depot_distribution is not None else None
self.depot_sampler = (
get_sampler("depot", depot_distribution, min_loc, max_loc, **kwargs)
if depot_distribution is not None
else None
)

# Demand distribution
if kwargs.get("demand_sampler", None) is not None:
Expand All @@ -114,7 +116,7 @@ def __init__(
self.capacity = capacity

def _generate(self, batch_size) -> TensorDict:

# Sample locations: depot and customers
if self.depot_sampler is not None:
depot = self.depot_sampler.sample((*batch_size, 2))
Expand Down
8 changes: 4 additions & 4 deletions rl4co/envs/routing/cvrp/local_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import partial
from multiprocessing import Pool
from typing import Tuple, Union
from typing import Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -38,7 +38,7 @@ def local_search(
td: TensorDict,
actions: torch.Tensor,
max_trials: int = 10,
neighbourhood_params: Union[dict, None] = None,
neighbourhood_params: dict | None = None,
load_penalty: float = 0.2,
allow_infeasible_solution: bool = False,
seed: int = 0,
Expand Down Expand Up @@ -113,7 +113,7 @@ def local_search_single(
positions: np.ndarray,
demands: np.ndarray,
distances: np.ndarray,
neighbourhood_params: Union[dict, None] = None,
neighbourhood_params: dict | None = None,
allow_infeasible_solution: bool = False,
load_penalty: float = 0.2,
max_trials: int = 10,
Expand Down Expand Up @@ -178,7 +178,7 @@ def make_solution(data: ProblemData, path: np.ndarray) -> Solution:


def make_search_operator(
data: ProblemData, seed=0, neighbourhood_params: Union[dict, None] = None
data: ProblemData, seed=0, neighbourhood_params: dict | None = None
) -> LocalSearch:
rng = RandomNumberGenerator(seed)
neighbours = compute_neighbours(
Expand Down
21 changes: 7 additions & 14 deletions rl4co/envs/routing/cvrptw/generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Union, Callable
from typing import Callable

import torch

from torch.distributions import Uniform
from tensordict.tensordict import TensorDict
from torch.distributions import Uniform

from rl4co.envs.routing.cvrp.generator import CVRPGenerator
from rl4co.utils.ops import get_distance
Expand Down Expand Up @@ -39,22 +39,17 @@ class CVRPTWGenerator(CVRPGenerator):
durations [batch_size, num_loc]: service durations of each location
time_windows [batch_size, num_loc, 2]: time windows of each location
"""

def __init__(
self,
num_loc: int = 20,
min_loc: float = 0.0,
max_loc: float = 150.0,
loc_distribution: Union[
int, float, str, type, Callable
] = Uniform,
depot_distribution: Union[
int, float, str, type, Callable
] = None,
loc_distribution: int | float | str | type | Callable = Uniform,
depot_distribution: int | float | str | type | Callable = Uniform,
min_demand: int = 1,
max_demand: int = 10,
demand_distribution: Union[
int, float, type, Callable
] = Uniform,
demand_distribution: int | float | str | type | Callable = Uniform,
vehicle_capacity: float = 1.0,
capacity: float = None,
max_time: float = 480,
Expand Down Expand Up @@ -86,9 +81,7 @@ def _generate(self, batch_size) -> TensorDict:

## define service durations
# generate randomly (first assume service durations of 0, to be changed later)
durations = torch.zeros(
*batch_size, self.num_loc + 1, dtype=torch.float32
)
durations = torch.zeros(*batch_size, self.num_loc + 1, dtype=torch.float32)

## define time windows
# 1. get distances from depot
Expand Down
Loading

0 comments on commit e378ddc

Please sign in to comment.