-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlflow_client.py
160 lines (125 loc) · 6.18 KB
/
mlflow_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import yaml
import json
import mlflow
import base64
import pandas as pd
from typing import Any, Dict, List
from datetime import datetime
from json import JSONDecodeError
from io import StringIO, BytesIO
from mlflow import MlflowClient, MlflowException
from models import model_handlers
def flatten_dict(d: dict, parent_key='', sep='_'):
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:
items.append((new_key, v))
return dict(items)
class Client:
def __init__(self):
os.environ['MLFLOW_TRACKING_USERNAME'] = os.getenv('MLFLOW_TRACKING_USERNAME')
os.environ['MLFLOW_TRACKING_PASSWORD'] = os.getenv('MLFLOW_TRACKING_PASSWORD')
os.environ['AWS_ACCESS_KEY_ID'] = os.getenv('AWS_ACCESS_KEY_ID')
os.environ['AWS_SECRET_ACCESS_KEY'] = os.getenv('AWS_SECRET_ACCESS_KEY')
os.environ['MLFLOW_S3_ENDPOINT_URL'] = os.getenv('MLFLOW_S3_ENDPOINT_URL')
os.environ['MLFLOW_TRACKING_INSECURE_TLS'] = os.getenv('MLFLOW_TRACKING_INSECURE_TLS')
os.environ['MLFLOW_HTTP_REQUEST_TIMEOUT'] = "1000"
mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI"))
self.client = MlflowClient()
def models(self):
returns = []
models_list = self.client.search_registered_models("tag.mage_model = 'true'")
if not models_list:
return None
for model in models_list:
timestamp_s = model.creation_timestamp / 1000
returns.append({
"name": model.name,
"creation_date": datetime.utcfromtimestamp(timestamp_s).strftime('%Y-%m-%d %H:%M:%S')
})
return returns
def model_parameters(self, name: str, version: str) -> Dict[str, Any] | None:
run_id = self.client.get_registered_model(name).latest_versions[0].run_id if version is None \
else self.client.get_model_version(name, version).run_id
if not run_id:
return None
parameters = self.client.get_run(run_id).data.params
try:
parameters = {k: json.loads(v.replace("'", '"')) for k, v in parameters.items()}
except JSONDecodeError:
pass
return flatten_dict(parameters)
def model_versions(self, name: str) -> List[Dict[str, str]] | None:
versions = self.client.search_model_versions(f"name='{name}'")
return [{"version": version.version,
"type": "" if version.tags.get("model_type") is None else version.tags["model_type"]}
for version in versions]
def model_metrics(self, name: str, version: str) -> Dict[str, Any] | None:
run_id = self.client.get_registered_model(name).latest_versions[0].run_id if version is None \
else self.client.get_model_version(name, version).run_id
metrics = self.client.get_run(run_id).data.metrics
return metrics
def model_dataset(self, name: str, version: str):
run_id = self.client.get_registered_model(name).latest_versions[0].run_id if version is None \
else self.client.get_model_version(name, version).run_id
if not run_id:
return None
try:
dataset = mlflow.artifacts.load_text(f"runs:/{run_id}/dataset")
return pd.read_csv(StringIO(dataset))
except MlflowException:
return None
def model_images(self, name: str, version: str):
run_id = self.client.get_registered_model(name).latest_versions[0].run_id if version is None \
else self.client.get_model_version(name, version).run_id
if not run_id:
return None
try:
artifacts = self.client.list_artifacts(run_id, path="figures")
images = {}
for artifact in artifacts:
print(artifact.path)
if ".png" in artifact.path:
image = mlflow.artifacts.load_image(f"runs:/{run_id}/{artifact.path}")
buffered = BytesIO()
image_format = image.format if image.format else 'PNG'
image.save(buffered, format=image_format)
images[str(artifact.path).split("/")[-1]] = (f"data:image/png;base64,"
f"{(base64.b64encode(buffered.getvalue()).decode('utf-8'))}")
except MlflowException:
return None
return images
def model_predict(self, name: str, df: pd.DataFrame):
run_id = self.client.get_registered_model(name).latest_versions[0].run_id
if not run_id:
return None
artifacts = self.client.list_artifacts(run_id)
artifacts = [artifact.path for artifact in artifacts if "model" in artifact.path and artifact.is_dir]
content = mlflow.artifacts.load_text(f"runs:/{run_id}/{artifacts[0]}/MLmodel")
content = yaml.safe_load(StringIO(content))
model_type = content["flavors"]["python_function"]["loader_module"].split(".")[1]
if model_type in list(model_handlers.keys()):
handler = model_handlers[model_type]
handler.load_model(model_uri=f"runs:/{run_id}/{artifacts[0]}")
return handler.predict(df)
return None
def model_package(self, name: str):
run_id = self.client.get_registered_model(name).latest_versions[0].run_id
if not run_id:
return None
artifacts = self.client.list_artifacts(run_id)
artifacts = [artifact.path for artifact in artifacts if "model" in artifact.path and artifact.is_dir]
content = mlflow.artifacts.load_text(f"runs:/{run_id}/{artifacts[0]}/MLmodel")
content = yaml.safe_load(StringIO(content))
model_type = content["flavors"]["python_function"]["loader_module"].split(".")[1]
if model_type in list(model_handlers.keys()):
handler = model_handlers[model_type]
return handler.save_model(f"runs:/{run_id}/{artifacts[0]}")
return None
@staticmethod
def model_register(run_id: str, model_name: str):
return mlflow.register_model(f"runs:/{run_id}", model_name)