-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathget_action_weights_old.py
131 lines (107 loc) · 3.55 KB
/
get_action_weights_old.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import datetime
import glob
import json
import math
import os
import shutil
import sqlite3
import string
import subprocess
import sys
import time
import random
import typing
import csv
from typing import Any
import numpy as np
import pickle
import glob
from http.server import BaseHTTPRequestHandler, HTTPServer
from scipy.linalg import lstsq
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPClassifier
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense, Flatten
from keras.models import load_model
from keras import optimizers
action_data = {}
compare_count = 0
action_count = 0
def load_data():
global action_data, compare_count, action_count
directory = './data'
for filename in os.listdir(directory):
f = os.path.join(directory, filename)
if os.path.isfile(f):
action_data[filename] = load_model(f)
conn = sqlite3.connect(os.getcwd() +'/cardData.cdb')
c = conn.cursor()
c.execute('SELECT max(rowid) FROM L_CompareTo')
compare_count = c.fetchone()[0]
c.execute('SELECT max(rowid) FROM L_ActionList')
action_count = c.fetchone()[0]
conn.close()
print("Compare Count:" + str(compare_count))
print("Action Count:" + str(action_count))
def get_predictions(data: typing.List[int], actions: typing.List[int], name: string):
global action_data, compare_count
if (action_data == None):
return []
if name not in action_data.keys():
return []
input_length = action_count
input_list = [0] * input_length
for id in data:
index = int(id) - 1
if (index < len(input_list) and index >= 0):
input_list[index] = 1
# for id in actions:
# index = int(id) - 1 + compare_count
# if (index < len(input_list) and index >= 0):
# input_list[index] = 1
result = action_data[name].predict([input_list], batch_size=1)
print("Estimate:" + str(np.argmax(result)) + " " + str(len(result[0])))
ind = np.argpartition(result, -4)[0][-4:]
index = ind[np.argsort(result[0][ind])]
print("Top k:")
for i in index:
print(str(i) + ":" + str(result[0][i]))
return result[0].tolist()
# def run_command_line():
# while True:
# data = input("Enter input data\n").split(' ')
# actions = input("Enter Actions\n").split(' ')
# predictions = get_predictions(data, actions)
# for action in predictions:
# print("Action" + str(action) + ":" +str(predictions[action]))
class handler(BaseHTTPRequestHandler):
def do_POST(self):
content_len = int(self.headers.get('Content-Length'))
get_body = self.rfile.read(content_len).decode()
raw_data = json.loads(get_body)
# print(raw_data["data"])
# print(raw_data["actions"])
data = raw_data["data"].split(' ')
actions = raw_data["actions"].split(' ')
name = raw_data["name"].split(' ')[0]
predictions = get_predictions(data, actions, name)
self.send_response(200)
self.send_header('Content-type','text/html')
self.end_headers()
# print(predictions)
message = json.dumps(predictions)
self.wfile.write(bytes(message, "utf8"))
def log_message(self, format: str, *args: Any) -> None:
return
return super().log_message(format, *args)
def run_server():
with HTTPServer(('', 8000), handler) as server:
server.serve_forever()
if __name__ == "__main__":
load_data()
#run_command_line()
run_server()