-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_chatbot_aas.py
53 lines (48 loc) · 2.12 KB
/
test_chatbot_aas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import pickle
import sys
import os
import numpy as np
import tensorflow as tf
import data_utils
from corpora_tools import clean_sentence, sentences_to_indexes, prepare_sentences
from train_chatbot import get_seq2seq_model, path_l1_dict, path_l2_dict
model_dir = os.getcwd()+"/tmp/chat"
def prepare_sentence(sentence, dict_l1, max_length):
sents = [sentence.split(" ")]
clean_sen_l1 = [clean_sentence(s) for s in sents]
idx_sentences_l1 = sentences_to_indexes(clean_sen_l1, dict_l1)
data_set = prepare_sentences(idx_sentences_l1, [[]], max_length, max_length)
sentences = (clean_sen_l1, [[]])
return sentences, data_set
def decode(data_set):
with tf.Session() as sess:
model = get_seq2seq_model(sess, True, dict_lengths, max_sentence_lengths, model_dir)
model.batch_size = 1
bucket = 0
encoder_inputs, decoder_inputs, target_weights = model.get_batch(
{bucket: [(data_set[0][0], [])]}, bucket)
_, _, output_logits = model.step(sess, encoder_inputs, decoder_inputs,
target_weights, bucket, True)
outputs = [int(np.argmax(logit, axis=1)) for logit in output_logits]
if data_utils.EOS_ID in outputs:
outputs = outputs[1:outputs.index(data_utils.EOS_ID)]
tf.reset_default_graph()
return " ".join([tf.compat.as_str(inv_dict_l2[output]) for output in outputs])
if __name__ == "__main__":
dict_l1 = pickle.load(open(path_l1_dict, "rb"))
dict_l1_length = len(dict_l1)
dict_l2 = pickle.load(open(path_l2_dict, "rb"))
dict_l2_length = len(dict_l2)
inv_dict_l2 = {v: k for k, v in dict_l2.items()}
max_lengths = 10
dict_lengths = (dict_l1_length, dict_l2_length)
max_sentence_lengths = (max_lengths, max_lengths)
from bottle import route, run, request
@route('/api')
def api():
in_sentence = request.query.sentence
#print (in_sentence)
_, data_set = prepare_sentence(in_sentence, dict_l1, max_lengths)
resp = [{"in": in_sentence, "out": decode(data_set)}]
return dict(data=resp)
run(host='127.0.0.1', port=8080, reloader=True, debug=True)