From 8826ea1a6bf186ceb90c5724c0625b8c3134821c Mon Sep 17 00:00:00 2001 From: chenyangkang Date: Thu, 24 Oct 2024 13:09:45 -0500 Subject: [PATCH] pre-commit syntax correction --- stemflow/model/AdaSTEM.py | 32 ++++++++++------------ stemflow/model/SphereAdaSTEM.py | 10 ++++--- stemflow/model_selection.py | 18 ++++++------ stemflow/utils/quadtree.py | 14 ++++++---- stemflow/utils/sphere_quadtree.py | 9 +++--- tests/test_model_selection.py | 2 +- tests/test_random_state_reproducibility.py | 4 +-- 7 files changed, 46 insertions(+), 43 deletions(-) diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py index 61e8306..899c153 100644 --- a/stemflow/model/AdaSTEM.py +++ b/stemflow/model/AdaSTEM.py @@ -88,7 +88,7 @@ def __init__( temporal_bin_interval: Union[float, int] = 50, temporal_bin_start_jitter: Union[float, int, str] = "adaptive", spatio_bin_jitter_magnitude: Union[float, int] = "adaptive", - random_state = None, + random_state=None, save_gridding_plot: bool = True, save_tmp: bool = False, save_dir: str = "./", @@ -105,7 +105,7 @@ def __init__( plot_ylims: Tuple[Union[float, int], Union[float, int]] = None, verbosity: int = 0, plot_empty: bool = False, - completely_random_rotation: bool = False + completely_random_rotation: bool = False, ): """Make an AdaSTEM object @@ -210,7 +210,7 @@ def __init__( # 1. Check random state self.random_state = random_state self.rng = check_random_state(random_state) - + # 2. Base model check_base_model(base_model) base_model = model_wrapper(base_model) @@ -279,9 +279,7 @@ def __init__( else: self.verbosity = 0 - def split( - self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1 - ): + def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1): """QuadTree indexing the input data Args: @@ -326,11 +324,7 @@ def split( grid_len_lower, ) - check_temporal_scale( - X_train[self.Temporal1].min(), - X_train[self.Temporal1].min(), - self.temporal_bin_interval - ) + check_temporal_scale(X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval) spatio_bin_jitter_magnitude = check_transform_spatio_bin_jitter_magnitude( X_train, self.Spatio1, self.Spatio2, self.spatio_bin_jitter_magnitude @@ -372,7 +366,7 @@ def split( Spatio2=self.Spatio2, save_gridding_plot=self.save_gridding_plot, ax=ax, - completely_random_rotation=self.completely_random_rotation + completely_random_rotation=self.completely_random_rotation, ) if njobs > 1 and isinstance(njobs, int): @@ -380,7 +374,8 @@ def split( output_generator = parallel( joblib.delayed(partial_get_one_ensemble_quadtree)( ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count) - ) for ensemble_count in list(range(self.ensemble_fold)) + ) + for ensemble_count in list(range(self.ensemble_fold)) ) if verbosity > 0: output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ") @@ -393,13 +388,16 @@ def split( if verbosity > 0 else range(self.ensemble_fold) ) - ensemble_all_df_list = [partial_get_one_ensemble_quadtree( - ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count) - ) for ensemble_count in iter_func_] + ensemble_all_df_list = [ + partial_get_one_ensemble_quadtree( + ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count) + ) + for ensemble_count in iter_func_ + ] # concat ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True) - + del ensemble_all_df_list # processing diff --git a/stemflow/model/SphereAdaSTEM.py b/stemflow/model/SphereAdaSTEM.py index 7af9359..1ced2ae 100644 --- a/stemflow/model/SphereAdaSTEM.py +++ b/stemflow/model/SphereAdaSTEM.py @@ -22,6 +22,7 @@ from ..utils.validation import ( check_base_model, check_prediciton_aggregation, + check_random_state, check_spatial_scale, check_spatio_bin_jitter_magnitude, check_task, @@ -29,7 +30,6 @@ check_temporal_scale, check_transform_njobs, check_verbosity, - check_random_state ) from ..utils.wrapper import model_wrapper from .AdaSTEM import AdaSTEM, AdaSTEMClassifier, AdaSTEMRegressor @@ -70,7 +70,7 @@ def __init__( temporal_bin_interval: Union[float, int] = 50, temporal_bin_start_jitter: Union[float, int, str] = "adaptive", spatio_bin_jitter_magnitude: Union[float, int] = "adaptive", - random_state = None, + random_state=None, save_gridding_plot: bool = True, save_tmp: bool = False, save_dir: str = "./", @@ -305,7 +305,8 @@ def split( output_generator = parallel( joblib.delayed(partial_get_one_ensemble_sphere_quadtree)( ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count) - ) for ensemble_count in list(range(self.ensemble_fold)) + ) + for ensemble_count in list(range(self.ensemble_fold)) ) if verbosity > 0: output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ") @@ -321,7 +322,8 @@ def split( ensemble_all_df_list = [ partial_get_one_ensemble_sphere_quadtree( ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count) - ) for ensemble_count in iter_func_ + ) + for ensemble_count in iter_func_ ] ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True) diff --git a/stemflow/model_selection.py b/stemflow/model_selection.py index fbaa369..cac533f 100644 --- a/stemflow/model_selection.py +++ b/stemflow/model_selection.py @@ -195,7 +195,7 @@ def ST_CV( yield X_train, X_test, y_train, y_test -class ST_KFold(): +class ST_KFold: def __init__( self, Spatio1: str = "longitude", @@ -238,14 +238,14 @@ def __init__( Spatio_blocks_count = 10, Temporal_blocks_count = 10, random_state = 42).split(X) - + for train_indexes, test_indexes in ST_KFold_generator: X_train = X.iloc[train_indexes,:] X_test = X.iloc[test_indexes,:] ... - + ``` - + """ self.rng = check_random_state(random_state) self.Spatio1 = Spatio1 @@ -254,10 +254,10 @@ def __init__( self.Spatio_blocks_count = Spatio_blocks_count self.Temporal_blocks_count = Temporal_blocks_count self.n_splits = n_splits - + if not (isinstance(n_splits, int) and n_splits > 0): raise ValueError("CV should be a positive integer") - + def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]: """split @@ -281,7 +281,9 @@ def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]: indexes = [ str(a) + "_" + str(b) + "_" + str(c) for a, b, c in zip( - np.digitize(X[self.Spatio1], Sindex1), np.digitize(X[self.Spatio2], Sindex2), np.digitize(X[self.Temporal1], Tindex1) + np.digitize(X[self.Spatio1], Sindex1), + np.digitize(X[self.Spatio2], Sindex2), + np.digitize(X[self.Temporal1], Tindex1), ) ] @@ -304,4 +306,4 @@ def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]: # get train set record indexes train_indexes = list(set(range(len(indexes))) - set(test_indexes)) - yield train_indexes, test_indexes \ No newline at end of file + yield train_indexes, test_indexes diff --git a/stemflow/utils/quadtree.py b/stemflow/utils/quadtree.py index 1a43c81..c62f2c9 100644 --- a/stemflow/utils/quadtree.py +++ b/stemflow/utils/quadtree.py @@ -16,7 +16,11 @@ from ..gridding.QTree import QTree from ..gridding.QuadGrid import QuadGrid -from .validation import check_transform_spatio_bin_jitter_magnitude, check_transform_temporal_bin_start_jitter, check_random_state +from .validation import ( + check_random_state, + check_transform_spatio_bin_jitter_magnitude, + check_transform_temporal_bin_start_jitter, +) # from tqdm.contrib.concurrent import process_map @@ -100,7 +104,7 @@ def get_one_ensemble_quadtree( ax=None, plot_empty: bool = False, rng: np.random._generator.Generator = None, - completely_random_rotation = False + completely_random_rotation=False, ): """Generate QuadTree gridding based on the input dataframe @@ -161,12 +165,12 @@ def get_one_ensemble_quadtree( """ rng = check_random_state(rng) - + if completely_random_rotation: rotation_angle = rng.uniform(0, 90) else: rotation_angle = (90 / size) * ensemble_count - + calibration_point_x_jitter = rng.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude) calibration_point_y_jitter = rng.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude) @@ -177,7 +181,7 @@ def get_one_ensemble_quadtree( step=temporal_step, bin_interval=temporal_bin_interval, temporal_bin_start_jitter=temporal_bin_start_jitter, - rng=rng + rng=rng, ) ensemble_all_df_list = [] diff --git a/stemflow/utils/sphere_quadtree.py b/stemflow/utils/sphere_quadtree.py index 2e73971..97e1e1f 100644 --- a/stemflow/utils/sphere_quadtree.py +++ b/stemflow/utils/sphere_quadtree.py @@ -16,7 +16,6 @@ from ..gridding.Sphere_QTree import Sphere_QTree from .quadtree import generate_temporal_bins from .sphere.coordinate_transform import lonlat_cartesian_3D_transformer - from .validation import check_random_state os.environ["MKL_NUM_THREADS"] = "1" @@ -44,7 +43,7 @@ def get_one_ensemble_sphere_quadtree( ax=None, radius: Union[int, float] = 6371.0, plot_empty: bool = False, - rng: np.random._generator.Generator = None + rng: np.random._generator.Generator = None, ): """Generate QuadTree gridding based on the input dataframe A function to get quadtree results for spherical indexing system. Twins to `get_ensemble_quadtree` in `quadtree.py`, Returns ensemble_df and plotting axes. @@ -87,7 +86,7 @@ def get_one_ensemble_sphere_quadtree( The radius of earth in km. Defaults to 6371.0. rng: random number generator. - + Returns: A tuple of
1. ensemble dataframe;
@@ -95,7 +94,7 @@ def get_one_ensemble_sphere_quadtree( """ rng = check_random_state(rng) - + if spatio_bin_jitter_magnitude == "adaptive": rotation_angle = rng.uniform(0, 90) rotation_axis = rng.uniform(-1, 1, 3) @@ -106,7 +105,7 @@ def get_one_ensemble_sphere_quadtree( step=temporal_step, bin_interval=temporal_bin_interval, temporal_bin_start_jitter=temporal_bin_start_jitter, - rng=rng + rng=rng, ) ensemble_all_df_list = [] diff --git a/tests/test_model_selection.py b/tests/test_model_selection.py index ce675cc..adf842d 100644 --- a/tests/test_model_selection.py +++ b/tests/test_model_selection.py @@ -1,4 +1,4 @@ -from stemflow.model_selection import ST_CV, ST_train_test_split, ST_KFold +from stemflow.model_selection import ST_CV, ST_KFold, ST_train_test_split from .set_up_data import set_up_data diff --git a/tests/test_random_state_reproducibility.py b/tests/test_random_state_reproducibility.py index 7bd70f0..aac89f2 100644 --- a/tests/test_random_state_reproducibility.py +++ b/tests/test_random_state_reproducibility.py @@ -4,9 +4,7 @@ from stemflow.model.AdaSTEM import AdaSTEM from stemflow.model_selection import ST_train_test_split -from .make_models import ( - make_AdaSTEMRegressor, -) +from .make_models import make_AdaSTEMRegressor from .set_up_data import set_up_data x_names, (X, y) = set_up_data()