Skip to content

Commit

Permalink
rename model_params to model_parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
diegomarvid committed May 4, 2024
1 parent d0df0c4 commit 7090d1a
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions pipeline_lib/core/steps/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def optimize(
X_validation,
y_validation,
model_class: Type[Model],
model_params: dict,
model_parameters: dict,
) -> dict:
def objective(trial):
# Create a copy of model_params, then update with the optuna suggested hyperparameters
# Create a copy of model_parameters, then update with the optuna hyperparameters
param = {}
param.update(model_params)
param.update(model_parameters)
param.update(self._create_trial_params(trial))

model = model_class(**param)
Expand Down Expand Up @@ -97,14 +97,14 @@ class ModelStep(PipelineStep):
def __init__(
self,
model_class: Type[Model],
model_params: Optional[dict] = None,
model_parameters: Optional[dict] = None,
optuna_params: Optional[dict] = None,
save_path: Optional[str] = None,
) -> None:
super().__init__()
self.init_logger()
self.model_class = model_class
self.model_params = model_params or {}
self.model_parameters = model_parameters or {}
self.optuna_params = optuna_params
self.save_path = save_path

Expand All @@ -116,7 +116,7 @@ def execute(self, data: DataContainer) -> DataContainer:

def train(self, data: DataContainer) -> DataContainer:
self.logger.info(f"Fitting the {self.model_class.__name__} model")
model_params = self.model_params
model_parameters = self.model_parameters

if self.optuna_params:
optimizer = OptunaOptimizer(self.optuna_params, self.logger)
Expand All @@ -126,13 +126,13 @@ def train(self, data: DataContainer) -> DataContainer:
data.X_validation,
data.y_validation,
self.model_class,
model_params,
model_parameters,
)
model_params.update(optuna_model_params)
self.logger.info(f"Optimized model parameters: \n{json.dumps(model_params)}")
model_parameters.update(optuna_model_params)
self.logger.info(f"Optimized model parameters: \n{json.dumps(model_parameters)}")
self.logger.info("Re-fitting the model with optimized parameters")

model = self.model_class(**model_params)
model = self.model_class(**model_parameters)
model.fit(
data.X_train,
data.y_train,
Expand Down

0 comments on commit 7090d1a

Please sign in to comment.