-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathzenml_pipeline.py
95 lines (75 loc) · 2.72 KB
/
zenml_pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import bentoml
from model import SimpleConvNet
from zenml.integrations.mlflow.mlflow_step_decorator import enable_mlflow
from train import train, test_model, cross_validate
from zenml.pipelines import pipeline
from zenml.steps import step, BaseStepConfig
from zenml.environment import Environment
from zenml.integrations.mlflow.mlflow_environment import MLFLOW_ENVIRONMENT_NAME
class TrainerConfig(BaseStepConfig):
"""Trainer params"""
epochs: int = 1
k_folds: int = 2
lr: float = 0.001
@enable_mlflow
@step
def cross_validate_dataset(config: TrainerConfig) -> dict:
return cross_validate(
epochs=config.epochs, k_folds=config.k_folds, learning_rate=config.lr
)
@enable_mlflow
@step
def train_model(config: TrainerConfig) -> SimpleConvNet:
return train(epochs=config.epochs, learning_rate=config.lr)
@enable_mlflow
@step
def test_model_performance(model: SimpleConvNet) -> dict:
return test_model(model=model, _test_loader=None)
@step
def _save_model(cv_results: dict, test_results: dict, model: SimpleConvNet) -> None:
metadata = {
"acc": float(test_results["correct"]) / test_results["total"],
"cv_stats": cv_results,
}
# bentoml save model
model_name = "pytorch_mist"
bentoml.pytorch.save(
model_name,
model,
metadata=metadata,
)
@pipeline(enable_cache=False)
def mnist_pipeline(_cross_validator, _trainer, _test_model, _save_model):
"""Links all the steps together in a pipeline"""
cv_results = _cross_validator()
model = _trainer()
test_results = _test_model(model=model)
_save_model(cv_results=cv_results, test_results=test_results, model=model)
if __name__ == "__main__":
# Run the pipeline
configs = [
{"epochs": 1, "k_folds": 2, "lr": 0.0003},
{"epochs": 2, "k_folds": 2, "lr": 0.0004},
]
for config in configs:
pipeline_def = mnist_pipeline(
_cross_validator=cross_validate_dataset(
config=TrainerConfig(
epochs=config["epochs"], k_folds=config["k_folds"], lr=config["lr"]
)
),
_trainer=train_model(
config=TrainerConfig(epochs=config["epochs"], lr=config["lr"])
),
_test_model=test_model_performance(),
_save_model=_save_model(),
)
pipeline_def.run()
mlflow_env = Environment()[MLFLOW_ENVIRONMENT_NAME]
print(
"Now run \n "
f" mlflow ui --backend-store-uri {mlflow_env.tracking_uri}\n"
"To inspect your experiment runs within the mlflow ui.\n"
"You can find your runs tracked within the `mlflow_example_pipeline`"
"experiment. Here you'll also be able to compare the two runs.)"
)