-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathkai_listener.py
executable file
·67 lines (53 loc) · 2.01 KB
/
kai_listener.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# import main Flask class and request object
from flask import Flask, request
from generation_request import GenerationRequest
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--generator', type=str, required=True,
help="Generator. Either 'oobabooga', 'llama.cpp'")
args = parser.parse_args()
if args.generator == "oobabooga":
from gen_oobabooga import OobaboogaGenerator
gen = OobaboogaGenerator()
elif args.generator == "llama.cpp":
from gen_llamacpp import LlamaCppGenerator
gen = LlamaCppGenerator()
else:
raise Exception(f"Unexpected parser {args.parser}")
app = Flask(__name__)
@app.route('/api/v1/model')
def model_name():
# result = {'result': 'facebook/opt-125m'}
# print("Model name quieried")
result = {'result': 'hoperator/oobabooga'}
return result
# return 'JSON Object Example'
@app.route('/api/v1/generate', methods=['POST'])
def generate_text(*args, **kwargs):
data = request.json
new_request = GenerationRequest()
new_request.prompt = data.get('prompt')
new_request.max_new_length = data.get('max_length')
new_request.max_context_length = data.get('max_context_length')
new_request.repetition_penalty = data.get('rep_pen')
new_request.repetition_penalty_slope = data.get('rep_pen_slope')
new_request.repetition_penalty_range = data.get('rep_pen_range')
new_request.temperature = data.get('temperature')
new_request.top_p = data.get('top_p')
new_request.top_k = data.get('top_k')
new_request.top_a = data.get('top_a')
new_request.tail_free_sampling = data.get('tfs')
new_request.typical = data.get('typical')
new_request.batch_count = data.get('n')
response = gen.run(new_request)
result = {'results': [{
"text": response
}]}
return result
@app.route('/', defaults={'path': ''})
@app.route('/<path:path>')
def catch_all(path):
print('You want path: %s' % path)
return 'You want path: %s' % path
if __name__ == '__main__':
app.run(debug=True, port=11111)