-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
119 lines (102 loc) · 4.67 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from fastapi import FastAPI, Security, BackgroundTasks, Response
from fastapi.security.api_key import APIKey, APIKeyHeader
from celery.result import AsyncResult
from huggingface_hub import HfApi
import aioredis
import uuid
import json
from models import GenerationAndCommitRequest, GenerationAndUpdateRequest, ChatViewRequest, ModelData
from tasks import generate_and_push_data, generate_and_update_data, generate_data
from worker import celery_app
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Security schemes
openai_key_scheme = APIKeyHeader(name="X-OpenAI-Key")
huggingface_key_scheme = APIKeyHeader(name="X-HuggingFace-Key")
# Redis connection pool
redis_pool = None
@app.on_event("startup")
async def startup_event():
global redis_pool
redis_pool = aioredis.from_url("redis://localhost", decode_responses=True)
@app.on_event("shutdown")
async def shutdown_event():
await redis_pool.close()
@app.post("/data/view", status_code=202)
async def chat_view(req: ChatViewRequest,
background_tasks: BackgroundTasks,
openai_key: APIKey = Security(openai_key_scheme)
):
task_id = str(uuid.uuid4())
await redis_pool.hset(task_id, mapping={"status": "Starting", "Progress": "None", "Detail": "None"})
background_tasks.add_task(generate_data, redis_pool, task_id, req, openai_key)
return {"status": "Accepted", "task_id": task_id}
@app.post("/data", status_code=202)
async def chat_completion(req: GenerationAndCommitRequest,
background_tasks: BackgroundTasks,
openai_key: APIKey = Security(openai_key_scheme),
huggingface_key: APIKey = Security(huggingface_key_scheme)
):
task_id = str(uuid.uuid4())
await redis_pool.hset(task_id, mapping={"status": "Starting", "Progress": "None", "Detail": "None"})
background_tasks.add_task(generate_and_push_data, redis_pool, task_id, req, openai_key, huggingface_key)
return {"status": "Accepted", "task_id": task_id}
@app.put("/data", status_code=202)
async def chat_updation(req: GenerationAndUpdateRequest,
background_tasks: BackgroundTasks,
openai_key: APIKey = Security(openai_key_scheme),
huggingface_key: APIKey = Security(huggingface_key_scheme)
):
task_id = str(uuid.uuid4())
await redis_pool.hset(task_id, mapping={"status": "Starting", "Progress": "None", "Detail": "None"})
background_tasks.add_task(generate_and_update_data, redis_pool, task_id, req, openai_key, huggingface_key)
return {"status": "Accepted", "task_id": task_id}
@app.post("/train", status_code=202)
async def train_model(req: ModelData,
huggingface_key: APIKey = Security(huggingface_key_scheme)
):
task = celery_app.send_task('worker.train_task', args=[dict(req), huggingface_key])
await redis_pool.hset(str(task.id), mapping={"status": "Acknowledged", "handler": "Celery"})
return {'task_id': str(task.id)}
@app.get("/commit")
async def commit(repo_id: str, response: Response):
api = HfApi()
try:
commit_info = api.list_repo_commits(repo_id)
except Exception as e:
response.status_code = 404
return {"response": str(e)}
commit_info = [{"version": item.commit_id, "date": item.created_at} for item in commit_info if "pytorch_model.bin" in item.title]
return {"response": commit_info}
@app.get("/track/{task_id}")
async def get_progress(task_id: str, response: Response):
res = await redis_pool.hgetall(task_id)
if res == {}:
response.status_code = 404
else:
response.status_code = 200
if "handler" in res and res["handler"] == "Celery":
cres = AsyncResult(task_id, app=celery_app)
if str(cres.status) == 'SUCCESS':
if isinstance(res['logs'], str):
logs = json.loads(res['logs'])
else:
logs = res['logs']
return {"status": res['status'], "response": logs}
return {'status': str(cres.status), 'response': cres.info}
else:
try:
if isinstance(res['Detail'], str):
detail = json.loads(res['Detail'])
return {"status": res['status'], "response": detail["data"]}
except:
pass
return {"status": res['status'], "response": {"Detail": res['Detail'], "Progress": res['Progress']}}