Skip to content

Commit

Permalink
Update to use RandomState to avoid compatibility issues with train_te…
Browse files Browse the repository at this point in the history
…st_split
  • Loading branch information
ovejabu committed Apr 18, 2024
1 parent 6c074f3 commit 86749ac
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 22 deletions.
4 changes: 2 additions & 2 deletions pipeline_lib/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pipeline_lib.core.data_container import DataContainer
from pipeline_lib.core.model_registry import ModelRegistry
from pipeline_lib.core.random_generator import initialize_generator
from pipeline_lib.core.random_state_generator import initialize_random_state
from pipeline_lib.core.step_registry import StepRegistry
from pipeline_lib.core.steps import PipelineStep

Expand Down Expand Up @@ -89,7 +89,7 @@ def from_json(cls, path: str) -> Pipeline:
seed = config["pipeline"]["seed"]
else:
seed = 42
initialize_generator(seed)
initialize_random_state(seed)

steps = []

Expand Down
14 changes: 0 additions & 14 deletions pipeline_lib/core/random_generator.py

This file was deleted.

28 changes: 28 additions & 0 deletions pipeline_lib/core/random_state_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional

import numpy as np
from numpy.random import RandomState

_random_state = None


def get_random_state() -> Optional[RandomState]:
"""
Get the global random state object.
Returns:
RandomState or None: The global random state object if initialized, else None.
"""
global _random_state
return _random_state


def initialize_random_state(seed: int):
"""
Initialize the global random state object with the provided seed.
Args:
seed (int): The seed value to initialize the random state object.
"""
global _random_state
_random_state = np.random.RandomState(seed)
4 changes: 2 additions & 2 deletions pipeline_lib/core/steps/explainer_dashboard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from explainerdashboard import RegressionExplainer

from pipeline_lib.core import DataContainer
from pipeline_lib.core.random_generator import get_random_generator
from pipeline_lib.core.random_state_generator import get_random_state
from pipeline_lib.core.steps.base import PipelineStep


Expand Down Expand Up @@ -38,7 +38,7 @@ def execute(self, data: DataContainer) -> DataContainer:
f" {self.max_samples}."
)
self.logger.info(f"Sampling {self.max_samples} data points from the dataset.")
df = df.sample(n=self.max_samples, random_state=get_random_generator().integers(0, 100))
df = df.sample(n=self.max_samples, random_state=get_random_state())

drop_columns = data._drop_columns
if drop_columns:
Expand Down
8 changes: 4 additions & 4 deletions pipeline_lib/core/steps/tabular_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sklearn.model_selection import train_test_split

from pipeline_lib.core import DataContainer
from pipeline_lib.core.random_generator import get_random_generator
from pipeline_lib.core.random_state_generator import get_random_state
from pipeline_lib.core.steps.base import PipelineStep


Expand Down Expand Up @@ -61,19 +61,19 @@ def execute(self, data: DataContainer) -> DataContainer:
train_val_df, test_df = train_test_split(
df,
test_size=self.test_percentage,
random_state=get_random_generator().integers(0, 100),
random_state=get_random_state(),
)
train_df, validation_df = train_test_split(
train_val_df,
train_size=self.train_percentage
/ (self.train_percentage + self.validation_percentage),
random_state=get_random_generator().integers(0, 100),
random_state=get_random_state(),
)
else:
train_df, validation_df = train_test_split(
df,
train_size=self.train_percentage,
random_state=get_random_generator().integers(0, 100),
random_state=get_random_state(),
)
test_df = None

Expand Down

0 comments on commit 86749ac

Please sign in to comment.