forked from SakuraLLM/SakuraLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
124 lines (94 loc) · 3.1 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import sys
# Fix for windows embedded environment
file_dir = os.path.dirname(__file__)
sys.path.append(file_dir)
import random
import asyncio
import coloredlogs
import logging
from argparse import ArgumentParser
from dacite import from_dict
from hypercorn import Config
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware
from api import log_request
from api.auth import get_auth_username
from utils import *
from utils import model as M
from utils import state
from utils.state import ServerConfig
from utils.cli import parse_args
dependencies = [
Depends(log_request),
]
args = parse_args()
coloredlogs.install(level=args.logLevel.upper())
logger = logging.getLogger(__name__)
logger.debug(f"Current Log Level: {args.logLevel}")
addr = args.listen.split(":")
ServerConfig.address = addr[0]
ServerConfig.port = int(addr[1])
# Hidden trick to disable auth, useful when you use docker-compose
if args.auth == ":":
args.auth = None
args.no_auth = True
auth = [None, None]
if args.no_auth:
logger.warning("Auth is disabled!")
else:
if not args.auth:
# Generate random auth credentials
args.auth = f"sakura:{random.randint(114514, 19194545)}"
logger.warning(f"Using random auth credentials. {auth}")
auth = args.auth.split(":")
# Insert http auth check
dependencies.append(Depends(get_auth_username))
ServerConfig.username = auth[0]
ServerConfig.password = auth[1]
app = FastAPI(dependencies=dependencies)
from api.legacy import router as legacy_router
app.include_router(legacy_router)
from api.openai.v1 import router as openai_router
app.include_router(openai_router)
from api.openai.v1.chat import router as openai_chat_router
app.include_router(openai_chat_router)
from api.core import router as core_router
app.include_router(core_router)
origins = [
"*",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
if __name__ == "__main__":
logger.info(f"Current server config: {ServerConfig.show()}")
# build cfg from args
cfg = from_dict(data_class=M.SakuraModelConfig, data=args.__dict__)
logger.info(f"Current model config: {cfg}")
state.init_model(cfg)
state.get_model().check_model_by_magic()
logger.info(
f"Server will run at http://{ServerConfig.address}:{ServerConfig.port}, preparing...")
# disable multiprocessing, since LLM model is not thread safe
if False: # use uvicorn
import uvicorn
uvicorn.run("server:app",
host=ServerConfig.address,
port=ServerConfig.port,
log_level=args.logLevel,
workers=1
)
else: # use hypercorn
from hypercorn.asyncio import serve
config = Config()
binding = f"{ServerConfig.address}:{ServerConfig.port}"
logger.debug(f"hypercorn binding: {binding}")
config.bind= [binding,]
config.loglevel = args.logLevel
config.debug = args.logLevel == "debug"
asyncio.run(serve(app, config))