diff --git a/drivers/115/util.go b/drivers/115/util.go index 3fdd7a56686..aacaff253d9 100644 --- a/drivers/115/util.go +++ b/drivers/115/util.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/SheltonZhu/115driver/pkg/driver" + "github.com/alist-org/alist/v3/drivers/base" "github.com/pkg/errors" ) @@ -15,6 +16,7 @@ func (d *Pan115) login() error { driver.UA(UserAgent), } d.client = driver.New(opts...) + d.client.SetHttpClient(base.HttpClient) cr := &driver.Credential{} if d.Addition.QRCodeToken != "" { s := &driver.QRCodeSession{ diff --git a/drivers/base/client.go b/drivers/base/client.go index 2c29ce62abe..82f624338f1 100644 --- a/drivers/base/client.go +++ b/drivers/base/client.go @@ -9,28 +9,40 @@ import ( "github.com/go-resty/resty/v2" ) -var NoRedirectClient *resty.Client -var RestyClient = NewRestyClient() -var HttpClient = &http.Client{} +var ( + NoRedirectClient *resty.Client + RestyClient *resty.Client + HttpClient *http.Client +) var UserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/87.0.4280.88 Safari/537.36" var DefaultTimeout = time.Second * 30 -func init() { +func InitClient() { NoRedirectClient = resty.New().SetRedirectPolicy( resty.RedirectPolicyFunc(func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }), - ) + ).SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) NoRedirectClient.SetHeader("user-agent", UserAgent) + + RestyClient = NewRestyClient() + HttpClient = NewHttpClient() } func NewRestyClient() *resty.Client { client := resty.New(). SetHeader("user-agent", UserAgent). SetRetryCount(3). - SetTimeout(DefaultTimeout) - if conf.Conf != nil && conf.Conf.TlsInsecureSkipVerify { - client = client.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) - } + SetTimeout(DefaultTimeout). + SetTLSClientConfig(&tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}) return client } + +func NewHttpClient() *http.Client { + return &http.Client{ + Timeout: DefaultTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify}, + }, + } +} diff --git a/drivers/lanzou/driver.go b/drivers/lanzou/driver.go index cae0c3eb05f..8016c2d4db5 100644 --- a/drivers/lanzou/driver.go +++ b/drivers/lanzou/driver.go @@ -5,7 +5,6 @@ import ( "fmt" "net/http" "regexp" - "time" "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" @@ -15,8 +14,6 @@ import ( "github.com/go-resty/resty/v2" ) -var upClient = base.NewRestyClient().SetTimeout(120 * time.Second) - type LanZou struct { Addition model.Storage @@ -209,11 +206,11 @@ func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr var resp RespText[[]FileOrFolder] _, err := d._post(d.BaseUrl+"/fileup.php", func(req *resty.Request) { req.SetFormData(map[string]string{ - "task": "1", - "vie": "2", - "ve": "2", - "id": "WU_FILE_0", - "name": stream.GetName(), + "task": "1", + "vie": "2", + "ve": "2", + "id": "WU_FILE_0", + "name": stream.GetName(), "folder_id_bb_n": dstDir.GetID(), }).SetFileReader("upload_file", stream.GetName(), stream).SetContext(ctx) }, &resp, true) diff --git a/drivers/lanzou/util.go b/drivers/lanzou/util.go index 0803937c4d3..859eed34b97 100644 --- a/drivers/lanzou/util.go +++ b/drivers/lanzou/util.go @@ -7,6 +7,7 @@ import ( "regexp" "strconv" "strings" + "sync" "time" "github.com/alist-org/alist/v3/drivers/base" @@ -16,6 +17,9 @@ import ( log "github.com/sirupsen/logrus" ) +var upClient *resty.Client +var once sync.Once + func (d *LanZou) doupload(callback base.ReqCallback, resp interface{}) ([]byte, error) { return d.post(d.BaseUrl+"/doupload.php", func(req *resty.Request) { req.SetQueryParam("uid", d.uid) @@ -64,6 +68,9 @@ func (d *LanZou) _post(url string, callback base.ReqCallback, resp interface{}, func (d *LanZou) request(url string, method string, callback base.ReqCallback, up bool) ([]byte, error) { var req *resty.Request if up { + once.Do(func() { + upClient = base.NewRestyClient().SetTimeout(120 * time.Second) + }) req = upClient.R() } else { req = base.RestyClient.R() diff --git a/drivers/quark/driver.go b/drivers/quark/driver.go index 5882ab409ce..4185ffd5a1d 100644 --- a/drivers/quark/driver.go +++ b/drivers/quark/driver.go @@ -102,7 +102,7 @@ func (d *Quark) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( req.Header.Set("User-Agent", ua) req.Header.Set("Cookie", d.Cookie) req.Header.Set("Referer", "https://pan.quark.cn") - resp, err := http.DefaultClient.Do(req) + resp, err := base.HttpClient.Do(req) if err != nil { return err } diff --git a/drivers/trainbit/driver.go b/drivers/trainbit/driver.go index b8c514f5120..63bd0627f63 100644 --- a/drivers/trainbit/driver.go +++ b/drivers/trainbit/driver.go @@ -10,6 +10,7 @@ import ( "net/url" "strings" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -31,7 +32,7 @@ func (d *Trainbit) GetAddition() driver.Additional { } func (d *Trainbit) Init(ctx context.Context) error { - http.DefaultClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + base.HttpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse } var err error @@ -119,7 +120,7 @@ func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, stream model.FileS query := &url.Values{} query.Add("q", strings.Split(dstDir.GetID(), "_")[1]) query.Add("guid", guid) - query.Add("name", url.QueryEscape(local2provider(stream.GetName(), false) + ".")) + query.Add("name", url.QueryEscape(local2provider(stream.GetName(), false)+".")) endpoint.RawQuery = query.Encode() var total int64 total = 0 @@ -135,7 +136,7 @@ func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, stream model.FileS return err } req.Header.Set("Content-Type", "text/json; charset=UTF-8") - _, err = http.DefaultClient.Do(req) + _, err = base.HttpClient.Do(req) return err } diff --git a/drivers/trainbit/util.go b/drivers/trainbit/util.go index ab9c55cc255..afc111a8290 100644 --- a/drivers/trainbit/util.go +++ b/drivers/trainbit/util.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/model" ) @@ -38,7 +39,7 @@ func get(url string, apiKey string, AUSHELLPORTAL string) (*http.Response, error Value: apiKey, MaxAge: 2 * 60, }) - res, err := http.DefaultClient.Do(req) + res, err := base.HttpClient.Do(req) return res, err } @@ -65,7 +66,7 @@ func postForm(endpoint string, data url.Values, apiExpiredate string, apiKey str Value: apiKey, MaxAge: 2 * 60, }) - res, err := http.DefaultClient.Do(req) + res, err := base.HttpClient.Do(req) return res, err } diff --git a/drivers/webdav/odrvcookie/fetch.go b/drivers/webdav/odrvcookie/fetch.go index 93fa87aed54..a52fc68be0e 100644 --- a/drivers/webdav/odrvcookie/fetch.go +++ b/drivers/webdav/odrvcookie/fetch.go @@ -12,6 +12,7 @@ import ( "strings" "time" + "github.com/alist-org/alist/v3/drivers/base" "golang.org/x/net/publicsuffix" ) @@ -185,7 +186,7 @@ func (ca *CookieAuth) getSPToken() (*SuccessResponse, error) { return nil, err } - client := &http.Client{} + client := base.HttpClient resp, err := client.Do(req) if err != nil { return nil, err diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 0ec99142630..9436dbca76d 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -1,7 +1,6 @@ package bootstrap import ( - "crypto/tls" "net/url" "os" "path/filepath" @@ -78,9 +77,7 @@ func InitConfig() { log.Fatalf("create temp dir error: %+v", err) } log.Debugf("config: %+v", conf.Conf) - if conf.Conf.TlsInsecureSkipVerify { - base.RestyClient = base.RestyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) - } + base.InitClient() initURL() } diff --git a/internal/fs/util.go b/internal/fs/util.go index 6ff9f7b3be1..ca3cb48fb18 100644 --- a/internal/fs/util.go +++ b/internal/fs/util.go @@ -8,6 +8,7 @@ import ( stdpath "path" "strings" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" @@ -33,8 +34,6 @@ func containsByName(files []model.Obj, file model.Obj) bool { return false } -var httpClient = &http.Client{} - func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, error) { var rc io.ReadCloser mimetype := utils.GetMimeType(file.GetName()) @@ -60,7 +59,7 @@ func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, for h, val := range link.Header { req.Header[h] = val } - res, err := httpClient.Do(req) + res, err := base.HttpClient.Do(req) if err != nil { return nil, errors.Wrapf(err, "failed to get response for %s", link.URL) } diff --git a/server/common/proxy.go b/server/common/proxy.go index 397201fe22f..66724b67aa4 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -8,7 +8,9 @@ import ( "os" "strconv" "strings" + "sync" + "github.com/alist-org/alist/v3/drivers/base" "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/utils" @@ -16,16 +18,23 @@ import ( log "github.com/sirupsen/logrus" ) -var HttpClient = &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - if len(via) >= 10 { - return errors.New("stopped after 10 redirects") +func HttpClient() *http.Client { + once.Do(func() { + httpClient = base.NewHttpClient() + httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { + if len(via) >= 10 { + return errors.New("stopped after 10 redirects") + } + req.Header.Del("Referer") + return nil } - req.Header.Del("Referer") - return nil - }, + }) + return httpClient } +var once sync.Once +var httpClient *http.Client + func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model.Obj) error { // read data with native var err error @@ -90,7 +99,7 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. for h, val := range link.Header { req.Header[h] = val } - res, err := HttpClient.Do(req) + res, err := HttpClient().Do(req) if err != nil { return err }