Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add redirect support, client certificate support, HTTP/2 support, and built-in middleware #35

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .golangci.version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.57.2
1.64.2
6 changes: 3 additions & 3 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ linters:
- dogsled # Checks assignments with too many blank identifiers (e.g. x, , , _, := f())
- errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13.
- exhaustive # check exhaustiveness of enum switch statements
- exportloopref # checks for pointers to enclosing loop variables
- copyloopvar # checks for pointers to enclosing loop variables
# - gochecknoglobals # A global variable is a variable declared in package scope and that can be read and written to by any function within the package.
- gocritic # Provides diagnostics that check for bugs, performance and style issues.
# - goconst # Inspects source code for security problems
# - gocyclo # Computes and checks the cyclomatic complexity of functions
- goerr113 # Golang linter to check the errors handling expressions
- err113 # Golang linter to check the errors handling expressions
- gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification
- goimports # In addition to fixing imports, goimports also formats your code in the same style as gofmt.
- gomnd # An analyzer to detect magic numbers.
- mnd # An analyzer to detect magic numbers.
- goprintffuncname # Checks that printf-like functions are named with f at the end
- gosec # Inspects source code for security problems
- misspell # Finds commonly misspelled English words in comments
Expand Down
143 changes: 137 additions & 6 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@ package requests

import (
"crypto/tls"
"crypto/x509"
"net/http"
"net/http/cookiejar"
"os"
"path/filepath"
"strings"
"sync"
"time"

"golang.org/x/net/http2"
)

// Client represents an HTTP client
Expand Down Expand Up @@ -44,6 +50,7 @@ type Config struct {
RetryStrategy BackoffStrategy // The backoff strategy function
RetryIf RetryIfFunc // Custom function to determine retry based on request and response
Logger Logger // Logger instance for the client
HTTP2 bool // Whether to use HTTP/2,The priority of http2 is lower than that of Transport
}

// URL creates a new HTTP client with the given base URL.
Expand Down Expand Up @@ -85,15 +92,27 @@ func Create(config *Config) *Client {
TLSConfig: config.TLSConfig,
}

