Skip to content

Commit

Permalink
Resolve innocuous typing warnings related to TSPGenerator
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Aug 13, 2024
1 parent c1bc209 commit 4f8d06f
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mlrose_hiive/fitness/travelling_salesperson.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class TravellingSalesperson:
fitness function object.
"""

def __init__(self, coords: list[tuple[float, float]] = None, distances: list[tuple[int, int, float]] = None) -> None:
def __init__(self, coords: list[tuple] = None, distances: list[tuple] = None) -> None:
# Ensure that at least one of coords or distances is provided
if coords is None and distances is None:
raise ValueError("At least one of coords and distances must be specified.")
Expand Down
10 changes: 5 additions & 5 deletions mlrose_hiive/opt_probs/tsp_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,21 +44,21 @@ class TSPOpt(DiscreteOpt):
argument is ignored if fitness_fn or coords is not :code:`None`.
"""

def __init__(self, length=None, fitness_fn=None, maximize=False, coords=None,
distances=None, source_graph=None):
def __init__(self, length=None, fitness_fn=None, maximize=False, coords: list = None,
distances: list[tuple] = None, source_graph=None):
if (fitness_fn is None) and (coords is None) and (distances is None):
raise Exception("""At least one of fitness_fn, coords and"""
+ """ distances must be specified.""")
elif fitness_fn is None:
fitness_fn = TravellingSalesperson(coords=coords, distances=distances)
self.distances = distances
self.coords = coords
self.distances: list[tuple] = distances
self.coords: list | None = coords
if length is None:
if coords is not None:
length = len(coords)
elif distances is not None:
length = len(set([x for (x, _, _) in distances] + [x for (_, x, _) in distances]))
self.length = length
self.length: int = length
DiscreteOpt.__init__(self, length, fitness_fn, maximize, max_val=length,
crossover=TSPCrossover(self), mutator=GeneSwapMutator(self))

Expand Down
1 change: 1 addition & 0 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,7 @@ def test_generate_custom_size(self):

assert problem.length == size


# noinspection PyTypeChecker
class TestTSPGenerator:

Expand Down

0 comments on commit 4f8d06f

Please sign in to comment.