diff --git a/client.go b/client.go index 8bba060..083fcf6 100644 --- a/client.go +++ b/client.go @@ -22,11 +22,13 @@ import ( "github.com/cloudwego/hertz/pkg/protocol" "net/http" "net/url" + "regexp" "sync" ) type Client struct { QueryParam url.Values + FormData map[string]string PathParams map[string]string Header http.Header Cookies []*http.Cookie @@ -45,8 +47,49 @@ type ( ResponseMiddleware func(*Client, *Response) error ) -func createClient(c *client.Client) *Client { - return &Client{client: c} +var ( + hdrUserAgentKey = http.CanonicalHeaderKey("User-Agent") + hdrAcceptKey = http.CanonicalHeaderKey("Accept") + hdrContentTypeKey = http.CanonicalHeaderKey("Content-Type") + hdrContentLengthKey = http.CanonicalHeaderKey("Content-Length") + hdrContentEncodingKey = http.CanonicalHeaderKey("Content-Encoding") + hdrLocationKey = http.CanonicalHeaderKey("Location") + hdrAuthorizationKey = http.CanonicalHeaderKey("Authorization") + hdrWwwAuthenticateKey = http.CanonicalHeaderKey("WWW-Authenticate") + + plainTextType = "text/plain; charset=utf-8" + jsonContentType = "application/json" + formContentType = "application/x-www-form-urlencoded" + formDataContentType = "multipart/form-data" + + jsonCheck = regexp.MustCompile(`(?i:(application|text)/(.*json.*)(;|$))`) + xmlCheck = regexp.MustCompile(`(?i:(application|text)/(.*xml.*)(;|$))`) +) + +func createClient(cc *client.Client) *Client { + c := &Client{ + QueryParam: url.Values{}, + PathParams: make(map[string]string), + Header: http.Header{}, + Cookies: make([]*http.Cookie, 0), + + udBeforeRequestLock: &sync.RWMutex{}, + afterResponseLock: &sync.RWMutex{}, + + client: cc, + } + + c.beforeRequest = []RequestMiddleware{ + parseRequestURL, + parseRequestHeader, + parseRequestBody, + } + + c.udBeforeRequest = []RequestMiddleware{} + + c.afterResponse = []ResponseMiddleware{} + + return c } func (c *Client) SetQueryParam(param, value string) *Client { @@ -114,7 +157,7 @@ func (c *Client) SetHeaders(headers map[string]string) *Client { func (c *Client) SetHeaderMultiValues(headers map[string][]string) *Client { for k, header := range headers { for _, v := range header { - c.Header.Set(k, v) + c.Header.Add(k, v) } } return c @@ -188,6 +231,7 @@ func (c *Client) R() *Request { Header: http.Header{}, Cookies: make([]*http.Cookie, 0), PathParams: map[string]string{}, + RawRequest: &protocol.Request{}, client: c, } diff --git a/easy_http.go b/easy_http.go index ff1fc86..5578a0d 100644 --- a/easy_http.go +++ b/easy_http.go @@ -20,7 +20,7 @@ import "github.com/cloudwego/hertz/pkg/app/client" func New() (*Client, error) { c, err := client.NewClient() - return &Client{client: c}, err + return createClient(c), err } func MustNew() *Client { @@ -28,7 +28,7 @@ func MustNew() *Client { if err != nil { panic(err) } - return &Client{client: c} + return createClient(c) } func NewWithHertzClient(c *client.Client) *Client { diff --git a/example/main.go b/example/main.go new file mode 100644 index 0000000..4243c97 --- /dev/null +++ b/example/main.go @@ -0,0 +1,16 @@ +package main + +import ( + "fmt" + "github.com/hertz-contrib/easy_http" +) + +func main() { + c := easy_http.MustNew() + + res, err := c.SetHeader("test", "test").SetQueryParam("test1", "test1").R().Get("http://www.baidu.com") + if err != nil { + panic(err) + } + fmt.Println(res) +} diff --git a/middleware.go b/middleware.go index 3a3909d..36e03ce 100644 --- a/middleware.go +++ b/middleware.go @@ -15,3 +15,127 @@ */ package easy_http + +import ( + "net/http" + "net/url" + "reflect" + "strings" +) + +func parseRequestURL(c *Client, r *Request) error { + if l := len(c.PathParams) + len(r.PathParams); l > 0 { + params := make(map[string]string, l) + + // GitHub #103 Path Params + for p, v := range r.PathParams { + params[p] = url.PathEscape(v) + } + for p, v := range c.PathParams { + if _, ok := params[p]; !ok { + params[p] = url.PathEscape(v) + } + } + + if len(params) > 0 { + + } + + for k, v := range params { + r.URL = strings.Replace(r.URL, "{"+k+"}", v, 1) + } + } + + // Parsing request URL + reqURL, err := url.Parse(r.URL) + if err != nil { + return err + } + + // Adding Query Param + if len(c.QueryParam)+len(r.QueryParam) > 0 { + for k, v := range c.QueryParam { + // skip query parameter if it was set in request + if _, ok := r.QueryParam[k]; ok { + continue + } + + r.QueryParam[k] = v[:] + } + + if len(r.QueryParam) > 0 { + if len(strings.TrimSpace(reqURL.RawQuery)) == 0 { + reqURL.RawQuery = r.QueryParam.Encode() + } else { + reqURL.RawQuery = reqURL.RawQuery + "&" + r.QueryParam.Encode() + } + } + } + + r.URL = reqURL.String() + + return nil +} + +func parseRequestHeader(c *Client, r *Request) error { + for k, v := range c.Header { + if _, ok := r.Header[k]; ok { + continue + } + r.Header[k] = v[:] + } + + return nil +} + +func parseRequestBody(c *Client, r *Request) error { + switch { + case r.RawRequest.HasMultipartForm(): // Handling Multipart + handleMultipart(c, r) + case len(c.FormData) > 0 || len(r.FormData) > 0: // Handling Form Data + handleFormData(c, r) + //case r.RawRequest.Body() != nil: // Handling Request body + // handleContentType(c, r) + } + + return nil +} + +func handleMultipart(c *Client, r *Request) { + r.RawRequest.SetMultipartFormData(c.FormData) + + r.Header.Set(hdrContentTypeKey, formDataContentType) +} + +func handleFormData(c *Client, r *Request) { + r.RawRequest.SetFormData(c.FormData) + + r.Header.Set(hdrContentTypeKey, formContentType) +} + +//func handleContentType(c *Client, r *Request) { +// contentType := r.Header.Get(hdrContentTypeKey) +// if len(strings.TrimSpace(contentType)) == 0 { +// contentType = DetectContentType(r.RawRequest.Body()) +// r.Header.Set(hdrContentTypeKey, contentType) +// } +//} + +func DetectContentType(body interface{}) string { + contentType := plainTextType + kind := reflect.Indirect(reflect.ValueOf(body)).Kind() + switch kind { + case reflect.Struct, reflect.Map: + contentType = jsonContentType + case reflect.String: + contentType = plainTextType + default: + if b, ok := body.([]byte); ok { + contentType = http.DetectContentType(b) + } else if kind == reflect.Slice { + contentType = jsonContentType + } + } + + return contentType +} diff --git a/request.go b/request.go index 4469b3a..df8a392 100644 --- a/request.go +++ b/request.go @@ -26,14 +26,14 @@ import ( type Request struct { client *Client - Url string + URL string Method string QueryParam url.Values Header http.Header Cookies []*http.Cookie PathParams map[string]string - FormParams map[string]string - FileParams map[string]string + FormData map[string]string + FileData map[string]string BodyParams interface{} RawRequest *protocol.Request Ctx context.Context @@ -65,7 +65,6 @@ const ( MethodOptions = "OPTIONS" ) - func (r *Request) Get(url string) (*Response, error) { return r.Execute(MethodGet, url) } @@ -95,16 +94,19 @@ func (r *Request) Patch(url string) (*Response, error) { } func (r *Request) Send() (*Response, error) { - return r.Execute(r.Method, r.Url) + return r.Execute(r.Method, r.URL) } func (r *Request) Execute(method, url string) (*Response, error) { r.Method = method - r.Url = url + + r.RawRequest.SetRequestURI(url) res := &Response{ - Request: r, + Request: r, RawResponse: &protocol.Response{}, } - err :=r.client.client.Do(context.Background(),r.RawRequest,res.RawResponse) - return res,err -} \ No newline at end of file + + var err error + res, err = r.client.execute(r) + return res, err +}