// If a TLS configuration is provided, apply it to the Transport.
if client.TLSConfig != nil && httpClient.Transport != nil {
httpTransport := httpClient.Transport.(*http.Transport)
httpTransport.TLSClientConfig = client.TLSConfig
} else if client.TLSConfig != nil {
httpClient.Transport = &http.Transport{
// Configure Transport, handle both TLS and HTTP/2
if client.TLSConfig != nil && config.HTTP2 {
// Use HTTP/2
client.HTTPClient.Transport = &http2.Transport{
TLSClientConfig: client.TLSConfig,
}
}
if client.TLSConfig != nil && !config.HTTP2 {
if httpClient.Transport != nil {
if transport, ok := httpClient.Transport.(*http.Transport); ok {
transport.TLSClientConfig = client.TLSConfig
}
} else {
client.HTTPClient.Transport = &http.Transport{
TLSClientConfig: client.TLSConfig,
}
}
}
if client.TLSConfig == nil && config.HTTP2 {
client.HTTPClient.Transport = &http2.Transport{}
}

if config.Middlewares != nil {
client.Middlewares = config.Middlewares
Expand Down Expand Up @@ -195,6 +214,82 @@ func (c *Client) InsecureSkipVerify() *Client {
return c
}

// SetCertificates sets the TLS certificates for the client.
func (c *Client) SetCertificates(certs ...tls.Certificate) *Client {
c.mu.Lock()
defer c.mu.Unlock()

if c.TLSConfig == nil {
c.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
c.TLSConfig.Certificates = certs
return c
}

// SetRootCertificate sets the root certificate for the client.
func (c *Client) SetRootCertificate(pemFilePath string) *Client {
cleanPath := filepath.Clean(pemFilePath)
if !strings.HasPrefix(cleanPath, "/expected/base/path") {
return c
}
rootPemData, err := os.ReadFile(pemFilePath)
if err != nil {
return c
}
c.handleCAs("root", rootPemData)
return c
}

// SetRootCertificateFromString sets the root certificate for the client from a string.
func (c *Client) SetRootCertificateFromString(pemCerts string) *Client {
return c.handleCAs("root", []byte(pemCerts))
}

// SetClientRootCertificate sets the client root certificate for the client.
func (c *Client) SetClientRootCertificate(pemFilePath string) *Client {
cleanPath := filepath.Clean(pemFilePath)
if !strings.HasPrefix(cleanPath, "/expected/base/path") {
return c
}
rootPemData, err := os.ReadFile(pemFilePath)
if err != nil {
return c
}
return c.handleCAs("client", rootPemData)
}

// SetClientRootCertificateFromString sets the client root certificate for the client from a string.
func (c *Client) SetClientRootCertificateFromString(pemCerts string) *Client {
return c.handleCAs("client", []byte(pemCerts))
}

// handleCAs sets the TLS certificates for the client.
func (c *Client) handleCAs(scope string, permCerts []byte) *Client {
c.mu.Lock()
defer c.mu.Unlock()

if c.TLSConfig == nil {
c.TLSConfig = &tls.Config{
MinVersion: tls.VersionTLS12,
}
}
switch scope {
case "root":
if c.TLSConfig.RootCAs == nil {
c.TLSConfig.RootCAs = x509.NewCertPool()
}
c.TLSConfig.RootCAs.AppendCertsFromPEM(permCerts)
case "client":
if c.TLSConfig.ClientCAs == nil {
c.TLSConfig.ClientCAs = x509.NewCertPool()
}
c.TLSConfig.ClientCAs.AppendCertsFromPEM(permCerts)
}
return c
}

// SetHTTPClient sets the HTTP client for the client
func (c *Client) SetHTTPClient(httpClient *http.Client) {
c.mu.Lock()
Expand Down Expand Up @@ -322,41 +417,59 @@ func (c *Client) DelDefaultCookie(name string) {

// SetJSONMarshal sets the JSON marshal function for the client's JSONEncoder
func (c *Client) SetJSONMarshal(marshalFunc func(v any) ([]byte, error)) {
c.mu.Lock()
defer c.mu.Unlock()

c.JSONEncoder = &JSONEncoder{
MarshalFunc: marshalFunc,
}
}

// SetJSONUnmarshal sets the JSON unmarshal function for the client's JSONDecoder
func (c *Client) SetJSONUnmarshal(unmarshalFunc func(data []byte, v any) error) {
c.mu.Lock()
defer c.mu.Unlock()

c.JSONDecoder = &JSONDecoder{
UnmarshalFunc: unmarshalFunc,
}
}

// SetXMLMarshal sets the XML marshal function for the client's XMLEncoder
func (c *Client) SetXMLMarshal(marshalFunc func(v any) ([]byte, error)) {
c.mu.Lock()
defer c.mu.Unlock()

c.XMLEncoder = &XMLEncoder{
MarshalFunc: marshalFunc,
}
}

// SetXMLUnmarshal sets the XML unmarshal function for the client's XMLDecoder
func (c *Client) SetXMLUnmarshal(unmarshalFunc func(data []byte, v any) error) {
c.mu.Lock()
defer c.mu.Unlock()

c.XMLDecoder = &XMLDecoder{
UnmarshalFunc: unmarshalFunc,
}
}

// SetYAMLMarshal sets the YAML marshal function for the client's YAMLEncoder
func (c *Client) SetYAMLMarshal(marshalFunc func(v any) ([]byte, error)) {
c.mu.Lock()
defer c.mu.Unlock()

c.YAMLEncoder = &YAMLEncoder{
MarshalFunc: marshalFunc,
}
}

// SetYAMLUnmarshal sets the YAML unmarshal function for the client's YAMLDecoder
func (c *Client) SetYAMLUnmarshal(unmarshalFunc func(data []byte, v any) error) {
c.mu.Lock()
defer c.mu.Unlock()

c.YAMLDecoder = &YAMLDecoder{
UnmarshalFunc: unmarshalFunc,
}
Expand Down Expand Up @@ -391,11 +504,29 @@ func (c *Client) SetRetryIf(retryIf RetryIfFunc) *Client {

// SetAuth configures an authentication method for the client.
func (c *Client) SetAuth(auth AuthMethod) {
c.mu.Lock()
defer c.mu.Unlock()

if auth.Valid() {
c.auth = auth
}
}

// SetRedirectPolicy sets the redirect policy for the client
func (c *Client) SetRedirectPolicy(policies ...RedirectPolicy) *Client {
c.mu.Lock()
defer c.mu.Unlock()
c.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
for _, p := range policies {
if err := p.Apply(req, via); err != nil {
return err
}
}
return nil
}
return c
}

// SetLogger sets logger instance in client.
func (c *Client) SetLogger(logger Logger) *Client {
c.mu.Lock()
Expand Down
Loading