Skip to content

Commit

Permalink
Run black on all files
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Aug 26, 2024
1 parent 168773f commit 0ca3152
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 45 deletions.
15 changes: 8 additions & 7 deletions src/mlrose_ky/algorithms/decay/geometric_decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,11 @@ def __repr__(self) -> str:
def __eq__(self, other: object) -> bool:
if not isinstance(other, GeometricDecay):
return False
return (self.initial_temperature == other.initial_temperature
and self.decay_rate == other.decay_rate
and self.minimum_temperature == other.minimum_temperature)
return (
self.initial_temperature == other.initial_temperature
and self.decay_rate == other.decay_rate
and self.minimum_temperature == other.minimum_temperature
)

def evaluate(self, time: int) -> float:
"""
Expand All @@ -86,7 +88,7 @@ def evaluate(self, time: int) -> float:
float
The temperature parameter at the given time, respecting the minimum temperature.
"""
return max(self.initial_temperature * (self.decay_rate ** time), self.minimum_temperature)
return max(self.initial_temperature * (self.decay_rate**time), self.minimum_temperature)

def get_info(self, time: int | None = None, prefix: str = "") -> dict:
"""
Expand Down Expand Up @@ -127,9 +129,8 @@ def __new__(cls, *args, **kwargs):
Please use 'GeometricDecay' instead.
"""
warnings.warn(
"The class 'GeomDecay' is deprecated and will be removed in a future release. "
"Please use 'GeometricDecay' instead.",
"The class 'GeomDecay' is deprecated and will be removed in a future release. " "Please use 'GeometricDecay' instead.",
DeprecationWarning,
stacklevel=2
stacklevel=2,
)
return super(GeomDecay, cls).__new__(cls)
4 changes: 2 additions & 2 deletions src/mlrose_ky/decorators/short_name_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def get_short_name(v: Any) -> str:
str
The short name of the variable, if assigned; otherwise, returns the full variable name or the variable itself as a fallback.
"""
if hasattr(v, '__short_name__'):
if hasattr(v, "__short_name__"):
return v.__short_name__
elif hasattr(v, '__name__'):
elif hasattr(v, "__name__"):
return v.__name__
return v
38 changes: 19 additions & 19 deletions src/mlrose_ky/runners/skmlp_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,25 +111,25 @@ def _invoke_runner_callback(self):
)

def __init__(
self,
x_train,
y_train,
x_test,
y_test,
experiment_name,
seed,
iteration_list,
grid_search_parameters,
grid_search_scorer_method=skmt.balanced_accuracy_score,
early_stopping=True,
max_attempts=500,
n_jobs=1,
cv=5,
override_ctrl_c_handler=True,
generate_curves=True,
output_directory=None,
replay=False,
**kwargs,
self,
x_train,
y_train,
x_test,
y_test,
experiment_name,
seed,
iteration_list,
grid_search_parameters,
grid_search_scorer_method=skmt.balanced_accuracy_score,
early_stopping=True,
max_attempts=500,
n_jobs=1,
cv=5,
override_ctrl_c_handler=True,
generate_curves=True,
output_directory=None,
replay=False,
**kwargs,
):

# take a copy of the grid search parameters
Expand Down
9 changes: 1 addition & 8 deletions tests/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,7 @@
@pytest.fixture
def sample_data():
"""Return sample data for testing."""
X = np.array([
[0, 1, 0, 1],
[0, 0, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1],
[0, 0, 1, 1],
[1, 0, 0, 0]
]) # X.shape = (6, 4)
X = np.array([[0, 1, 0, 1], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 1, 1], [1, 0, 0, 0]]) # X.shape = (6, 4)

y_classifier = np.reshape(np.array([1, 1, 0, 0, 1, 1]), (X.shape[0], 1))
y_multiclass = np.array([[1, 1], [1, 0], [0, 0], [0, 0], [1, 0], [1, 1]])
Expand Down
1 change: 0 additions & 1 deletion tests/test_generators/test_flip_flop_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,3 @@ def test_generate_custom_size(self):
problem = FlipFlopGenerator.generate(SEED, size=size)

assert problem.length == size

9 changes: 7 additions & 2 deletions tests/test_neural/test_neural_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,13 @@ def test_fit_genetic_alg(self, sample_data):
hidden_nodes = [2]
bias = False
network = NeuralNetwork(
hidden_nodes=hidden_nodes, activation="identity", algorithm="genetic_alg", bias=bias, learning_rate=1, clip_max=1, max_attempts=1
hidden_nodes=hidden_nodes,
activation="identity",
algorithm="genetic_alg",
bias=bias,
learning_rate=1,
clip_max=1,
max_attempts=1,
)

node_list = [X.shape[1], *hidden_nodes, 2 if bias else 1]
Expand Down Expand Up @@ -158,4 +164,3 @@ def test_learning_curve(self):
train_sizes, train_scores, test_scores = learning_curve(network, X, y, train_sizes=train_sizes, cv=cv, scoring="accuracy")

assert not np.isnan(train_scores).any() and not np.isnan(test_scores).any()

1 change: 0 additions & 1 deletion tests/test_opt_probs/test_discrete_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,4 +185,3 @@ def test_sample_pop(self):
problem.eval_node_probs()
sample = problem.sample_pop(100)
assert np.shape(sample)[0] == 100 and np.shape(sample)[1] == 5 and 0 < np.sum(sample) < 500

2 changes: 1 addition & 1 deletion tests/test_runners/test_ga_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def test_garunner_initialization_with_additional_kwargs(self, problem, runner_kw
runner = GARunner(**runner_kwargs, **additional_kwargs)

assert runner.problem == problem
assert runner.get_runner_name() == 'ga'
assert runner.get_runner_name() == "ga"
assert runner._experiment_name == runner_kwargs["experiment_name"]
assert runner.seed == runner_kwargs["seed"]
assert runner.iteration_list == runner_kwargs["iteration_list"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runners/test_mimic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def test_mimicrunner_initialization_with_additional_kwargs(self, problem, runner
runner = MIMICRunner(**runner_kwargs, **additional_kwargs)

assert runner.problem == problem
assert runner.get_runner_name() == 'mimic'
assert runner.get_runner_name() == "mimic"
assert runner._experiment_name == runner_kwargs["experiment_name"]
assert runner.seed == runner_kwargs["seed"]
assert runner.iteration_list == runner_kwargs["iteration_list"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runners/test_nngs_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def runner_kwargs(self, data):
"max_iters": [1, 2],
"learning_rate": [0.001, 0.002],
"hidden_layer_sizes": [[2], [2, 2]],
"activation": [mlrose_ky.relu, mlrose_ky.sigmoid]
"activation": [mlrose_ky.relu, mlrose_ky.sigmoid],
}

return {
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runners/test_rhc_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_rhc_runner_initialization_with_additional_kwargs(self, problem, runner_
runner = RHCRunner(**runner_kwargs, **additional_kwargs)

assert runner.problem == problem
assert runner.get_runner_name() == 'rhc'
assert runner.get_runner_name() == "rhc"
assert runner._experiment_name == runner_kwargs["experiment_name"]
assert runner.seed == runner_kwargs["seed"]
assert runner.iteration_list == runner_kwargs["iteration_list"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_runners/test_sa_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_sarunner_initialization_with_additional_kwargs(self, problem, runner_kw
runner = SARunner(**runner_kwargs, **additional_kwargs)

assert runner.problem == problem
assert runner.get_runner_name() == 'sa'
assert runner.get_runner_name() == "sa"
assert runner._experiment_name == runner_kwargs["experiment_name"]
assert runner.seed == runner_kwargs["seed"]
assert runner.iteration_list == runner_kwargs["iteration_list"]
Expand Down

0 comments on commit 0ca3152

Please sign in to comment.