-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathtlstunnel.ml
408 lines (363 loc) · 13.9 KB
/
tlstunnel.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
open Lwt.Infix
module Log = struct
let inet_to_string = function
| Lwt_unix.ADDR_INET (x, p) -> Unix.string_of_inet_addr x ^ ":" ^ string_of_int p
| Lwt_unix.ADDR_UNIX s -> s
let log_raw out event =
match out with
| None -> ()
| Some out ->
let open Unix in
let lt = gmtime (time ()) in
Printf.fprintf out "[%04d-%02d-%02dT%02d:%02d:%02dZ] %s\n%!"
(lt.tm_year + 1900) (succ lt.tm_mon) lt.tm_mday
lt.tm_hour lt.tm_min lt.tm_sec
event
let log out addr event =
let source = inet_to_string addr in
log_raw out (source ^ ": " ^ event)
let log_initial out back event front =
let listen = inet_to_string front
and forward = inet_to_string back
in
log_raw out (event ^ listen ^ ", forwarding to " ^ forward)
end
module Stats = struct
type stats = {
mutable read : int ;
mutable written : int
}
let new_stats () = { read = 0 ; written = 0 }
let inc_read s v = s.read <- s.read + v
let inc_written s v = s.written <- s.written + v
let print_stats stats =
"read " ^ (string_of_int stats.read) ^ " bytes, " ^
"wrote " ^ (string_of_int stats.written) ^ " bytes"
end
module Fd_logger = struct
let fds = ref []
let count = ref 0
let add_fd fd =
fds := fd :: !fds ;
count := succ !count
let aborted_to_string ab =
match Lwt_unix.state ab with
| Lwt_unix.Aborted exn -> Printexc.to_string exn
| _ -> ""
let log () =
let opened, closed, aborted =
List.fold_left (fun (o, c, a) x -> match Lwt_unix.state x with
| Lwt_unix.Opened -> (x :: o, c, a)
| Lwt_unix.Closed -> (o, x :: c, a)
| Lwt_unix.Aborted _ -> (o, c, x :: a))
([], [], []) !fds
in
fds := List.append opened aborted ;
Printf.sprintf "fds: count %d, active %d, open %d, closed %d, aborted %d%s"
!count (List.length !fds) (List.length opened) (List.length closed) (List.length aborted)
(if List.length aborted > 0 then
"\n" ^ (String.concat "\n " (List.map aborted_to_string aborted))
else
"")
let start logger () =
Lwt_engine.on_timer 60. true (fun _ -> logger (log ()))
end
module Haproxy1 = struct
(* implementation of the PROXY protocol as used by haproxy
* http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt
* This module only implements the protocol detailed in the
* "2.1. Human-readable header format (Version 1)" section.
*
* *)
let make_header socket =
(* basically it looks like:
* PROXY TCP4 SOURCEIP DESTIP SRCPORT DESTPORT\r\n *)
let own_sockaddr = Lwt_unix.getsockname socket in
let peer_sockaddr = Lwt_unix.getpeername socket in
let protocol_string =
begin match Unix.domain_of_sockaddr own_sockaddr with
| Unix.PF_UNIX -> failwith "TODO unix socket log and drop"
| Unix.PF_INET -> "TCP4"
| Unix.PF_INET6 -> "TCP6"
end in
let get_addr_port = function
| Lwt_unix.ADDR_INET (inet_addr , port)
-> (Unix.string_of_inet_addr inet_addr) , string_of_int port
| Lwt_unix.ADDR_UNIX _ -> failwith "TODO addr_unix" in
let peer_addr, peer_port = get_addr_port peer_sockaddr
and own_addr, own_port = get_addr_port own_sockaddr in
let header = String.concat " "
[ "PROXY" ; protocol_string ; peer_addr ; own_addr ; peer_port ; own_port ]
in
header ^ "\r\n"
end
let server_config auth cert priv_key =
X509_lwt.private_of_pems ~cert ~priv_key >|= fun cert ->
let reneg = match auth with None -> false | Some _ -> true in
Tls.Config.server ~reneg ~certificates:(`Single cert) ()
let init_socket log_raw frontend =
Unix.handle_unix_error (fun () ->
let open Lwt_unix in
let s = socket PF_INET SOCK_STREAM 0 in
setsockopt s SO_REUSEADDR true ;
bind s frontend >|= fun () ->
listen s 10 ;
log_raw "listener started on " frontend ;
s) ()
let bufsize = 4096
type res = Stop | Continue
let rec read_write debug log closing close cnt ic oc =
if !closing then
close ()
else
let doit () =
let buf = Bytes.create bufsize in
Lwt_io.read_into ic buf 0 bufsize >>= fun l ->
cnt l ;
if l > 0 then
let s = Bytes.sub buf 0 l in
(if debug then
log (String.concat " " ["read"; string_of_int l; "bytes:"; Bytes.to_string s])) ;
Lwt_io.write_from oc s 0 l
>|= fun n ->
(if debug then log (Printf.sprintf "wrote %d bytes" n)) ;
Continue
else
begin
(if debug then log "closing") ;
close () >|= fun () ->
Stop
end
in
Lwt.catch doit
(function
| Unix.Unix_error (Unix.EBADF, _, _) ->
(if debug then log "EBADF, closing") ;
close () >|= fun () -> Stop
| exn ->
log ("failed in read_write " ^ Printexc.to_string exn) ;
close () >|= fun () ->
Stop)
>>= function
| Stop -> Lwt.return_unit
| Continue -> read_write debug log closing close cnt ic oc
let tls_info t =
let v, c, cert =
match Tls_lwt.Unix.epoch t with
| `Ok data -> (data.Tls.Core.protocol_version, data.Tls.Core.ciphersuite, data.Tls.Core.peer_certificate)
| `Error -> assert false
in
let version = Sexplib.Sexp.to_string_hum (Tls.Core.sexp_of_tls_version v)
and cipher = Sexplib.Sexp.to_string_hum (Tls.Ciphersuite.sexp_of_ciphersuite c)
and cert = match cert with
| None -> ""
| Some x ->
let serial = Z.to_string (X509.Certificate.serial x)
and subject =
Fmt.to_to_string X509.Distinguished_name.pp (X509.Certificate.subject x)
in
", authenticated using " ^ subject ^ " (serial: " ^ serial ^ ")"
in
version ^ ", " ^ cipher ^ cert
let safe_close closing tls fd () =
closing := true ;
let safely f x =
Lwt.catch (fun _ -> f x) (fun _ -> Lwt.return_unit)
in
(match tls with
| Some x -> safely Tls_lwt.Unix.close x
| None -> Lwt.return_unit) >>= fun () ->
safely Lwt_unix.close fd
let worker config auth backend log s haproxy1 logfds debug trace () =
let closing = ref false in
Lwt.catch (fun () ->
Tls_lwt.Unix.server_of_fd config ?trace s >>= fun t ->
log ("connection established (" ^ (tls_info t) ^ ")") ;
(match auth with
| None -> Lwt.return_unit
| Some cas ->
let acceptable_cas = List.map X509.Certificate.subject cas in
let time = Ptime_clock.now () in
let authenticator = X509.Authenticator.chain_of_trust ~time cas in
Tls_lwt.Unix.reneg ~authenticator ~acceptable_cas t >|= fun () ->
log ("connection renegotiated (" ^ (tls_info t) ^ ")")) >>= fun () ->
let ic, oc = Tls_lwt.of_t t in
let stats = Stats.new_stats () in
let fd = Lwt_unix.(socket PF_INET SOCK_STREAM 0) in
if logfds then Fd_logger.add_fd fd ;
let close = safe_close closing (Some t) fd in
Lwt.catch (fun () ->
Lwt_unix.connect fd backend >>= fun () ->
let pic = Lwt_io.of_fd ~close ~mode:Lwt_io.Input fd
and poc = Lwt_io.of_fd ~close ~mode:Lwt_io.Output fd
in begin match haproxy1 with
| true ->
let haproxy1_header = Haproxy1.make_header s in
Lwt_io.write poc haproxy1_header
| false -> Lwt.return ()
end
>>= fun () ->
Lwt.join [
read_write debug log closing close (Stats.inc_read stats) ic poc ;
read_write debug log closing close (Stats.inc_written stats) pic oc
] >|= fun () ->
log ("connection closed " ^ (Stats.print_stats stats))
)
(function
| Unix.Unix_error (e, f, _) ->
let msg = Unix.error_message e in
log ("backend refused connection: " ^ msg ^ " while calling " ^ f) ;
close ()
| exn ->
close () >|= fun () ->
log ("received inner exception " ^ Printexc.to_string exn)))
(fun exn ->
safe_close closing None s () >|= fun () ->
log ("failed to establish TLS connection: " ^ Printexc.to_string exn))
let init out =
Printexc.register_printer (function
| Tls_lwt.Tls_alert x -> Some ("TLS alert: " ^ Tls.Packet.alert_type_to_string x)
| Tls_lwt.Tls_failure f -> Some ("TLS failure: " ^ Tls.Engine.string_of_failure f)
| _ -> None) ;
let out = match out with
| None -> Unix.out_channel_of_descr Unix.stdout
| Some x -> x
in
Lwt.async_exception_hook := (fun exn ->
Printf.fprintf out "async error %s\n%!" (Printexc.to_string exn))
let accept_loop s log_raw log_conn tls_config auth backend haproxy1 logfds debug trace =
let rec loop () =
Lwt.catch (fun () ->
Lwt_unix.accept s >>= fun (client_socket, addr) ->
(* log_conn addr "accepted incoming connection" ; *)
if logfds then Fd_logger.add_fd client_socket ;
Lwt.async (worker tls_config auth backend (log_conn addr) client_socket haproxy1 logfds debug trace) ;
loop ())
(function
| Unix.Unix_error (e, f, _) ->
let msg = Unix.error_message e in
log_raw ("accept failed " ^ msg ^ " in " ^ f) ;
loop ()
| exn ->
log_raw ("failure in accept_loop: " ^ Printexc.to_string exn) ;
loop ())
in
loop ()
let serve (fip, fport) (bip, bport) certificate privkey auth haproxy1 logfd logfds debug =
let logchan = match logfd with
| Some fd -> Some (Unix.out_channel_of_descr fd)
| None -> None
in
init logchan ;
let frontend = Lwt_unix.ADDR_INET (fip, fport)
and backend = Lwt_unix.ADDR_INET (bip, bport)
in
server_config auth certificate privkey >>= fun tls_config ->
(match auth with
| None -> Lwt.return None
| Some f -> X509_lwt.certs_of_pem f >|= fun a -> Some a) >>= fun auth ->
init_socket (Log.log_initial logchan backend) frontend >>= fun server_socket ->
let raw_log = Log.log_raw logchan in
if logfds then ignore (Fd_logger.start raw_log ()) ;
let trace =
if debug then
let out = match logchan with
| None -> Unix.out_channel_of_descr Unix.stdout
| Some x -> x
in
Some (fun sexp -> Printf.fprintf out "%s\n" Sexplib.Sexp.(to_string_hum sexp))
else
None
in
(* drop privileges here! *)
accept_loop server_socket raw_log (Log.log logchan) tls_config auth backend haproxy1 logfds debug trace
let run_server frontend backend certificate privkey auth haproxy1 log quiet logfds debug =
Sys.(set_signal sigpipe Signal_ignore) ;
let logfd = match quiet, log with
| true, None -> None
| false, None -> Some Unix.stdout
| false, Some x -> Some (Unix.openfile x [Unix.O_WRONLY ; Unix.O_APPEND; Unix.O_CREAT] 0o640)
| true, Some _ -> invalid_arg "cannot specify logfile and quiet"
in
let c, p = match certificate, privkey with
| Some c, Some p -> (c, p)
| Some c, None -> (c, c)
| None, _ -> invalid_arg "missing certificate file"
in
Lwt_main.run (serve frontend backend c p auth haproxy1 logfd logfds debug)
open Cmdliner
let resolve name =
let he = Unix.gethostbyname name in
if Array.length he.Unix.h_addr_list > 0 then
he.Unix.h_addr_list.(0)
else
let msg = "no address for " ^ name in
invalid_arg msg
let host_port default : (Unix.inet_addr * int) Arg.converter =
let parse s =
let host, port =
try
let colon = String.index s ':' in
let hostname =
if colon > 1 then
resolve (String.sub s 0 colon)
else
default
in
let csucc = succ colon in
(hostname, String.(sub s csucc (length s - csucc)))
with
Not_found -> (default, s)
in
let port = int_of_string port in
`Ok (host, port)
in
parse, fun ppf (h, p) -> Format.fprintf ppf "%s:%d" (Unix.string_of_inet_addr h) p
let backend =
let default = Unix.inet_addr_loopback in
let hp = host_port default in
Arg.(value & opt hp (default, 8080) & info ["b" ; "backend"]
~docv:"backend"
~doc:"The hostname and port of the backend [connect] service (default is [127.0.0.1]:8080)")
let frontend =
let default = Unix.inet_addr_any in
let hp = host_port default in
Arg.(value & opt hp (default, 4433) & info ["f" ; "frontend"]
~docv:"frontend"
~doc:"The hostname and port to listen on for incoming connections (default is [*]:4433")
let certificate =
Arg.(value & opt (some file) None & info ["cert"] ~docv:"FILE"
~doc:"The full path to PEM encoded certificate chain FILE (may also include the private key)")
let privkey =
Arg.(value & opt (some file) None & info ["key"] ~docv:"FILE"
~doc:"The full path to PEM encoded unencrypted private key in FILE (defaults to certificate_chain)")
let haproxy1 =
Arg.(value & flag & info ["haproxy1"] ~doc:"Forward protocol, IP, and port numbers to the destination using HA PROXY protocol v1 (for use with nginx, Varnish, etc - see http://www.haproxy.org/download/1.5/doc/proxy-protocol.txt)")
let auth =
Arg.(value & opt (some file) None & info ["auth"] ~docv:"FILE"
~doc:"The full path to PEM encoded CA certificate FILE used for client authentication")
let log =
Arg.(value & opt (some string) None & info ["l"; "logfile"] ~docv:"FILE"
~doc:"Write accesses to FILE (by default, logging is done to standard output).")
let logfds =
Arg.(value & flag & info ["logfds"] ~doc:"Log file descriptors")
let debug =
Arg.(value & flag & info ["debug"] ~doc:"Debug, show full traces")
let quiet =
Arg.(value & flag & info ["q"; "quiet"]
~doc:"Be quiet, no logging of accesses.")
let cmd =
let doc = "Proxy TLS connections to a standard TCP service" in
let man = [
`S "DESCRIPTION" ;
`P "$(tname) listens on a given port and forwards request to the specified hostname" ;
`S "BUGS" ;
`P "Please report bugs on the issue tracker at <https://github.com/hannesm/tlstunnel/issues>" ;
`S "SEE ALSO" ;
`P "$(b,stunnel)(1), $(b,stud)(1)" ]
in
Term.(pure run_server $ frontend $ backend $ certificate $ privkey $ auth $ haproxy1 $ log $ quiet $ logfds $ debug),
Term.info "tlstunnel" ~version:"%%VERSION_NUM%%" ~doc ~man
let () =
match Term.eval cmd
with `Error _ -> exit 1 | _ -> exit 0