diff --git a/stemflow/model/AdaSTEM.py b/stemflow/model/AdaSTEM.py
index 61e8306..899c153 100644
--- a/stemflow/model/AdaSTEM.py
+++ b/stemflow/model/AdaSTEM.py
@@ -88,7 +88,7 @@ def __init__(
temporal_bin_interval: Union[float, int] = 50,
temporal_bin_start_jitter: Union[float, int, str] = "adaptive",
spatio_bin_jitter_magnitude: Union[float, int] = "adaptive",
- random_state = None,
+ random_state=None,
save_gridding_plot: bool = True,
save_tmp: bool = False,
save_dir: str = "./",
@@ -105,7 +105,7 @@ def __init__(
plot_ylims: Tuple[Union[float, int], Union[float, int]] = None,
verbosity: int = 0,
plot_empty: bool = False,
- completely_random_rotation: bool = False
+ completely_random_rotation: bool = False,
):
"""Make an AdaSTEM object
@@ -210,7 +210,7 @@ def __init__(
# 1. Check random state
self.random_state = random_state
self.rng = check_random_state(random_state)
-
+
# 2. Base model
check_base_model(base_model)
base_model = model_wrapper(base_model)
@@ -279,9 +279,7 @@ def __init__(
else:
self.verbosity = 0
- def split(
- self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1
- ):
+ def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1):
"""QuadTree indexing the input data
Args:
@@ -326,11 +324,7 @@ def split(
grid_len_lower,
)
- check_temporal_scale(
- X_train[self.Temporal1].min(),
- X_train[self.Temporal1].min(),
- self.temporal_bin_interval
- )
+ check_temporal_scale(X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval)
spatio_bin_jitter_magnitude = check_transform_spatio_bin_jitter_magnitude(
X_train, self.Spatio1, self.Spatio2, self.spatio_bin_jitter_magnitude
@@ -372,7 +366,7 @@ def split(
Spatio2=self.Spatio2,
save_gridding_plot=self.save_gridding_plot,
ax=ax,
- completely_random_rotation=self.completely_random_rotation
+ completely_random_rotation=self.completely_random_rotation,
)
if njobs > 1 and isinstance(njobs, int):
@@ -380,7 +374,8 @@ def split(
output_generator = parallel(
joblib.delayed(partial_get_one_ensemble_quadtree)(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
- ) for ensemble_count in list(range(self.ensemble_fold))
+ )
+ for ensemble_count in list(range(self.ensemble_fold))
)
if verbosity > 0:
output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ")
@@ -393,13 +388,16 @@ def split(
if verbosity > 0
else range(self.ensemble_fold)
)
- ensemble_all_df_list = [partial_get_one_ensemble_quadtree(
- ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
- ) for ensemble_count in iter_func_]
+ ensemble_all_df_list = [
+ partial_get_one_ensemble_quadtree(
+ ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
+ )
+ for ensemble_count in iter_func_
+ ]
# concat
ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True)
-
+
del ensemble_all_df_list
# processing
diff --git a/stemflow/model/SphereAdaSTEM.py b/stemflow/model/SphereAdaSTEM.py
index 7af9359..1ced2ae 100644
--- a/stemflow/model/SphereAdaSTEM.py
+++ b/stemflow/model/SphereAdaSTEM.py
@@ -22,6 +22,7 @@
from ..utils.validation import (
check_base_model,
check_prediciton_aggregation,
+ check_random_state,
check_spatial_scale,
check_spatio_bin_jitter_magnitude,
check_task,
@@ -29,7 +30,6 @@
check_temporal_scale,
check_transform_njobs,
check_verbosity,
- check_random_state
)
from ..utils.wrapper import model_wrapper
from .AdaSTEM import AdaSTEM, AdaSTEMClassifier, AdaSTEMRegressor
@@ -70,7 +70,7 @@ def __init__(
temporal_bin_interval: Union[float, int] = 50,
temporal_bin_start_jitter: Union[float, int, str] = "adaptive",
spatio_bin_jitter_magnitude: Union[float, int] = "adaptive",
- random_state = None,
+ random_state=None,
save_gridding_plot: bool = True,
save_tmp: bool = False,
save_dir: str = "./",
@@ -305,7 +305,8 @@ def split(
output_generator = parallel(
joblib.delayed(partial_get_one_ensemble_sphere_quadtree)(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
- ) for ensemble_count in list(range(self.ensemble_fold))
+ )
+ for ensemble_count in list(range(self.ensemble_fold))
)
if verbosity > 0:
output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ")
@@ -321,7 +322,8 @@ def split(
ensemble_all_df_list = [
partial_get_one_ensemble_sphere_quadtree(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
- ) for ensemble_count in iter_func_
+ )
+ for ensemble_count in iter_func_
]
ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True)
diff --git a/stemflow/model_selection.py b/stemflow/model_selection.py
index fbaa369..cac533f 100644
--- a/stemflow/model_selection.py
+++ b/stemflow/model_selection.py
@@ -195,7 +195,7 @@ def ST_CV(
yield X_train, X_test, y_train, y_test
-class ST_KFold():
+class ST_KFold:
def __init__(
self,
Spatio1: str = "longitude",
@@ -238,14 +238,14 @@ def __init__(
Spatio_blocks_count = 10,
Temporal_blocks_count = 10,
random_state = 42).split(X)
-
+
for train_indexes, test_indexes in ST_KFold_generator:
X_train = X.iloc[train_indexes,:]
X_test = X.iloc[test_indexes,:]
...
-
+
```
-
+
"""
self.rng = check_random_state(random_state)
self.Spatio1 = Spatio1
@@ -254,10 +254,10 @@ def __init__(
self.Spatio_blocks_count = Spatio_blocks_count
self.Temporal_blocks_count = Temporal_blocks_count
self.n_splits = n_splits
-
+
if not (isinstance(n_splits, int) and n_splits > 0):
raise ValueError("CV should be a positive integer")
-
+
def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
"""split
@@ -281,7 +281,9 @@ def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
indexes = [
str(a) + "_" + str(b) + "_" + str(c)
for a, b, c in zip(
- np.digitize(X[self.Spatio1], Sindex1), np.digitize(X[self.Spatio2], Sindex2), np.digitize(X[self.Temporal1], Tindex1)
+ np.digitize(X[self.Spatio1], Sindex1),
+ np.digitize(X[self.Spatio2], Sindex2),
+ np.digitize(X[self.Temporal1], Tindex1),
)
]
@@ -304,4 +306,4 @@ def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
# get train set record indexes
train_indexes = list(set(range(len(indexes))) - set(test_indexes))
- yield train_indexes, test_indexes
\ No newline at end of file
+ yield train_indexes, test_indexes
diff --git a/stemflow/utils/quadtree.py b/stemflow/utils/quadtree.py
index 1a43c81..c62f2c9 100644
--- a/stemflow/utils/quadtree.py
+++ b/stemflow/utils/quadtree.py
@@ -16,7 +16,11 @@
from ..gridding.QTree import QTree
from ..gridding.QuadGrid import QuadGrid
-from .validation import check_transform_spatio_bin_jitter_magnitude, check_transform_temporal_bin_start_jitter, check_random_state
+from .validation import (
+ check_random_state,
+ check_transform_spatio_bin_jitter_magnitude,
+ check_transform_temporal_bin_start_jitter,
+)
# from tqdm.contrib.concurrent import process_map
@@ -100,7 +104,7 @@ def get_one_ensemble_quadtree(
ax=None,
plot_empty: bool = False,
rng: np.random._generator.Generator = None,
- completely_random_rotation = False
+ completely_random_rotation=False,
):
"""Generate QuadTree gridding based on the input dataframe
@@ -161,12 +165,12 @@ def get_one_ensemble_quadtree(
"""
rng = check_random_state(rng)
-
+
if completely_random_rotation:
rotation_angle = rng.uniform(0, 90)
else:
rotation_angle = (90 / size) * ensemble_count
-
+
calibration_point_x_jitter = rng.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude)
calibration_point_y_jitter = rng.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude)
@@ -177,7 +181,7 @@ def get_one_ensemble_quadtree(
step=temporal_step,
bin_interval=temporal_bin_interval,
temporal_bin_start_jitter=temporal_bin_start_jitter,
- rng=rng
+ rng=rng,
)
ensemble_all_df_list = []
diff --git a/stemflow/utils/sphere_quadtree.py b/stemflow/utils/sphere_quadtree.py
index 2e73971..97e1e1f 100644
--- a/stemflow/utils/sphere_quadtree.py
+++ b/stemflow/utils/sphere_quadtree.py
@@ -16,7 +16,6 @@
from ..gridding.Sphere_QTree import Sphere_QTree
from .quadtree import generate_temporal_bins
from .sphere.coordinate_transform import lonlat_cartesian_3D_transformer
-
from .validation import check_random_state
os.environ["MKL_NUM_THREADS"] = "1"
@@ -44,7 +43,7 @@ def get_one_ensemble_sphere_quadtree(
ax=None,
radius: Union[int, float] = 6371.0,
plot_empty: bool = False,
- rng: np.random._generator.Generator = None
+ rng: np.random._generator.Generator = None,
):
"""Generate QuadTree gridding based on the input dataframe
A function to get quadtree results for spherical indexing system. Twins to `get_ensemble_quadtree` in `quadtree.py`, Returns ensemble_df and plotting axes.
@@ -87,7 +86,7 @@ def get_one_ensemble_sphere_quadtree(
The radius of earth in km. Defaults to 6371.0.
rng:
random number generator.
-
+
Returns:
A tuple of
1. ensemble dataframe;
@@ -95,7 +94,7 @@ def get_one_ensemble_sphere_quadtree(
"""
rng = check_random_state(rng)
-
+
if spatio_bin_jitter_magnitude == "adaptive":
rotation_angle = rng.uniform(0, 90)
rotation_axis = rng.uniform(-1, 1, 3)
@@ -106,7 +105,7 @@ def get_one_ensemble_sphere_quadtree(
step=temporal_step,
bin_interval=temporal_bin_interval,
temporal_bin_start_jitter=temporal_bin_start_jitter,
- rng=rng
+ rng=rng,
)
ensemble_all_df_list = []
diff --git a/tests/test_model_selection.py b/tests/test_model_selection.py
index ce675cc..adf842d 100644
--- a/tests/test_model_selection.py
+++ b/tests/test_model_selection.py
@@ -1,4 +1,4 @@
-from stemflow.model_selection import ST_CV, ST_train_test_split, ST_KFold
+from stemflow.model_selection import ST_CV, ST_KFold, ST_train_test_split
from .set_up_data import set_up_data
diff --git a/tests/test_random_state_reproducibility.py b/tests/test_random_state_reproducibility.py
index 7bd70f0..aac89f2 100644
--- a/tests/test_random_state_reproducibility.py
+++ b/tests/test_random_state_reproducibility.py
@@ -4,9 +4,7 @@
from stemflow.model.AdaSTEM import AdaSTEM
from stemflow.model_selection import ST_train_test_split
-from .make_models import (
- make_AdaSTEMRegressor,
-)
+from .make_models import make_AdaSTEMRegressor
from .set_up_data import set_up_data
x_names, (X, y) = set_up_data()