forked from hiive/mlrose
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
In progress: Refactor all decay classes (arith, geom, etc.)
- Add type hints for all methods/functions - Refactor file and method/function docstrings - General code style improvements
- Loading branch information
1 parent
8651387
commit 3cc64b8
Showing
17 changed files
with
118 additions
and
102 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,107 +1,118 @@ | ||
"""Classes for defining decay schedules for simulated annealing.""" | ||
|
||
|
||
# Author: Genevieve Hayes | ||
# Authors: Genevieve Hayes (modified by Andrew Rollings, Kyle Nakamura) | ||
# License: BSD 3 clause | ||
|
||
|
||
class ArithDecay: | ||
class ArithmeticDecay: | ||
""" | ||
Schedule for arithmetically decaying the simulated | ||
annealing temperature parameter T according to the formula: | ||
Schedule for arithmetically decaying the temperature parameter T in a | ||
simulated annealing process, calculated using the formula: | ||
.. math:: | ||
T(t) = \\max(T_{0} - rt, T_{min}) | ||
where: | ||
* :math:`T_{0}` is the initial temperature (at time t = 0); | ||
* :math:`r` is the rate of arithmetic decay; and | ||
* :math:`T_{min}` is the minimum temperature value. | ||
where :math:`T_{0}` is the initial temperature at time `t = 0`, `r` is the decay rate, and :math:`T_{min}` | ||
is the minimum temperature. | ||
Parameters | ||
---------- | ||
init_temp: float, default: 1.0 | ||
Initial value of temperature parameter T. Must be greater than 0. | ||
decay: float, default: 0.0001 | ||
Temperature decay parameter, r. Must be greater than 0. | ||
min_temp: float, default: 0.001 | ||
Minimum value of temperature parameter. Must be greater than 0. | ||
Example | ||
------- | ||
.. highlight:: python | ||
.. code-block:: python | ||
>>> import mlrose_hiive | ||
>>> schedule = mlrose_hiive.ArithDecay(init_temp=10, decay=0.95, min_temp=1) | ||
>>> schedule.evaluate(5) | ||
5.25 | ||
initial_temperature : float | ||
Initial value of the temperature parameter T. Must be greater than 0. | ||
decay_rate : float | ||
Temperature decay parameter. Must be a positive value and less than or equal to 1. | ||
minimum_temperature : float | ||
Minimum allowable value of the temperature parameter. Must be positive and less than `initial_temperature`. | ||
Attributes | ||
---------- | ||
initial_temperature : float | ||
Stores the initial temperature. | ||
decay_rate : float | ||
Stores the rate of temperature decay. | ||
minimum_temperature : float | ||
Stores the minimum temperature. | ||
Examples | ||
-------- | ||
>>> schedule = ArithmeticDecay(initial_temperature=10, decay_rate=0.95, minimum_temperature=1) | ||
>>> schedule.evaluate(5) | ||
5.25 | ||
""" | ||
|
||
def __init__(self, init_temp=1.0, decay=0.0001, min_temp=0.001): | ||
def __init__(self, initial_temperature: float = 1.0, decay_rate: float = 0.0001, minimum_temperature: float = 0.001) -> None: | ||
self.initial_temperature: float = initial_temperature | ||
self.decay_rate: float = decay_rate | ||
self.minimum_temperature: float = minimum_temperature | ||
|
||
self.init_temp = init_temp | ||
self.decay = decay | ||
self.min_temp = min_temp | ||
if self.initial_temperature <= 0: | ||
raise ValueError("Initial temperature must be greater than 0.") | ||
if not (0 < self.decay_rate <= 1): | ||
raise ValueError("Decay rate must be greater than 0 and less than or equal to 1.") | ||
if not (0 < self.minimum_temperature < self.initial_temperature): | ||
raise ValueError("Minimum temperature must be greater than 0 and less than initial temperature.") | ||
|
||
if self.init_temp <= 0: | ||
raise Exception("""init_temp must be greater than 0.""") | ||
def __str__(self) -> str: | ||
return (f'ArithmeticDecay(initial_temperature={self.initial_temperature}, ' | ||
f'decay_rate={self.decay_rate}, ' | ||
f'minimum_temperature={self.minimum_temperature})') | ||
|
||
if (self.decay <= 0) or (self.decay > 1): | ||
raise Exception("""decay must be greater than 0.""") | ||
def __repr__(self) -> str: | ||
return self.__str__() | ||
|
||
if self.min_temp < 0: | ||
raise Exception("""min_temp must be greater than 0.""") | ||
elif self.min_temp > self.init_temp: | ||
raise Exception("""init_temp must be greater than min_temp.""") | ||
def __eq__(self, other: object) -> bool: | ||
if not isinstance(other, ArithmeticDecay): | ||
return False | ||
return (self.initial_temperature == other.initial_temperature and | ||
self.decay_rate == other.decay_rate and | ||
self.minimum_temperature == other.minimum_temperature) | ||
|
||
def evaluate(self, t): | ||
"""Evaluate the temperature parameter at time t. | ||
def evaluate(self, time: int) -> float: | ||
""" | ||
Calculate and return the temperature parameter at the given time. | ||
Parameters | ||
---------- | ||
t: int | ||
Time at which the temperature paramter T is evaluated. | ||
time : int | ||
The time at which to evaluate the temperature parameter. | ||
Returns | ||
------- | ||
temp: float | ||
Temperature parameter at time t. | ||
float | ||
The temperature parameter at the given time, respecting the minimum temperature. | ||
""" | ||
temperature = max(self.initial_temperature - (self.decay_rate * time), self.minimum_temperature) | ||
return temperature | ||
|
||
temp = self.init_temp - (self.decay * t) | ||
def get_info(self, time: int = None, prefix: str = '') -> dict: | ||
""" | ||
Generate a dictionary containing the decay schedule's settings and optionally its current value. | ||
if temp < self.min_temp: | ||
temp = self.min_temp | ||
Parameters | ||
---------- | ||
time : int, optional | ||
The time at which to evaluate the current temperature value. | ||
If provided, the current value is included in the returned dictionary. | ||
prefix : str, optional | ||
A prefix to prepend to each key in the dictionary, useful for integrating this info into larger structured data. | ||
return temp | ||
Returns | ||
------- | ||
dict | ||
A dictionary with keys reflecting the decay schedule's parameters and optionally the current temperature. | ||
""" | ||
info_prefix = f'{prefix}schedule_' if prefix else 'schedule_' | ||
|
||
def get_info__(self, t=None, prefix=''): | ||
prefix = f'_{prefix}__schedule_' if len(prefix) > 0 else 'schedule_' | ||
info = { | ||
f'{prefix}type': 'arithmetic', | ||
f'{prefix}init_temp': self.init_temp, | ||
f'{prefix}decay': self.decay, | ||
f'{prefix}min_temp': self.min_temp, | ||
f'{info_prefix}type': 'arithmetic', | ||
f'{info_prefix}initial_temperature': self.initial_temperature, | ||
f'{info_prefix}decay_rate': self.decay_rate, | ||
f'{info_prefix}minimum_temperature': self.minimum_temperature, | ||
} | ||
if t is not None: | ||
info[f'{prefix}current_value'] = self.evaluate(t) | ||
return info | ||
|
||
def __str__(self): | ||
return str(self.init_temp) | ||
if time is not None: | ||
info[f'{info_prefix}current_value'] = self.evaluate(time) | ||
|
||
def __repr__(self): | ||
return f'{self.__class__.__name__}(init_temp={self.init_temp}, ' \ | ||
f'decay={self.decay}, min_temp={self.min_temp})' | ||
|
||
def __eq__(self, other): | ||
try: | ||
return (self.__class__.__name__ == other.__class__.__name__ | ||
and self.init_temp == other.init_temp | ||
and self.decay == other.decay | ||
and self.min_temp == other.min_temp) | ||
except AttributeError: | ||
return False | ||
return info |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.