Skip to content

Commit

Permalink
Merge pull request #2 from kittinan/main
Browse files Browse the repository at this point in the history
create flask api
  • Loading branch information
titipata authored Oct 19, 2021
2 parents 47d1a32 + 6406620 commit 4d078fa
Showing 1 changed file with 48 additions and 0 deletions.
48 changes: 48 additions & 0 deletions api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import timeit
from flask import Flask, render_template, request, jsonify
from flask_cors import CORS, cross_origin


os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from torch.nn.functional import softmax
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline


# download model from hub
TOKENIZER = AutoTokenizer.from_pretrained("tupleblog/salim-classifier")
MODEL = AutoModelForSequenceClassification.from_pretrained("tupleblog/salim-classifier")

app = Flask(__name__)
cors = CORS(app)


def predict(model, tokenizer, text):
"""
Predict with model, tokeinzer, and text
"""
device = "cpu"
_inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model(**_inputs)
result = softmax(outputs[0], dim=1).cpu().data.numpy().round(6).tolist()
result = result[0]
format_result = [
{"label": label, "score": float(result[index])}
for index, label in model.config.id2label.items()
]
return format_result


@app.route("/", methods=["POST"])
def index():
text = request.form.get("text", "")
print(text)
start_time = timeit.default_timer()
result = predict(MODEL, TOKENIZER, text)
usage_time = round(timeit.default_timer() - start_time, 3)
return jsonify({"result": result, "usage_time": usage_time})


if __name__ == "__main__":
app.run(debug=True, host="0.0.0.0", port=5000)

0 comments on commit 4d078fa

Please sign in to comment.