Skip to content

Commit

Permalink
pre-commit syntax correction
Browse files Browse the repository at this point in the history
  • Loading branch information
chenyangkang committed Oct 24, 2024
1 parent d1b0678 commit 8826ea1
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 43 deletions.
32 changes: 15 additions & 17 deletions stemflow/model/AdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(
temporal_bin_interval: Union[float, int] = 50,
temporal_bin_start_jitter: Union[float, int, str] = "adaptive",
spatio_bin_jitter_magnitude: Union[float, int] = "adaptive",
random_state = None,
random_state=None,
save_gridding_plot: bool = True,
save_tmp: bool = False,
save_dir: str = "./",
Expand All @@ -105,7 +105,7 @@ def __init__(
plot_ylims: Tuple[Union[float, int], Union[float, int]] = None,
verbosity: int = 0,
plot_empty: bool = False,
completely_random_rotation: bool = False
completely_random_rotation: bool = False,
):
"""Make an AdaSTEM object
Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(
# 1. Check random state
self.random_state = random_state
self.rng = check_random_state(random_state)

# 2. Base model
check_base_model(base_model)
base_model = model_wrapper(base_model)
Expand Down Expand Up @@ -279,9 +279,7 @@ def __init__(
else:
self.verbosity = 0

def split(
self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1
):
def split(self, X_train: pd.core.frame.DataFrame, verbosity: Union[None, int] = None, ax=None, njobs: int = 1):
"""QuadTree indexing the input data
Args:
Expand Down Expand Up @@ -326,11 +324,7 @@ def split(
grid_len_lower,
)

check_temporal_scale(
X_train[self.Temporal1].min(),
X_train[self.Temporal1].min(),
self.temporal_bin_interval
)
check_temporal_scale(X_train[self.Temporal1].min(), X_train[self.Temporal1].min(), self.temporal_bin_interval)

spatio_bin_jitter_magnitude = check_transform_spatio_bin_jitter_magnitude(
X_train, self.Spatio1, self.Spatio2, self.spatio_bin_jitter_magnitude
Expand Down Expand Up @@ -372,15 +366,16 @@ def split(
Spatio2=self.Spatio2,
save_gridding_plot=self.save_gridding_plot,
ax=ax,
completely_random_rotation=self.completely_random_rotation
completely_random_rotation=self.completely_random_rotation,
)

if njobs > 1 and isinstance(njobs, int):
parallel = joblib.Parallel(n_jobs=njobs, return_as="generator")
output_generator = parallel(
joblib.delayed(partial_get_one_ensemble_quadtree)(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
) for ensemble_count in list(range(self.ensemble_fold))
)
for ensemble_count in list(range(self.ensemble_fold))
)
if verbosity > 0:
output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ")
Expand All @@ -393,13 +388,16 @@ def split(
if verbosity > 0
else range(self.ensemble_fold)
)
ensemble_all_df_list = [partial_get_one_ensemble_quadtree(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
) for ensemble_count in iter_func_]
ensemble_all_df_list = [
partial_get_one_ensemble_quadtree(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
)
for ensemble_count in iter_func_
]

# concat
ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True)

del ensemble_all_df_list

