Skip to content

Commit

Permalink
Revert method names in .samples/* (#7)
Browse files Browse the repository at this point in the history
  • Loading branch information
knakamura13 committed Sep 3, 2024
1 parent e572c0c commit 1b63063
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions src/mlrose_ky/samples/synthetic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ class SyntheticDataGenerator:
Parameters
----------
seed : int
seed : int, optional, default=42
Random seed for reproducibility.
root_directory : str, optional
Directory to save the generated data, by default None.
root_directory : str, optional, default=None
Directory to save the generated data.
"""

def __init__(self, seed: int, root_directory: str | None = None):
self.seed = seed
self.root_directory = root_directory
def __init__(self, seed: int = 42, root_directory: str | None = None):
self.seed: int = seed
self.root_directory: str | None = root_directory

@staticmethod
def get_synthetic_features_and_classes(with_redundant_column: bool = False) -> tuple[list[str], list[str]]:
Expand All @@ -33,8 +33,8 @@ def get_synthetic_features_and_classes(with_redundant_column: bool = False) -> t
Parameters
----------
with_redundant_column : bool, optional
Whether to include a redundant column, by default False.
with_redundant_column : bool, optional, default=False
Whether to include a redundant column.
Returns
-------
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_synthetic_data(
tuple[np.ndarray, list[str], list[str], str | None]
A tuple containing the synthetic data, feature names, class labels, and output directory.
"""
synthetic_data = self._create_synthetic_data(x_dim, y_dim, add_noise, add_redundant_column)
synthetic_data = self.__create_synthetic_data(x_dim, y_dim, add_noise, add_redundant_column)
synthetic_data_array = synthetic_data.values

output_directory = None
Expand All @@ -86,6 +86,7 @@ def get_synthetic_data(
pass

features, classes = self.get_synthetic_features_and_classes(add_redundant_column)

return synthetic_data_array, features, classes, output_directory

def setup_synthetic_data_test_train(
Expand Down Expand Up @@ -119,7 +120,7 @@ def setup_synthetic_data_test_train(

return x, y, x_train, x_test, y_train, y_test

def _create_synthetic_data(self, x_dim: int, y_dim: int, add_noise: float = 0.0, add_redundant_column: bool = False) -> pd.DataFrame:
def __create_synthetic_data(self, x_dim: int, y_dim: int, add_noise: float = 0.0, add_redundant_column: bool = False) -> pd.DataFrame:
"""
Create synthetic data.
Expand Down

0 comments on commit 1b63063

Please sign in to comment.