Skip to content

Commit

Permalink
Update .gitignore and remove unnecessary files and code
Browse files Browse the repository at this point in the history
  • Loading branch information
amine-akrout committed Mar 13, 2024
1 parent a828a4f commit bc2ec85
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 13 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ tf_training/mlruns/*

# saved models
tf_training/tmp/*
tf_serving/tmp/*

# data
data/*
Expand Down
File renamed without changes.
File renamed without changes.
13 changes: 6 additions & 7 deletions tf_training/docker-compose.yml → tf_serving/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
version: '3.2'
version: "3.2"
services:
tf_serving:
ports:
- '8501:8501'
- "8501:8501"
container_name: reviews_prediction
environment:
- MODEL_NAME=reviews_preds
image: tensorflow/serving
volumes:
- type: bind
source: ./tmp/swivel
source: ./tmp/best_model/model/data/model
target: /models/reviews_preds/1
networks:
- app
streamlit:
build: ./app
image: 3aak/streamlit
ports:
- '9000:9000'
- "9000:9000"
container_name: streamlit_ui
volumes:
- './app:/app'
- "./app:/app"
depends_on:
- tf_serving
command: streamlit run app.py --server.port=9000 --browser.serverAddress=0.0.0.0
networks:
- app
networks:
app:
external:
name: app
driver: bridge
File renamed without changes.
File renamed without changes.
4 changes: 1 addition & 3 deletions tf_training/lstm_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def run_lstm():
train_padded = np.array(train_padded)
test_padded = np.array(test_padded)
model = create_lstm_model(vocab_size, embedding_dim, max_len, metrics)
train_model(
"baseline", model, train_padded, y_train, test_padded, y_test, num_epoch
)
train_model("lstm", model, train_padded, y_train, test_padded, y_test, num_epoch)


if __name__ == "__main__":
Expand Down
37 changes: 37 additions & 0 deletions tf_training/model_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Select the best model from the MLflow experiment and save it to a directory""
"""

import os

import mlflow


# Function to find the best model
def get_best_model(experiment_name):
"""Get the best model from the MLflow experiment"""
runs = mlflow.search_runs(experiment_names=[experiment_name])
best_run = runs.sort_values("metrics.val_recall", ascending=False).iloc[0]
print(f"best model: ", best_run["tags.mlflow.runName"])
return best_run


def save_model(best_run, model_path="../tf_serving/tmp/best_model"):
"""Save the best model to a directory"""
# Model artifacts
model_uri = best_run["artifact_uri"] + "/model"
os.makedirs(model_path, exist_ok=True)
mlflow.artifacts.download_artifacts(artifact_uri=model_uri, dst_path=model_path)


def select_best_model():
"""Select the best model from the MLflow experiment and save it to a directory"""
# MLflow server details
mlflow.set_tracking_uri("http://localhost:5000")
experiment_name = "Reviews_Classification"
best_run = get_best_model(experiment_name)
save_model(best_run)


if __name__ == "__main__":
select_best_model()
13 changes: 10 additions & 3 deletions tf_training/training_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from bert_training import run_bert
from cnn_training import run_cnn
from lstm_training import run_lstm
from model_selection import select_best_model
from prefect import flow, task
from swivel_training import run_swivel

Expand All @@ -21,23 +22,29 @@ def train_cnn():
run_cnn()


# @task
# def train_bert():
# run_bert()
@task
def train_bert():
run_bert()


@task
def train_swivel():
run_swivel()


@task
def select_best_model_task():
select_best_model()


@flow(name="Training Flow", log_prints=True)
def training_flow():
train_baseline()
# train_bert()
train_cnn()
train_lstm()
train_swivel()
select_best_model_task()


if __name__ == "__main__":
Expand Down

0 comments on commit bc2ec85

Please sign in to comment.