diff --git a/.travis.yml b/.travis.yml index e2d1599f..54ce171f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,13 +1,14 @@ language: go +go: + - 1.1 env: - TRAVIS="yes" install: - go get github.com/shadowsocks/shadowsocks-go/shadowsocks - go get github.com/cyfdecyf/bufio - go get github.com/cyfdecyf/leakybuf - - go get code.google.com/p/go.crypto/blowfish - - go get code.google.com/p/go.crypto/cast5 script: - - cd $TRAVIS_BUILD_DIR + - pushd $TRAVIS_BUILD_DIR - go test -v - ./script/test.sh + - popd diff --git a/CHANGELOG b/CHANGELOG index abbb930b..12d98884 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,7 @@ +0.7.2 (2013-07-01) + * Close idle server connections earlier: avoid opening too many sockets + * Support authenticating multiple users (can limit port for each user) + 0.7.1 (2013-06-08) * Fix parent proxy fallback bug diff --git a/README.md b/README.md index aac04693..7d692281 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ COW 是一个利用二级代理帮助自动化翻墙的 HTTP 代理服务器。它能自动检测被墙网站,且仅对被墙网站使用二级代理。 -当前版本:0.7.0 [CHANGELOG](CHANGELOG) +当前版本:0.7.2 [CHANGELOG](CHANGELOG) [![Build Status](https://travis-ci.org/cyfdecyf/cow.png?branch=master)](https://travis-ci.org/cyfdecyf/cow) **如果要发 pull request,请在最新的 develop branch 上进行开发。** diff --git a/auth.go b/auth.go index 2da891d0..8f9ec73d 100644 --- a/auth.go +++ b/auth.go @@ -2,8 +2,11 @@ package main import ( "bytes" + "errors" "fmt" + "github.com/cyfdecyf/bufio" "net" + "os" "strconv" "strings" "text/template" @@ -29,25 +32,61 @@ type netAddr struct { mask net.IPMask } +type authUser struct { + // user name is the key to auth.user, no need to store here + passwd string + ha1 string // used in request digest, initialized ondemand + port uint16 // 0 means any port +} + var auth struct { required bool - user string - passwd string - ha1 string // used in request digest + user map[string]*authUser allowedClient []netAddr - authed *TimeoutSet // cache authentication based on client ip + authed *TimeoutSet // cache authenticated users based on ip template *template.Template } +func (au *authUser) initHA1(user string) { + if au.ha1 == "" { + au.ha1 = md5sum(user + ":" + authRealm + ":" + au.passwd) + } +} + +func parseUserPasswd(userPasswd string) (user string, au *authUser, err error) { + arr := strings.Split(userPasswd, ":") + n := len(arr) + if n == 1 || n > 3 { + err = errors.New("user password: " + userPasswd + + " syntax wrong, should be username:password[:port]") + return + } + user, passwd := arr[0], arr[1] + if user == "" || passwd == "" { + err = errors.New("user password " + userPasswd + + " should not contain empty user name or password") + return "", nil, err + } + var port int + if n == 3 && arr[2] != "" { + port, err = strconv.Atoi(arr[2]) + if err != nil || port <= 0 || port > 0xffff { + err = errors.New("user password: " + userPasswd + " invalid port") + return "", nil, err + } + } + au = &authUser{passwd, "", uint16(port)} + return user, au, nil +} + func parseAllowedClient(val string) { if val == "" { return } - auth.required = true arr := strings.Split(val, ",") auth.allowedClient = make([]netAddr, len(arr)) for i, v := range arr { @@ -58,13 +97,13 @@ func parseAllowedClient(val string) { } ip := net.ParseIP(ipAndMask[0]) if ip == nil { - Fatal("allowedClient syntax error %s: ip address not valid\n", s) + Fatalf("allowedClient syntax error %s: ip address not valid\n", s) } var mask net.IPMask if len(ipAndMask) == 2 { nbit, err := strconv.Atoi(ipAndMask[1]) if err != nil { - Fatal("allowedClient syntax error %s: %v\n", s, err) + Fatalf("allowedClient syntax error %s: %v\n", s, err) } if nbit > 32 { Fatal("allowedClient error: mask number should <= 32") @@ -77,30 +116,55 @@ func parseAllowedClient(val string) { } } -func parseUserPasswd(val string) { +func addUserPasswd(val string) { if val == "" { return } - auth.required = true - // password format checking is done in checkConfig in config.go - arr := strings.SplitN(val, ":", 2) - auth.user, auth.passwd = arr[0], arr[1] + user, au, err := parseUserPasswd(val) + debug.Println("user:", user, "port:", au.port) + if err != nil { + Fatal(err) + } + if _, ok := auth.user[user]; ok { + Fatal("duplicate user:", user) + } + auth.user[user] = au } -func initAuth() { - parseUserPasswd(config.UserPasswd) - parseAllowedClient(config.AllowedClient) - - if !auth.required { +func loadUserPasswdFile(file string) { + if file == "" { return } + f, err := os.Open(file) + if err != nil { + Fatal("error opening user passwd fle:", err) + } - auth.authed = NewTimeoutSet(time.Duration(config.AuthTimeout) * time.Hour) + r := bufio.NewReader(f) + s := bufio.NewScanner(r) + for s.Scan() { + addUserPasswd(s.Text()) + } + f.Close() +} - if auth.user == "" { +func initAuth() { + if config.UserPasswd != "" || + config.UserPasswdFile != "" || + config.AllowedClient != "" { + auth.required = true + } else { return } - auth.ha1 = md5sum(auth.user + ":" + authRealm + ":" + auth.passwd) + + auth.user = make(map[string]*authUser) + + addUserPasswd(config.UserPasswd) + loadUserPasswdFile(config.UserPasswdFile) + parseAllowedClient(config.AllowedClient) + + auth.authed = NewTimeoutSet(time.Duration(config.AuthTimeout) * time.Hour) + rawTemplate := "HTTP/1.1 407 Proxy Authentication Required\r\n" + "Proxy-Authenticate: Digest realm=\"" + authRealm + "\", nonce=\"{{.Nonce}}\", qop=\"auth\"\r\n" + "Content-Type: text/html\r\n" + @@ -108,7 +172,7 @@ func initAuth() { "Content-Length: " + fmt.Sprintf("%d", len(authRawBodyTmpl)) + "\r\n\r\n" + authRawBodyTmpl var err error if auth.template, err = template.New("auth").Parse(rawTemplate); err != nil { - Fatal("Internal error generating auth template:", err) + Fatal("internal error generating auth template:", err) } } @@ -123,11 +187,14 @@ func Authenticate(conn *clientConn, r *Request) (err error) { if authIP(clientIP) { // IP is allowed return } - // No user specified - if auth.user == "" { - sendErrorPage(conn, "403 Forbidden", "Access forbidden", "You are not allowed to use the proxy.") - return errShouldClose - } + /* + // No user specified + if auth.user == "" { + sendErrorPage(conn, "403 Forbidden", "Access forbidden", + "You are not allowed to use the proxy.") + return errShouldClose + } + */ err = authUserPasswd(conn, r) if err == nil { auth.authed.add(clientIP) @@ -175,7 +242,7 @@ func calcRequestDigest(kv map[string]string, ha1, method string) string { return md5sum(buf.String()) } -func checkProxyAuthorization(r *Request) error { +func checkProxyAuthorization(conn *clientConn, r *Request) error { debug.Println("authorization:", r.ProxyAuthorization) arr := strings.SplitN(r.ProxyAuthorization, " ", 2) if len(arr) != 2 { @@ -200,21 +267,39 @@ func checkProxyAuthorization(r *Request) error { if time.Now().Sub(time.Unix(nonceTime, 0)) > time.Minute { return errAuthRequired } - if authHeader["username"] != auth.user { - errl.Println("auth: username mismatch:", authHeader["username"]) + + user := authHeader["username"] + au, ok := auth.user[user] + if !ok { + errl.Println("auth: no such user:", authHeader["username"]) return errAuthRequired } + + if au.port != 0 { + // check port + _, portStr := splitHostPort(conn.LocalAddr().String()) + port, _ := strconv.Atoi(portStr) + if uint16(port) != au.port { + errl.Println("auth: user", user, "port not match") + return errAuthRequired + } + } + if authHeader["qop"] != "auth" { - errl.Println("auth: qop wrong:", authHeader["qop"]) - return errBadRequest + msg := "auth: qop wrong: " + authHeader["qop"] + errl.Println(msg) + return errors.New(msg) } + response, ok := authHeader["response"] if !ok { - errl.Println("auth: no request-digest") - return errBadRequest + msg := "auth: no request-digest" + errl.Println(msg) + return errors.New(msg) } - digest := calcRequestDigest(authHeader, auth.ha1, r.Method) + au.initHA1(user) + digest := calcRequestDigest(authHeader, au.ha1, r.Method) if response == digest { return nil } @@ -225,13 +310,14 @@ func checkProxyAuthorization(r *Request) error { func authUserPasswd(conn *clientConn, r *Request) (err error) { if r.ProxyAuthorization != "" { // client has sent authorization header - err = checkProxyAuthorization(r) + err = checkProxyAuthorization(conn, r) if err == nil { return } else if err != errAuthRequired { - sendErrorPage(conn, errCodeBadReq, "Bad authorization request", "") + sendErrorPage(conn, errCodeBadReq, "Bad authorization request", err.Error()) return } + // auth required to through the following } nonce := genNonce() diff --git a/auth_test.go b/auth_test.go index d5cc302b..4655fdaa 100644 --- a/auth_test.go +++ b/auth_test.go @@ -5,6 +5,40 @@ import ( "testing" ) +func TestParseUserPasswd(t *testing.T) { + testData := []struct { + val string + user string + au *authUser + }{ + {"foo:bar", "foo", &authUser{"bar", "", 0}}, + {"foo:bar:-1", "", nil}, + {"hello:world:", "hello", &authUser{"world", "", 0}}, + {"hello:world:0", "", nil}, + {"hello:world:1024", "hello", &authUser{"world", "", 1024}}, + {"hello:world:65535", "hello", &authUser{"world", "", 65535}}, + } + + for _, td := range testData { + user, au, err := parseUserPasswd(td.val) + if td.au == nil { + if err == nil { + t.Error(td.val, "should return error") + } + continue + } + if td.user != user { + t.Error(td.val, "user should be:", td.user, "got:", user) + } + if td.au.passwd != au.passwd { + t.Error(td.val, "passwd should be:", td.au.passwd, "got:", au.passwd) + } + if td.au.port != au.port { + t.Error(td.val, "port should be:", td.au.port, "got:", au.port) + } + } +} + func TestCalcDigest(t *testing.T) { a1 := md5sum("cyf" + ":" + authRealm + ":" + "wlx") auth := map[string]string{ diff --git a/config.go b/config.go index 34998078..b2929a26 100644 --- a/config.go +++ b/config.go @@ -14,7 +14,7 @@ import ( ) const ( - version = "0.7.1" + version = "0.7.2" defaultListenAddr = "127.0.0.1:7777" ) @@ -38,9 +38,10 @@ type Config struct { hasHttpParent bool // authenticate client - UserPasswd string - AllowedClient string - AuthTimeout time.Duration + UserPasswd string + UserPasswdFile string // file that contains user:passwd:[port] pairs + AllowedClient string + AuthTimeout time.Duration // advanced options DialTimeout time.Duration @@ -83,6 +84,9 @@ func init() { config.ReadTimeout = defaultReadTimeout } +// Whether command line options specifies listen addr +var cmdHasListenAddr bool + func parseCmdLineConfig() *Config { var c Config var listenAddr string @@ -96,6 +100,7 @@ func parseCmdLineConfig() *Config { flag.Parse() if listenAddr != "" { configParser{}.ParseListen(listenAddr) + cmdHasListenAddr = true // must come after ParseListen } return &c } @@ -151,13 +156,14 @@ func (p configParser) ParseLogFile(val string) { } func (p configParser) ParseListen(val string) { - // Command line options has already specified listenAddr - if config.ListenAddr != nil { + if cmdHasListenAddr { return } arr := strings.Split(val, ",") - config.ListenAddr = make([]string, len(arr)) - for i, s := range arr { + if config.ListenAddr == nil { + config.ListenAddr = make([]string, 0, len(arr)) + } + for _, s := range arr { s = strings.TrimSpace(s) host, port := splitHostPort(s) if port == "" { @@ -169,7 +175,7 @@ func (p configParser) ParseListen(val string) { "%s represents all ip addresses on this host.\n", s) } } - config.ListenAddr[i] = s + config.ListenAddr = append(config.ListenAddr, s) } } @@ -334,6 +340,17 @@ func (p configParser) ParseUserPasswd(val string) { } } +func (p configParser) ParseUserPasswdFile(val string) { + exist, err := isFileExists(val) + if err != nil { + Fatal("userPasswdFile error:", err) + } + if !exist { + Fatal("userPasswdFile", val, "does not exist") + } + config.UserPasswdFile = val +} + func (p configParser) ParseAllowedClient(val string) { config.AllowedClient = val } diff --git a/doc/sample-config/rc b/doc/sample-config/rc index 1ef60ee6..fa8fad0c 100644 --- a/doc/sample-config/rc +++ b/doc/sample-config/rc @@ -1,4 +1,4 @@ -# 代理服务器监听地址,用逗号分隔多个地址 +# 代理服务器监听地址,用逗号分隔多个地址,或者重复多次来指定多个地址 # 0.0.0.0 表示监听本机所有 IP 地址 # PAC url 为 http:///pac,其中的代理服务器地址会根据 client 访问 ip 正确生成 # 如果使用端口映射来将 COW 的服务提供到外网,请参考高级选项中的 addrInPAC 选项 @@ -77,6 +77,12 @@ listen = 127.0.0.1:7777 # COW 总是先验证 IP 是否在 allowedClient 中,若不在其中再通过用户名密码认证 #userPasswd = username:password +# 如需指定多个用户名密码,可在下面选项指定的文件中列出,文件中每行内容如下 +# username:password[:port] +# port 为可选项,若指定,则该用户只能从指定端口连接 COW +# 注意:如有重复用户,COW 会报错退出 +#userPasswdFile = /path/to/file + # 认证失效时间 # 语法:2h3m4s 表示 2 小时 3 分钟 4 秒 #authTimeout = 2h diff --git a/error.go b/error.go index cc7b9145..f77d02be 100644 --- a/error.go +++ b/error.go @@ -13,7 +13,6 @@ var errPageRawTmpl = `

{{.H1}}

{{.Msg}} - {{.Form}}
Generated by COW at {{.T}} @@ -28,7 +27,7 @@ var headRawTmpl = "HTTP/1.1 {{.CodeReason}}\r\n" + "Content-Type: text/html\r\n" + "Content-Length: {{.Length}}\r\n" -var errPageTmpl, headTmpl, blockedFormTmpl, directFormTmpl *template.Template +var errPageTmpl, headTmpl *template.Template func init() { var err error @@ -40,17 +39,15 @@ func init() { } } -func genErrorPage(h1, msg, form string) (string, error) { +func genErrorPage(h1, msg string) (string, error) { var err error data := struct { - H1 string - Msg string - Form string - T string + H1 string + Msg string + T string }{ h1, msg, - form, time.Now().Format(time.ANSIC), } @@ -59,8 +56,8 @@ func genErrorPage(h1, msg, form string) (string, error) { return buf.String(), err } -func sendPageGeneric(w io.Writer, codeReason, h1, msg, form, addHeader string) { - page, err := genErrorPage(h1, msg, form) +func sendPageGeneric(w io.Writer, codeReason, h1, msg string) { + page, err := genErrorPage(h1, msg) if err != nil { errl.Println("Error generating error page:", err) return @@ -79,12 +76,11 @@ func sendPageGeneric(w io.Writer, codeReason, h1, msg, form, addHeader string) { return } - buf.WriteString(addHeader) buf.WriteString("\r\n") buf.WriteString(page) w.Write(buf.Bytes()) } func sendErrorPage(w io.Writer, codeReason, h1, msg string) { - sendPageGeneric(w, codeReason, "[Error] "+h1, msg, "", "") + sendPageGeneric(w, codeReason, "[Error] "+h1, msg) } diff --git a/http.go b/http.go index 2748f6ce..de4c8c68 100644 --- a/http.go +++ b/http.go @@ -416,7 +416,6 @@ func (h *Header) parseHeader(reader *bufio.Reader, raw *bytes.Buffer, url *URL) raw.Write(s) // debug.Printf("len %d %s", len(s), s) } - return } // Parse the request line and header, does not touch body @@ -426,6 +425,9 @@ func parseRequest(c *clientConn, r *Request) (err error) { setConnReadTimeout(c, clientConnTimeout, "parseRequest") // parse request line if s, err = reader.ReadSlice('\n'); err != nil { + if isErrTimeout(err) { + return errClientTimeout + } return err } unsetConnReadTimeout(c, "parseRequest") diff --git a/install-cow.sh b/install-cow.sh index e40cd51a..a8c3940d 100755 --- a/install-cow.sh +++ b/install-cow.sh @@ -1,6 +1,6 @@ #!/bin/bash -version=0.7.1 +version=0.7.2 arch=`uname -m` case $arch in diff --git a/main.go b/main.go index 684025df..00e5e163 100644 --- a/main.go +++ b/main.go @@ -78,14 +78,10 @@ func main() { go runEstimateTimeout() done := make(chan byte, 1) - // save 1 goroutine (a few KB) for the common case with only 1 listen address - if len(config.ListenAddr) > 1 { - for i, addr := range config.ListenAddr[1:] { - go NewProxy(addr, config.AddrInPAC[i+1]).Serve(done) - } + for i, addr := range config.ListenAddr { + go NewProxy(addr, config.AddrInPAC[i]).Serve(done) } - NewProxy(config.ListenAddr[0], config.AddrInPAC[0]).Serve(done) - for i := 0; i < len(config.ListenAddr); i++ { + for _, _ = range config.ListenAddr { <-done } } diff --git a/pac.go b/pac.go index 43051074..ba5a7517 100644 --- a/pac.go +++ b/pac.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "strings" + "sync" "text/template" "time" ) @@ -13,6 +14,24 @@ var pac struct { template *template.Template topLevelDomain string directList string + // Assignments and reads to directList are in different goroutines. Go + // does not guarantee atomic assignment, so we should protect these racing + // access. + dLRWMutex sync.RWMutex +} + +func getDirectList() string { + pac.dLRWMutex.RLock() + dl := pac.directList + pac.dLRWMutex.RUnlock() + return dl +} + +func updateDirectList() { + dl := strings.Join(siteStat.GetDirectList(), "\",\n\"") + pac.dLRWMutex.Lock() + pac.directList = dl + pac.dLRWMutex.Unlock() } func init() { @@ -129,7 +148,9 @@ func genPAC(c *clientConn) []byte { proxyAddr = net.JoinHostPort(host, c.proxy.port) } - if pac.directList == "" { + dl := getDirectList() + + if dl == "" { // Empty direct domain list buf.Write(pacHeader) pacproxy := fmt.Sprintf("function FindProxyForURL(url, host) { return 'PROXY %s; DIRECT'; };", @@ -144,7 +165,7 @@ func genPAC(c *clientConn) []byte { TopLevel string }{ proxyAddr, - pac.directList, + dl, pac.topLevelDomain, } @@ -157,11 +178,13 @@ func genPAC(c *clientConn) []byte { } func initPAC() { - pac.directList = strings.Join(siteStat.GetDirectList(), "\",\n\"") + // we can't control goroutine scheduling, this is to make sure when + // initPAC is done, direct list is updated + updateDirectList() go func() { for { - time.Sleep(10 * time.Minute) - pac.directList = strings.Join(siteStat.GetDirectList(), "\",\n\"") + time.Sleep(5 * time.Minute) + updateDirectList() } }() } diff --git a/parent_proxy.go b/parent_proxy.go index 0aa320f3..d1bb200f 100644 --- a/parent_proxy.go +++ b/parent_proxy.go @@ -20,7 +20,7 @@ func connectByParentProxy(url *URL) (srvconn conn, err error) { firstId := 0 if config.LoadBalance == loadBalanceHash { firstId = int(stringHash(url.Host) % uint64(nproxy)) - debug.Println("use proxy ", firstId) + debug.Println("use proxy", firstId) } for i := 0; i < nproxy; i++ { @@ -201,7 +201,7 @@ func (sp socksParent) connect(url *URL) (cn conn, err error) { // version/method selection repBuf := make([]byte, 2) - _, err = c.Read(repBuf) + _, err = io.ReadFull(c, repBuf) if err != nil { errl.Printf("read ver/method selection error %v\n", err) hasErr = true diff --git a/proxy.go b/proxy.go index 3cf2bdec..c65056c1 100644 --- a/proxy.go +++ b/proxy.go @@ -29,15 +29,19 @@ const httpBufSize = 8192 // holding post data. var httpBuf = leakybuf.NewLeakyBuf(512, httpBufSize) -// Close client connection if no new request received in some time. Keep it -// small to avoid keeping too many client connections (which associates with -// server connections) and causing too much open file error. (On OS X, the -// default soft limit of open file descriptor is 256, which is really -// conservative.) On the other hand, making this too small may cause reading -// normal client request timeout. -const clientConnTimeout = 10 * time.Second +// Close client connection if no new request received in some time. To prevent +// keeping too many idle (keep-alive) server connections, COW timeout read for +// the initial request line after clientConnTimeout, and closes client +// connection after clientMaxTimeoutCnt timeout. +// (On OS X, the default soft limit of open file descriptor is 256, which is +// very conservative and easy to cause problem if we are not careful.) +const clientConnTimeout = 5 * time.Second +const clientMaxTimeoutCnt = 2 const keepAliveHeader = "Keep-Alive: timeout=10\r\n" +// Remove idle server connection every cleanServerInterval second. +const cleanServerInterval = 5 * time.Second + // If client closed connection for HTTP CONNECT method in less then 1 second, // consider it as an ssl error. This is only effective for Chrome which will // drop connection immediately upon SSL error. @@ -115,6 +119,8 @@ type clientConn struct { bufRd *bufio.Reader buf []byte // buffer for the buffered reader serverConn map[string]*serverConn // request serverConn, host:port as key + timeoutCnt int // number of timeouts reading requests + cleanedOn time.Time // time of last idle server clean up proxy *Proxy } @@ -124,6 +130,7 @@ var ( errShouldClose = errors.New("Error can only be handled by close connection") errInternal = errors.New("Internal error") errNoParentProxy = errors.New("No parent proxy") + errClientTimeout = errors.New("Read client request timeout") errChunkedEncode = errors.New("Invalid chunked encoding") errMalformHeader = errors.New("Malformed HTTP header") @@ -208,8 +215,7 @@ func isSelfURL(url string) bool { func (c *clientConn) getRequest(r *Request) (err error) { if err = parseRequest(c, r); err != nil { - c.handleClientReadError(r, err, "parse client request") - return err + return c.handleClientReadError(r, err, "parse client request") } return nil } @@ -279,19 +285,19 @@ func (c *clientConn) serve() { // Refer to implementation.md for the design choices on parsing the request // and response. - cnt := 0 for { if c.bufRd == nil || c.buf == nil { - errl.Printf("%s client read buffer nil, served %d requests", - c.RemoteAddr(), cnt) - if r.URL != nil { - errl.Println("previous request:", &r) - } panic("client read buffer nil") } - cnt++ + // clean up idle server connection before waiting for client request + if c.shouldCleanServerConn() { + c.cleanServerConn() + } if err = c.getRequest(&r); err != nil { - if !isErrTimeout(err) { + if err == errClientTimeout { + continue + } + if err != errShouldClose { sendErrorPage(c, "404 Bad request", "Bad request", err.Error()) } return @@ -420,17 +426,28 @@ func (c *clientConn) handleServerWriteError(r *Request, sv *serverConn, err erro } func (c *clientConn) handleClientReadError(r *Request, err error, msg string) error { + if err == errClientTimeout { + c.timeoutCnt++ + if c.timeoutCnt >= clientMaxTimeoutCnt { + debug.Printf("%s client maybe has closed\n", msg) + return errShouldClose + } + debug.Printf("%s client read timeout\n", msg) + c.cleanServerConn() + return err + } + + if !debug { + return err + } + if err == io.EOF { debug.Printf("%s client closed connection", msg) - } else if debug { - if isErrConnReset(err) { - debug.Printf("%s connection reset", msg) - } else if isErrTimeout(err) { - debug.Printf("%s client read timeout, maybe has closed\n", msg) - } else { - // may reach here when header is larger than buffer size - debug.Printf("handleClientReadError: %s %v %v\n", msg, err, r) - } + } else if isErrConnReset(err) { + debug.Printf("%s connection reset", msg) + } else { + // may reach here when header is larger than buffer size or got malformed HTTP request + debug.Printf("handleClientReadError: %s %v %v\n", msg, err, r) } return err } @@ -525,6 +542,25 @@ func (c *clientConn) readResponse(sv *serverConn, r *Request, rp *Response) (err return } +func (c *clientConn) shouldCleanServerConn() bool { + return len(c.serverConn) > 0 && + time.Now().Sub(c.cleanedOn) > cleanServerInterval +} + +// Remove all maybe closed server connection +func (c *clientConn) cleanServerConn() { + if debug { + debug.Printf("%s client clean up idle server connection", c.RemoteAddr()) + } + now := time.Now() + c.cleanedOn = now + for _, sv := range c.serverConn { + if now.After(sv.willCloseOn) { + c.removeServerConn(sv) + } + } +} + func (c *clientConn) getServerConn(r *Request) (sv *serverConn, err error) { sv, ok := c.serverConn[r.URL.HostPort] if ok && sv.mayBeClosed() { @@ -794,7 +830,6 @@ func copyServer2Client(sv *serverConn, c *clientConn, r *Request) (err error) { sv.updateVisit() } } - return } type serverWriter struct { @@ -912,7 +947,6 @@ func copyClient2Server(c *clientConn, sv *serverConn, r *Request, srvStopped not } // debug.Printf("cli(%s)->srv(%s) sent %d bytes data\n", c.RemoteAddr(), r.URL.HostPort, n) } - return } var connEstablished = []byte("HTTP/1.1 200 Tunnel established\r\n\r\n") @@ -934,7 +968,7 @@ func (sv *serverConn) doConnect(r *Request, c *clientConn) (err error) { // debug.Printf("send connection confirmation to %s->%s\n", c.RemoteAddr(), r.URL.HostPort) if _, err = c.Write(connEstablished); err != nil { if debug { - debug.Printf("%v Error sending 200 Connecion established: %v\n", c.RemoteAddr(), err) + debug.Printf("%s error sending 200 Connecion established: %v\n", c.RemoteAddr(), err) } return err } @@ -1098,7 +1132,6 @@ func sendBodyChunked(r *bufio.Reader, w io.Writer, rdSize int) (err error) { return } } - return } const CRLF = "\r\n" @@ -1138,7 +1171,6 @@ func sendBodySplitIntoChunk(r *bufio.Reader, w io.Writer) (err error) { return } } - return } // Send message body. If req is not nil, read from client, send to server. If diff --git a/script/build.sh b/script/build.sh index 8175f819..d68ad95a 100755 --- a/script/build.sh +++ b/script/build.sh @@ -24,7 +24,7 @@ build() { if [[ $1 == "windows" ]]; then mv cow.exe script pushd script - cp ../doc/sample-config/rc sample-rc.txt + sed -e 's/$/\r/' ../doc/sample-config/rc > sample-rc.txt zip $name.zip cow.exe cow-taskbar.exe sample-rc.txt rm -f cow.exe sample-rc.txt mv $name.zip ../bin/ diff --git a/site_direct.go b/site_direct.go index 63d0ee75..8583772c 100644 --- a/site_direct.go +++ b/site_direct.go @@ -25,6 +25,7 @@ var directDomainList = []string{ "360buy.com", "360buyimg.com", + "jd.com", "51buy.com", "icson.com", diff --git a/sitestat.go b/sitestat.go index 6c09b18b..84125f28 100644 --- a/sitestat.go +++ b/sitestat.go @@ -421,5 +421,4 @@ func loadSiteList(fpath string) (lst []string, err error) { } lst = append(lst, strings.TrimSpace(site)) } - return } diff --git a/util.go b/util.go index e97e98a3..5a84b077 100644 --- a/util.go +++ b/util.go @@ -34,7 +34,6 @@ func (n notification) hasNotified() bool { default: return false } - return false } // ReadLine read till '\n' is found or encounter error. The returned line does @@ -437,10 +436,11 @@ func NewNbitIPv4Mask(n int) net.IPMask { var topLevelDomain = map[string]bool{ "ac": true, "co": true, - "org": true, "com": true, - "net": true, "edu": true, + "gov": true, + "net": true, + "org": true, } func trimLastDot(s string) string {