Skip to content

Commit

Permalink
Fix an issue overriding config values on the DataContainer when runni…
Browse files Browse the repository at this point in the history
…ng the prediction pipeline
  • Loading branch information
Ludecan committed Apr 26, 2024
1 parent f5ca60a commit e8c8746
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 33 deletions.
11 changes: 11 additions & 0 deletions pipeline_lib/core/data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def __init__(self, initial_data: Optional[dict] = None):
self.logger = logging.getLogger(self.__class__.__name__)
self.logger.debug(f"{self.__class__.__name__} initialized")

def update(self, other: DataContainer) -> None:
"""
Update the data in this container with another DataContainer's data.
Parameters
=========
other : DataContainer
The DataContainer to copy data from.
"""
self.data.update(other.data)

def add(self, key: str, value):
"""
Add a new item to the container.
Expand Down
78 changes: 45 additions & 33 deletions pipeline_lib/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import os
import pprint
import time
from datetime import datetime
from typing import Any, Optional
Expand All @@ -28,11 +29,19 @@ class Pipeline:
"explainer",
]

def __init__(self, initial_data: Optional[DataContainer] = None):
def __init__(
self,
save_data_path: str,
target: str,
columns_to_ignore_for_training: Optional[list[str]] = None,
):
self.data = DataContainer()
self.steps = []
self.initial_data = initial_data
self.save_data_path = save_data_path
self.data.target = target
self.data.prediction_column = f"{target}_prediction"
self.data.columns_to_ignore_for_training = columns_to_ignore_for_training or []
self.config = {}
self.save_data_path = None

def add_steps(self, steps: list[PipelineStep]):
"""Add steps to the pipeline."""
Expand All @@ -41,46 +50,24 @@ def add_steps(self, steps: list[PipelineStep]):
def run(self, is_train: bool, save: bool = True) -> DataContainer:
"""Run the pipeline on the given data."""

if "parameters" not in self.config["pipeline"]:
raise ValueError("Missing pipeline parameters section in the config file.")

if "save_data_path" not in self.config["pipeline"]["parameters"]:
raise ValueError(
"A path for saving the data must be provided. Use the `save_data_path` attribute "
'of the pipeline parameters" section in the config.'
)

if "target" not in self.config["pipeline"]["parameters"]:
raise ValueError(
"A target column must be provided. Use the `target` attribute of the pipeline"
' "parameters" section in the config.'
)

data = DataContainer()

self.save_data_path = self.config["pipeline"]["parameters"]["save_data_path"]
data.target = self.config["pipeline"]["parameters"]["target"]
data.prediction_column = f"{data.target}_prediction"
data.columns_to_ignore_for_training = self.config["pipeline"]["parameters"].get(
"columns_to_ignore_for_training", []
)

if is_train:
steps_to_run = [step for step in self.steps if step.used_for_training]
self.logger.info("Training the pipeline")
else:
data = DataContainer.from_pickle(self.save_data_path)
self.data.update(DataContainer.from_pickle(self.save_data_path))
steps_to_run = [step for step in self.steps if step.used_for_prediction]
self.logger.info("Predicting with the pipeline")

data.is_train = is_train
self.data.is_train = is_train

pprint.pprint(self.data.data)

for i, step in enumerate(steps_to_run):
start_time = time.time()
log_str = f"Running {step.__class__.__name__} - {i + 1} / {len(steps_to_run)}"
Pipeline.logger.info(log_str)

data = step.execute(data)
data = step.execute(self.data)

Pipeline.logger.info(f"{log_str} done. Took: {time.time() - start_time:.2f}s")

Expand All @@ -100,6 +87,23 @@ def predict(self) -> DataContainer:
"""Run the pipeline on the given data."""
return self.run(is_train=False)

@classmethod
def _validate_configuration(cls, config: dict[str, Any]) -> None:
if "parameters" not in config["pipeline"]:
raise ValueError("Missing pipeline parameters section in the config file.")

if "save_data_path" not in config["pipeline"]["parameters"]:
raise ValueError(
"A path for saving the data must be provided. Use the `save_data_path` attribute "
'of the pipeline parameters" section in the config.'
)

if "target" not in config["pipeline"]["parameters"]:
raise ValueError(
"A target column must be provided. Use the `target` attribute of the pipeline"
' "parameters" section in the config.'
)

@classmethod
def from_json(cls, path: str) -> Pipeline:
"""Load a pipeline from a JSON file."""
Expand All @@ -110,15 +114,23 @@ def from_json(cls, path: str) -> Pipeline:
with open(path, "r") as config_file:
config = json.load(config_file)

Pipeline._validate_configuration(config)

custom_steps_path = config.get("custom_steps_path")
if custom_steps_path:
cls.step_registry.load_and_register_custom_steps(custom_steps_path)

pipeline = Pipeline()

pipeline = Pipeline(
save_data_path=config["pipeline"]["parameters"]["save_data_path"],
target=config["pipeline"]["parameters"]["target"],
columns_to_ignore_for_training=config["pipeline"]["parameters"].get(
"columns_to_ignore_for_training", []
),
)
pipeline.config = config
steps = []

steps = []
# step_config = list(config["pipeline"]["steps"])[0]
for step_config in config["pipeline"]["steps"]:
step_type = step_config["step_type"]
parameters = step_config.get("parameters", {})
Expand Down

0 comments on commit e8c8746

Please sign in to comment.