Skip to content

Commit

Permalink
Store train/val/test splits on DataContainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Ludecan committed Apr 12, 2024
1 parent 15aa923 commit 72f55a8
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 29 deletions.
32 changes: 30 additions & 2 deletions pipeline_lib/core/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,34 @@ def raw(self, value: pd.DataFrame):
"""
self["raw"] = value

@property
def split_values(self) -> dict[str, list[Any]]:
"""
Get the split values used in the SplitStep from the DataContainer.
Returns
-------
dict[str, list[Any]]
A dictionary with keys "train", "validation" and "test", where each key maps to the list
of values used for performing the train, validation and test splits respectively.
Test set values may be empty
"""
return self["train"]

@split_values.setter
def split_values(self, value: dict[str, list[Any]]):
"""
Set the split values used in the SplitStep in the DataContainer.
Parameters
----------
value
A dictionary with keys "train", "validation" and "test", where each key maps to the list
of values used for performing the train, validation and test splits respectively.
Test set values may be empty
"""
self["train"] = value

@property
def train(self) -> pd.DataFrame:
"""
Expand Down Expand Up @@ -381,7 +409,7 @@ def validation(self, value: pd.DataFrame):
self["validation"] = value

@property
def test(self) -> pd.DataFrame:
def test(self) -> Optional[pd.DataFrame]:
"""
Get the test data from the DataContainer.
Expand All @@ -393,7 +421,7 @@ def test(self) -> pd.DataFrame:
return self["test"]

@test.setter
def test(self, value: pd.DataFrame):
def test(self, value: Optional[pd.DataFrame]):
"""
Set the test data in the DataContainer.
Expand Down
75 changes: 48 additions & 27 deletions pipeline_lib/core/steps/tabular_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Optional
from typing import Iterable, Optional, Tuple

import pandas as pd
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -102,33 +102,21 @@ def __init__(
"validation_percentage must be provided when test_percentage is specified."
)

def execute(self, data: DataContainer) -> DataContainer:
"""Execute the random train-validation-test split."""
self.logger.info("Splitting tabular data...")
df = data.flow

def perform_split(
self,
df: pd.DataFrame,
data: DataContainer,
) -> Tuple[pd.DataFrame, pd.DataFrame, Optional[pd.DataFrame]]:
if self.group_by_columns is not None:
concatted_groupby_columns = _concatenate_columns(df, self.group_by_columns)
split_values = concatted_groupby_columns.unique().tolist()
else:
concatted_groupby_columns = None
split_values = df.index.tolist()

if self.test_percentage is not None:
train_val_values, test_values = train_test_split(
split_values, test_size=self.test_percentage, random_state=42
)
train_values, validation_values = train_test_split(
train_val_values,
train_size=self.train_percentage
/ (self.train_percentage + self.validation_percentage),
random_state=42,
)
else:
train_values, validation_values = train_test_split(
split_values, train_size=self.train_percentage, random_state=42
)
test_values = None
train_values, validation_values, test_values = (
data.split_values["train"],
data.split_values["validation"],
data.split_values["test"],
)

if self.group_by_columns is not None:
train_df = df[concatted_groupby_columns.isin(set(train_values))]
Expand Down Expand Up @@ -181,9 +169,42 @@ def execute(self, data: DataContainer) -> DataContainer:
f"Number of rows in test set: {test_rows} | {test_rows / total_rows:.2%}"
)

data.train = train_df
data.validation = validation_df
if test_df is not None:
data.test = test_df
return train_df, validation_df, test_df

def execute(self, data: DataContainer) -> DataContainer:
"""Execute the random train-validation-test split."""
self.logger.info("Splitting tabular data...")
df = data.flow

if self.group_by_columns is not None:
concatted_groupby_columns = _concatenate_columns(df, self.group_by_columns)
split_values = concatted_groupby_columns.unique().tolist()
else:
concatted_groupby_columns = None
split_values = df.index.tolist()

if self.test_percentage is not None:
train_val_split_values, test_split_values = train_test_split(
split_values, test_size=self.test_percentage, random_state=42
)
train_split_values, validation_split_values = train_test_split(
train_val_split_values,
train_size=self.train_percentage
/ (self.train_percentage + self.validation_percentage),
random_state=42,
)
else:
train_split_values, validation_split_values = train_test_split(
split_values, train_size=self.train_percentage, random_state=42
)
test_split_values = []

data.split_values = {
"train": train_split_values,
"validation": validation_split_values,
"test": test_split_values,
}

data.train, data.validation, data.test = self.perform_split(df, data)

return data

0 comments on commit 72f55a8

Please sign in to comment.