forked from hiive/mlrose
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor all fitness functions (peaks, queens, etc.)
- Add type hints for all methods/functions - Refactor file and method/function docstrings - General code style improvements - Vectorize suboptimal code
- Loading branch information
1 parent
a228184
commit 61946b4
Showing
29 changed files
with
837 additions
and
689 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,21 @@ | ||
"""Classes for defining fitness functions (i.e., optimization problems) for optimization algorithms.""" | ||
|
||
from .continuous_peaks import ContinuousPeaks | ||
|
||
from .flip_flop import FlipFlop | ||
|
||
from .four_peaks import FourPeaks | ||
from .six_peaks import SixPeaks | ||
from .continuous_peaks import ContinuousPeaks | ||
from .one_max import OneMax | ||
from .max_k_color import MaxKColor | ||
|
||
from .knapsack import Knapsack | ||
|
||
from .max_k_color import MaxKColor | ||
|
||
from .one_max import OneMax | ||
|
||
from .queens import Queens | ||
from .travelling_sales import TravellingSales | ||
|
||
from .six_peaks import SixPeaks | ||
|
||
from .travelling_salesperson import TravellingSalesperson | ||
|
||
from .custom_fitness import CustomFitness |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""Class defining the Discrete Peaks base fitness function for use with the Four Peaks, Six Peaks, and Custom fitness functions.""" | ||
|
||
# Authors: Genevieve Hayes (modified by Andrew Rollings, Kyle Nakamura) | ||
# License: BSD 3 clause | ||
|
||
import numpy as np | ||
|
||
|
||
class _DiscretePeaksBase: | ||
"""Base class for Discrete Peaks fitness functions. | ||
This class provides methods to determine the number of leading or trailing | ||
specific integer values within a vector. | ||
""" | ||
|
||
@staticmethod | ||
def count_leading_values(value: int, vector: np.ndarray) -> int: | ||
"""Determine the number of leading occurrences of `value` in `vector`. | ||
Parameters | ||
---------- | ||
value : int | ||
The integer value to count at the beginning of the vector. | ||
vector : np.ndarray | ||
A vector of integers. | ||
Returns | ||
------- | ||
int | ||
Number of leading occurrences of `value` in `vector`. | ||
Raises | ||
------ | ||
TypeError | ||
If `vector` is not an instance of `np.ndarray`. | ||
""" | ||
if not isinstance(vector, np.ndarray): | ||
raise TypeError(f"Expected vector to be np.ndarray, got {type(vector).__name__} instead.") | ||
|
||
# Use NumPy's cumulative sum to find leading values | ||
leading_mask = np.cumsum(vector != value) == 0 | ||
return np.sum(leading_mask) | ||
|
||
@staticmethod | ||
def count_trailing_values(value: int, vector: np.ndarray) -> int: | ||
"""Determine the number of trailing occurrences of `value` in `vector`. | ||
Parameters | ||
---------- | ||
value : int | ||
The integer value to count at the end of the vector. | ||
vector : np.ndarray | ||
A vector of integers. | ||
Returns | ||
------- | ||
int | ||
Number of trailing occurrences of `value` in `vector`. | ||
Raises | ||
------ | ||
TypeError | ||
If `vector` is not an instance of `np.ndarray`. | ||
""" | ||
if not isinstance(vector, np.ndarray): | ||
raise TypeError(f"Expected vector to be np.ndarray, got {type(vector).__name__} instead.") | ||
|
||
# Use NumPy's cumulative sum to find trailing values | ||
trailing_mask = np.cumsum(vector[::-1] != value) == 0 | ||
return np.sum(trailing_mask) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.