# processing
Expand Down
10 changes: 6 additions & 4 deletions stemflow/model/SphereAdaSTEM.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from ..utils.validation import (
check_base_model,
check_prediciton_aggregation,
check_random_state,
check_spatial_scale,
check_spatio_bin_jitter_magnitude,
check_task,
check_temporal_bin_start_jitter,
check_temporal_scale,
check_transform_njobs,
check_verbosity,
check_random_state
)
from ..utils.wrapper import model_wrapper
from .AdaSTEM import AdaSTEM, AdaSTEMClassifier, AdaSTEMRegressor
Expand Down Expand Up @@ -70,7 +70,7 @@ def __init__(
temporal_bin_interval: Union[float, int] = 50,
temporal_bin_start_jitter: Union[float, int, str] = "adaptive",
spatio_bin_jitter_magnitude: Union[float, int] = "adaptive",
random_state = None,
random_state=None,
save_gridding_plot: bool = True,
save_tmp: bool = False,
save_dir: str = "./",
Expand Down Expand Up @@ -305,7 +305,8 @@ def split(
output_generator = parallel(
joblib.delayed(partial_get_one_ensemble_sphere_quadtree)(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
) for ensemble_count in list(range(self.ensemble_fold))
)
for ensemble_count in list(range(self.ensemble_fold))
)
if verbosity > 0:
output_generator = tqdm(output_generator, total=self.ensemble_fold, desc="Generating Ensemble: ")
Expand All @@ -321,7 +322,8 @@ def split(
ensemble_all_df_list = [
partial_get_one_ensemble_sphere_quadtree(
ensemble_count=ensemble_count, rng=np.random.default_rng(self.rng.integers(1e9) + ensemble_count)
) for ensemble_count in iter_func_
)
for ensemble_count in iter_func_
]

ensemble_df = pd.concat(ensemble_all_df_list).reset_index(drop=True)
Expand Down
18 changes: 10 additions & 8 deletions stemflow/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def ST_CV(
yield X_train, X_test, y_train, y_test


class ST_KFold():
class ST_KFold:
def __init__(
self,
Spatio1: str = "longitude",
Expand Down Expand Up @@ -238,14 +238,14 @@ def __init__(
Spatio_blocks_count = 10,
Temporal_blocks_count = 10,
random_state = 42).split(X)
for train_indexes, test_indexes in ST_KFold_generator:
X_train = X.iloc[train_indexes,:]
X_test = X.iloc[test_indexes,:]
...
```
"""
self.rng = check_random_state(random_state)
self.Spatio1 = Spatio1
Expand All @@ -254,10 +254,10 @@ def __init__(
self.Spatio_blocks_count = Spatio_blocks_count
self.Temporal_blocks_count = Temporal_blocks_count
self.n_splits = n_splits

if not (isinstance(n_splits, int) and n_splits > 0):
raise ValueError("CV should be a positive integer")

def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
"""split
Expand All @@ -281,7 +281,9 @@ def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
indexes = [
str(a) + "_" + str(b) + "_" + str(c)
for a, b, c in zip(
np.digitize(X[self.Spatio1], Sindex1), np.digitize(X[self.Spatio2], Sindex2), np.digitize(X[self.Temporal1], Tindex1)
np.digitize(X[self.Spatio1], Sindex1),
np.digitize(X[self.Spatio2], Sindex2),
np.digitize(X[self.Temporal1], Tindex1),
)
]

Expand All @@ -304,4 +306,4 @@ def split(self, X: DataFrame) -> Generator[Tuple[ndarray, ndarray], None, None]:
# get train set record indexes
train_indexes = list(set(range(len(indexes))) - set(test_indexes))

yield train_indexes, test_indexes
yield train_indexes, test_indexes
14 changes: 9 additions & 5 deletions stemflow/utils/quadtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

from ..gridding.QTree import QTree
from ..gridding.QuadGrid import QuadGrid
from .validation import check_transform_spatio_bin_jitter_magnitude, check_transform_temporal_bin_start_jitter, check_random_state
from .validation import (
check_random_state,
check_transform_spatio_bin_jitter_magnitude,
check_transform_temporal_bin_start_jitter,
)

# from tqdm.contrib.concurrent import process_map

Expand Down Expand Up @@ -100,7 +104,7 @@ def get_one_ensemble_quadtree(
ax=None,
plot_empty: bool = False,
rng: np.random._generator.Generator = None,
completely_random_rotation = False
completely_random_rotation=False,
):
"""Generate QuadTree gridding based on the input dataframe
Expand Down Expand Up @@ -161,12 +165,12 @@ def get_one_ensemble_quadtree(
"""
rng = check_random_state(rng)

if completely_random_rotation:
rotation_angle = rng.uniform(0, 90)
else:
rotation_angle = (90 / size) * ensemble_count

calibration_point_x_jitter = rng.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude)
calibration_point_y_jitter = rng.uniform(-spatio_bin_jitter_magnitude, spatio_bin_jitter_magnitude)

Expand All @@ -177,7 +181,7 @@ def get_one_ensemble_quadtree(
step=temporal_step,
bin_interval=temporal_bin_interval,
temporal_bin_start_jitter=temporal_bin_start_jitter,
rng=rng
rng=rng,
)

ensemble_all_df_list = []
Expand Down
9 changes: 4 additions & 5 deletions stemflow/utils/sphere_quadtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ..gridding.Sphere_QTree import Sphere_QTree
from .quadtree import generate_temporal_bins
from .sphere.coordinate_transform import lonlat_cartesian_3D_transformer

from .validation import check_random_state

os.environ["MKL_NUM_THREADS"] = "1"
Expand Down Expand Up @@ -44,7 +43,7 @@ def get_one_ensemble_sphere_quadtree(
ax=None,
radius: Union[int, float] = 6371.0,
plot_empty: bool = False,
rng: np.random._generator.Generator = None
rng: np.random._generator.Generator = None,
):
"""Generate QuadTree gridding based on the input dataframe
A function to get quadtree results for spherical indexing system. Twins to `get_ensemble_quadtree` in `quadtree.py`, Returns ensemble_df and plotting axes.
Expand Down Expand Up @@ -87,15 +86,15 @@ def get_one_ensemble_sphere_quadtree(
The radius of earth in km. Defaults to 6371.0.
rng:
random number generator.
Returns:
A tuple of <br>
1. ensemble dataframe;<br>
2. grid plot. np.nan if save_gridding_plot=False<br>
"""
rng = check_random_state(rng)

if spatio_bin_jitter_magnitude == "adaptive":
rotation_angle = rng.uniform(0, 90)
rotation_axis = rng.uniform(-1, 1, 3)
Expand All @@ -106,7 +105,7 @@ def get_one_ensemble_sphere_quadtree(
step=temporal_step,
bin_interval=temporal_bin_interval,
temporal_bin_start_jitter=temporal_bin_start_jitter,
rng=rng
rng=rng,
)

ensemble_all_df_list = []
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_selection.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from stemflow.model_selection import ST_CV, ST_train_test_split, ST_KFold
from stemflow.model_selection import ST_CV, ST_KFold, ST_train_test_split

from .set_up_data import set_up_data

Expand Down
4 changes: 1 addition & 3 deletions tests/test_random_state_reproducibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
from stemflow.model.AdaSTEM import AdaSTEM
from stemflow.model_selection import ST_train_test_split

from .make_models import (
make_AdaSTEMRegressor,
)
from .make_models import make_AdaSTEMRegressor
from .set_up_data import set_up_data

x_names, (X, y) = set_up_data()
Expand Down

0 comments on commit 8826ea1

Please sign in to comment.