Skip to content

Commit

Permalink
feat: Added redirect configuration, client certificate support, Http2…
Browse files Browse the repository at this point in the history
… support, and built-in middleware including header, cookie, cache middleware, and wrote corresponding functional documentation
  • Loading branch information
CodeyCoo committed Feb 12, 2025
1 parent cec6baa commit 3c82aa3
Show file tree
Hide file tree
Showing 15 changed files with 1,210 additions and 8 deletions.
132 changes: 125 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package requests

import (
"crypto/tls"
"crypto/x509"
"net/http"
"net/http/cookiejar"
"os"
"sync"
"time"

"golang.org/x/net/http2"
)

// Client represents an HTTP client
Expand Down Expand Up @@ -44,6 +48,7 @@ type Config struct {
RetryStrategy BackoffStrategy // The backoff strategy function
RetryIf RetryIfFunc // Custom function to determine retry based on request and response
Logger Logger // Logger instance for the client
HTTP2 bool // Whether to use HTTP/2,The priority of http2 is lower than that of Transport
}

// URL creates a new HTTP client with the given base URL.
Expand Down Expand Up @@ -85,14 +90,27 @@ func Create(config *Config) *Client {
TLSConfig: config.TLSConfig,
}

// If a TLS configuration is provided, apply it to the Transport.
if client.TLSConfig != nil && httpClient.Transport != nil {
httpTransport := httpClient.Transport.(*http.Transport)
httpTransport.TLSClientConfig = client.TLSConfig
} else if client.TLSConfig != nil {
httpClient.Transport = &http.Transport{
TLSClientConfig: client.TLSConfig,
// Configure Transport, handle both TLS and HTTP/2
if client.TLSConfig != nil {
if config.HTTP2 {
// Use HTTP/2
client.HTTPClient.Transport = &http2.Transport{
TLSClientConfig: client.TLSConfig,
}
} else {
// Use HTTP/1.1
if httpClient.Transport != nil {
if transport, ok := httpClient.Transport.(*http.Transport); ok {
transport.TLSClientConfig = client.TLSConfig
}
} else {
client.HTTPClient.Transport = &http.Transport{
TLSClientConfig: client.TLSConfig,
}
}
}
} else if config.HTTP2 {
client.HTTPClient.Transport = &http2.Transport{}
}

if config.Middlewares != nil {
Expand Down Expand Up @@ -195,6 +213,70 @@ func (c *Client) InsecureSkipVerify() *Client {
return c
}

// SetCertificates sets the TLS certificates for the client.
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
c.mu.Lock()
defer c.mu.Unlock()

if c.TLSConfig == nil {
c.TLSConfig = &tls.Config{}
}
c.TLSConfig.Certificates = certs
return c
}

// SetRootCertificate sets the root certificate for the client.
func (c *Client) SetRootCertificate(pemFilePath string) *Client {
rootPemData, err := os.ReadFile(pemFilePath)
if err != nil {
return c
}
c.handleCAs("root", rootPemData)
return c
}

// SetRootCertificateFromString sets the root certificate for the client from a string.
func (c *Client) SetRootCertificateFromString(pemCerts string) *Client {
return c.handleCAs("root", []byte(pemCerts))
}

// SetClientRootCertificate sets the client root certificate for the client.
func (c *Client) SetClientRootCertificate(pemFilePath string) *Client {
rootPemData, err := os.ReadFile(pemFilePath)
if err != nil {
return c
}
return c.handleCAs("client", rootPemData)
}

// SetClientRootCertificateFromString sets the client root certificate for the client from a string.
func (c *Client) SetClientRootCertificateFromString(pemCerts string) *Client {
return c.handleCAs("client", []byte(pemCerts))
}

// handleCAs sets the TLS certificates for the client.
func (c *Client) handleCAs(scope string, permCerts []byte) *Client {
c.mu.Lock()
defer c.mu.Unlock()

if c.TLSConfig == nil {
c.TLSConfig = &tls.Config{}
}
switch scope {
case "root":
if c.TLSConfig.RootCAs == nil {
c.TLSConfig.RootCAs = x509.NewCertPool()
}
c.TLSConfig.RootCAs.AppendCertsFromPEM(permCerts)
case "client":
if c.TLSConfig.ClientCAs == nil {
c.TLSConfig.ClientCAs = x509.NewCertPool()
}
c.TLSConfig.ClientCAs.AppendCertsFromPEM(permCerts)
}
return c
}

// SetHTTPClient sets the HTTP client for the client
func (c *Client) SetHTTPClient(httpClient *http.Client) {
c.mu.Lock()
Expand Down Expand Up @@ -322,41 +404,59 @@ func (c *Client) DelDefaultCookie(name string) {

// SetJSONMarshal sets the JSON marshal function for the client's JSONEncoder
func (c *Client) SetJSONMarshal(marshalFunc func(v any) ([]byte, error)) {
c.mu.Lock()
defer c.mu.Unlock()

c.JSONEncoder = &JSONEncoder{
MarshalFunc: marshalFunc,
}
}

// SetJSONUnmarshal sets the JSON unmarshal function for the client's JSONDecoder
func (c *Client) SetJSONUnmarshal(unmarshalFunc func(data []byte, v any) error) {
c.mu.Lock()
defer c.mu.Unlock()

c.JSONDecoder = &JSONDecoder{
UnmarshalFunc: unmarshalFunc,
}
}

// SetXMLMarshal sets the XML marshal function for the client's XMLEncoder
func (c *Client) SetXMLMarshal(marshalFunc func(v any) ([]byte, error)) {
c.mu.Lock()
defer c.mu.Unlock()

c.XMLEncoder = &XMLEncoder{
MarshalFunc: marshalFunc,
}
}

// SetXMLUnmarshal sets the XML unmarshal function for the client's XMLDecoder
func (c *Client) SetXMLUnmarshal(unmarshalFunc func(data []byte, v any) error) {
c.mu.Lock()
defer c.mu.Unlock()

c.XMLDecoder = &XMLDecoder{
UnmarshalFunc: unmarshalFunc,
}
}

// SetYAMLMarshal sets the YAML marshal function for the client's YAMLEncoder
func (c *Client) SetYAMLMarshal(marshalFunc func(v any) ([]byte, error)) {
c.mu.Lock()
defer c.mu.Unlock()

c.YAMLEncoder = &YAMLEncoder{
MarshalFunc: marshalFunc,
}
}

// SetYAMLUnmarshal sets the YAML unmarshal function for the client's YAMLDecoder
func (c *Client) SetYAMLUnmarshal(unmarshalFunc func(data []byte, v any) error) {
c.mu.Lock()
defer c.mu.Unlock()

c.YAMLDecoder = &YAMLDecoder{
UnmarshalFunc: unmarshalFunc,
}
Expand Down Expand Up @@ -391,11 +491,29 @@ func (c *Client) SetRetryIf(retryIf RetryIfFunc) *Client {

// SetAuth configures an authentication method for the client.
func (c *Client) SetAuth(auth AuthMethod) {
c.mu.Lock()
defer c.mu.Unlock()

if auth.Valid() {
c.auth = auth
}
}

// SetRedirectPolicy sets the redirect policy for the client
func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
for _, p := range policies {
if err := p.Apply(req, via); err != nil {
return err
}
}
return nil
}
return c
}

// SetLogger sets logger instance in client.
func (c *Client) SetLogger(logger Logger) *Client {
c.mu.Lock()
Expand Down
169 changes: 169 additions & 0 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -877,3 +877,172 @@ func TestSetRetryIf(t *testing.T) {
t.Errorf("Expected 2 retries, got %d", retryCount)
}
}

func TestClientCertificates(t *testing.T) {
serverCert, err := tls.LoadX509KeyPair(".github/testdata/cert.pem", ".github/testdata/key.pem")
require.NoError(t, err, "load server certificate failed")

server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.TLS != nil && len(r.TLS.PeerCertificates) > 0 {
w.WriteHeader(http.StatusOK)
w.Write([]byte("certificate verification successful"))

Check failure on line 888 in client_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `w.Write` is not checked (errcheck)
} else {
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte("lack of client certificate"))

Check failure on line 891 in client_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `w.Write` is not checked (errcheck)
}
}))
clientCertPool := x509.NewCertPool()
clientCertData, err := os.ReadFile(".github/testdata/cert.pem")
require.NoError(t, err, "load client certificate failed")
clientCertPool.AppendCertsFromPEM(clientCertData)
clientCertPath := ".github/testdata/cert.pem"

server.TLS = &tls.Config{
Certificates: []tls.Certificate{serverCert},
ClientCAs: clientCertPool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
server.StartTLS()
defer server.Close()

client := Create(&Config{
BaseURL: server.URL,
})

t.Run("use client certificate", func(t *testing.T) {
clientCert, err := tls.LoadX509KeyPair(".github/testdata/cert.pem", ".github/testdata/key.pem")
require.NoError(t, err, "load client certificate failed")

client.SetTLSConfig(&tls.Config{
InsecureSkipVerify: true,
})
client.SetCertificates(clientCert)
client.SetClientRootCertificate(clientCertPath)
resp, err := client.Get("/").Send(context.Background())
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Close()

Check failure on line 925 in client_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `resp.Close` is not checked (errcheck)

assert.Equal(t, http.StatusOK, resp.StatusCode(), "status code not correct")
assert.Equal(t, "certificate verification successful", resp.String(), "response content not correct")
})

t.Run("do not use client certificate", func(t *testing.T) {
clientWithoutCert := Create(&Config{
BaseURL: server.URL,
})
clientWithoutCert.SetTLSConfig(&tls.Config{
InsecureSkipVerify: true,
})
clientWithoutCert.SetClientRootCertificate(clientCertPath)

_, err := clientWithoutCert.Get("/").Send(context.Background())
assert.Error(t, err, "expect request failed")
})
}

func TestClientSetRootCertificate(t *testing.T) {
t.Run("root cert", func(t *testing.T) {
filePath := ".testdata/sample_root.pem"

client := Create(nil)
client.SetRootCertificate(filePath)

if transport, ok := client.HTTPClient.Transport.(*http.Transport); ok {
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
}
})

t.Run("root cert not exists", func(t *testing.T) {
filePath := "../.testdata/not-exists-sample-root.pem"

client := Create(nil)
client.SetRootCertificate(filePath)

if transport, ok := client.HTTPClient.Transport.(*http.Transport); ok {
assert.Nil(t, transport.TLSClientConfig)
}
})

t.Run("root cert from string", func(t *testing.T) {
client := Create(nil)

cert := `-----BEGIN CERTIFICATE-----`

client.SetRootCertificateFromString(cert)
if transport, ok := client.HTTPClient.Transport.(*http.Transport); ok {
assert.NotNil(t, transport.TLSClientConfig.RootCAs)
}
})
}

func TestHttp2Scenarios(t *testing.T) {
tests := []struct {
name string
config *Config
url string
expectedVersion string
expectedError string
}{
{
name: "Default HTTP version, request to use http2 version URL",
config: &Config{},
url: "https://tools.scrapfly.io/api/fp/anything",
expectedVersion: "HTTP/2.0",
expectedError: "",
},
{
name: "Explicit HTTP/2, request to use http2 version URL",
config: &Config{HTTP2: true},
url: "https://tools.scrapfly.io/api/fp/anything",
expectedVersion: "HTTP/2.0",
expectedError: "",
},
{
name: "Set Transport, request to use http2 version URL,The priority of http2 is lower than that of Transport",
config: &Config{Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}},
url: "https://tools.scrapfly.io/api/fp/anything",
expectedVersion: "",
expectedError: "Get \"https://tools.scrapfly.io/api/fp/anything\": EOF",
},
{
name: "Set Transport, request to use http1.1 version URL",
config: &Config{Transport: &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}},
url: "https://www.baidu.com",
expectedVersion: "HTTP/1.1",
expectedError: "",
},
{
name: "Explicit HTTP/2 with Baidu",
config: &Config{HTTP2: true},
url: "https://www.baidu.com",
expectedVersion: "",
expectedError: "Get \"https://www.baidu.com\": http2: unexpected ALPN protocol \"\"; want \"h2\"",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client := Create(tt.config)

resp, err := client.Get(tt.url).Send(context.Background())
if err != nil {
assert.Equal(t, tt.expectedError, err.Error(), "Protocol settings are incorrect")
return
} else {
defer resp.Close()

Check failure on line 1042 in client_test.go

View workflow job for this annotation

GitHub Actions / Lint

Error return value of `resp.Close` is not checked (errcheck)

assert.Equal(t, tt.expectedVersion, resp.RawResponse.Proto, "Protocol version mismatch")
}
})
}
}
Loading

0 comments on commit 3c82aa3

Please sign in to comment.