-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.go
321 lines (268 loc) · 8.92 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
package main
import (
"database/sql"
"encoding/csv"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"
"regexp"
"strconv"
"strings"
json "github.com/json-iterator/go"
_ "github.com/mattn/go-sqlite3"
)
const version = "1.4.0"
func main() {
if err := cmd(os.Args[1:]); err != nil {
log.Fatal(err)
}
}
func cmd(cmdArgs []string) error {
log.Printf("SQLiteQueryServer v%s\n", version)
log.Println("https://github.com/assafmo/SQLiteQueryServer")
// Parse cmd args
var flagSet = flag.NewFlagSet("cmd flags", flag.ContinueOnError)
var dbPath string
var queryString string
var serverPort uint
flagSet.StringVar(&dbPath, "db", "", "Filesystem path of the SQLite database")
flagSet.StringVar(&queryString, "query", "", "SQL query to prepare for")
flagSet.UintVar(&serverPort, "port", 80, "HTTP port to listen on")
err := flagSet.Parse(cmdArgs)
if err != nil {
return err
}
// Init db and query
queryHandler, err := initQueryHandler(dbPath, queryString, serverPort)
if err != nil {
return err
}
// Start the server
log.Printf("Starting server on port %d...\n", serverPort)
log.Printf("Starting server with query '%s'...\n", queryString)
http.HandleFunc("/query", queryHandler)
err = http.ListenAndServe(fmt.Sprintf(":%d", serverPort), nil)
return err
}
type queryResult struct {
In []string `json:"in"`
Headers []string `json:"headers"`
Out [][]interface{} `json:"out"`
}
func initQueryHandler(dbPath string, queryString string, serverPort uint) (func(w http.ResponseWriter, r *http.Request), error) {
// Init db and query
if dbPath == "" {
return nil, fmt.Errorf("Must provide --db param")
}
if queryString == "" {
return nil, fmt.Errorf("Must provide --query param")
}
if _, err := os.Stat(dbPath); os.IsNotExist(err) {
return nil, fmt.Errorf("Database file '%s' doesn't exist", dbPath)
}
db, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?mode=rw&cache=shared&_journal_mode=WAL", dbPath))
if err != nil {
return nil, err
}
db.SetMaxOpenConns(1)
queryStmt, err := db.Prepare(queryString)
if err != nil {
db.Close()
return nil, err
}
helpMessage := buildHelpMessage("", queryString, queryStmt, serverPort)
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Server", "SQLiteQueryServer v"+version)
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("X-Content-Type-Options", "nosniff")
if r.URL.Path != "/query" {
http.Error(w, helpMessage, http.StatusNotFound)
return
}
if r.Method != "POST" && r.Method != "GET" {
http.Error(w, helpMessage, http.StatusMethodNotAllowed)
return
}
// Init fullResponse
fullResponse := []queryResult{}
var reqCsvReader *csv.Reader
if r.Method == "GET" {
// Static query
reqCsvReader = csv.NewReader(strings.NewReader(""))
} else {
// Parameterized query
reqCsvReader = csv.NewReader(r.Body)
}
reqCsvReader.FieldsPerRecord = -1
// Iterate over each query
for {
csvRecord, err := reqCsvReader.Read()
if r.Method == "POST" {
// Parameterized query
if err == io.EOF || err == http.ErrBodyReadAfterClose {
// EOF || last line is without \n
break
} else if err != nil {
http.Error(w, fmt.Sprintf("\n\nError reading request body: %v\n\n%s", err, helpMessage), http.StatusInternalServerError)
return
}
} else {
csvRecord = make([]string, 0)
}
// Init queryResponse
// Set queryResponse.Headers to the query's params (the fields of the csv record)
var queryResponse queryResult
queryResponse.In = csvRecord
queryParams := make([]interface{}, len(csvRecord))
for i := range csvRecord {
queryParams[i] = csvRecord[i]
}
rows, err := queryStmt.Query(queryParams...)
if err != nil {
http.Error(w, fmt.Sprintf("\n\nError executing query for params %#v: %v\n\n%s", csvRecord, err, helpMessage), http.StatusInternalServerError)
return
}
defer rows.Close()
// Set queryResponse.Headers to the query's columns
// Init queryResponse.Out
cols, err := rows.Columns()
if err != nil {
http.Error(w, fmt.Sprintf("\n\nError reading columns for query with params %#v: %v\n\n%s", csvRecord, err, helpMessage), http.StatusInternalServerError)
return
}
queryResponse.Headers = cols
queryResponse.Out = make([][]interface{}, 0)
// Iterate over returned rows for this query
// Append each row to queryResponse.Out
for rows.Next() {
row := make([]interface{}, len(cols))
pointers := make([]interface{}, len(row))
for i := range row {
pointers[i] = &row[i]
}
err = rows.Scan(pointers...)
if err != nil {
http.Error(w, fmt.Sprintf("\n\nError reading query results for params %#v: %v\n\n%s", csvRecord, err, helpMessage), http.StatusInternalServerError)
return
}
queryResponse.Out = append(queryResponse.Out, row)
}
err = rows.Err()
if err != nil {
http.Error(w, fmt.Sprintf("\n\nError executing query: %v\n\n%s", err, helpMessage), http.StatusInternalServerError)
return
}
fullResponse = append(fullResponse, queryResponse)
if r.Method == "GET" {
// Static query - execute only once
break
}
}
// Return json
w.Header().Add("Content-Type", "application/json")
answerJSON, err := json.Marshal(fullResponse)
if err != nil {
http.Error(w, fmt.Sprintf("\n\nError encoding json: %v\n\n%s", err, helpMessage), http.StatusInternalServerError)
return
}
_, err = w.Write(answerJSON)
if err != nil {
http.Error(w, fmt.Sprintf("\n\nError sending json to client: %v\n\n%s", err, helpMessage), http.StatusInternalServerError)
return
}
}, nil
}
func buildHelpMessage(helpMessage string, queryString string, queryStmt *sql.Stmt, serverPort uint) string {
helpMessage += fmt.Sprintf(`Query:
%s
`, queryString)
queryParamsCount, err := countParams(queryStmt)
if err != nil {
log.Printf("Error extracting params count from query: %v\n", err)
} else {
helpMessage += fmt.Sprintf(`Params count (question marks in query):
%d
`, queryParamsCount)
}
helpMessage += fmt.Sprintf(`Request examples:
$ echo -e "$QUERY1_PARAM1,$QUERY1_PARAM2\n$QUERY2_PARAM1,$QUERY2_PARAM2" curl "http://$ADDRESS:%d/query" --data-binary @-
$ curl "http://$ADDRESS:%d/query" -d "$PARAM_1,$PARAM_2,...,$PARAM_N"
- Request must be a HTTP POST to "http://$ADDRESS:%d/query".
- Request body must be a valid CSV.
- Request body must not have a CSV header.
- Each request body line is a different query.
- Each param in a line corresponds to a query param (a question mark in the query string).
- Static query (without any query params):
- The request must be a HTTP GET to "http://$ADDRESS:%d/query".
- The query executes only once.
`, serverPort, serverPort, serverPort, serverPort)
helpMessage += fmt.Sprintf(`Response example:
$ echo -e "github.com\none.one.one.one\ngoogle-public-dns-a.google.com" | curl "http://$ADDRESS:%d/query" --data-binary @-
[
{
"in": ["github.com"],
"headers": ["ip","dns"],
"out": [
["192.30.253.112","github.com"],
["192.30.253.113","github.com"]
]
},
{
"in": ["one.one.one.one"],
"headers": ["ip","dns"],
"out": [
["1.1.1.1","one.one.one.one"]
]
},
{
"in": ["google-public-dns-a.google.com"],
"headers": ["ip","dns"],
"out": [
["8.8.8.8","google-public-dns-a.google.com"]
]
}
]
- Response is a JSON array (Content-Type: application/json).
- Each element in the array:
- Is a result of a query
- Has an "in" fields which is an array of the input params (a request body line).
- Has an "headers" fields which is an array of headers of the SQL query result.
- Has an "out" field which is an array of arrays of results. Each inner array is a result row.
- Element #1 is the result of query #1, Element #2 is the result of query #2, and so forth.
- Static query (without any query params):
- The response JSON has only one element.
For more info visit https://github.com/assafmo/SQLiteQueryServer
`, serverPort)
return helpMessage
}
func countParams(queryStmt *sql.Stmt) (int, error) {
// Query with 0 params
rows, err := queryStmt.Query()
if err == nil {
// Query went fine, this means it has 0 params
rows.Close()
return 0, nil
}
// Query returned an error
// Parse the error to get the expected params count
regex := regexp.MustCompile(`sql: expected (\p{N}+) arguments, got 0`)
regexSubmatches := regex.FindAllStringSubmatch(err.Error(), 1)
if len(regexSubmatches) != 1 || len(regexSubmatches[0]) != 2 {
// This is weird
// queryStmt is prepared (compiled) so it is valid
// but yet there was an error executing queryStmt
return -1, fmt.Errorf("Cannot extract params count from query error: %v", err)
}
countString := regexSubmatches[0][1]
count, err := strconv.Atoi(countString)
if err != nil {
// This is even weirder
// The regex is \p{N}+ (unicode number sequence) and there was a match,
// but converting it from string to int returned an error
return -1, fmt.Errorf(`Cannot convert \p{N}+ regex to int: %v`, err)
}
return count, nil
}