diff --git a/internal/http1parser/header.go b/internal/http1parser/header.go index 85d64050..d4ef3e64 100644 --- a/internal/http1parser/header.go +++ b/internal/http1parser/header.go @@ -1,165 +1,43 @@ package http1parser -import "errors" - -var ( - ErrBadProto = errors.New("bad protocol") - ErrMissingData = errors.New("missing data") +import ( + "errors" + "net/textproto" + "strings" ) -const ( - _eNextHeader int = iota - _eNextHeaderN - _eHeader - _eHeaderValueSpace - _eHeaderValue - _eHeaderValueN - _eMLHeaderStart - _eMLHeaderValue -) +var ErrBadProto = errors.New("bad protocol") // Http1ExtractHeaders is an HTTP/1.0 and HTTP/1.1 header-only parser, // to extract the original header names for the received request. -// Fully inspired by https://github.com/evanphx/wildcat -func Http1ExtractHeaders(input []byte) ([]string, error) { - total := len(input) - var path, version, headers int - var headerNames []string - - // First line: METHOD PATH VERSION - var methodOk bool - for i := 0; i < total; i++ { - switch input[i] { - case ' ', '\t': - methodOk = true - path = i + 1 - } - if methodOk { - break - } - } - - if !methodOk { - return nil, ErrMissingData - } - - var pathOk bool - for i := path; i < total; i++ { - switch input[i] { - case ' ', '\t': - pathOk = true - version = i + 1 - } - if pathOk { - break - } +// Fully inspired by readMIMEHeader() in +// https://github.com/golang/go/blob/master/src/net/textproto/reader.go +func Http1ExtractHeaders(r *textproto.Reader) ([]string, error) { + // Discard first line, it doesn't contain useful information, and it has + // already been validated in http.ReadRequest() + if _, err := r.ReadLine(); err != nil { + return nil, err } - if !pathOk { - return nil, ErrMissingData + // The first line cannot start with a leading space. + if buf, err := r.R.Peek(1); err == nil && (buf[0] == ' ' || buf[0] == '\t') { + return nil, ErrBadProto } - var versionOk bool - var readN bool - for i := version; i < total; i++ { - c := input[i] - - switch readN { - case false: - switch c { - case '\r': - readN = true - case '\n': - headers = i + 1 - versionOk = true - } - case true: - if c != '\n' { - return nil, ErrBadProto - } - headers = i + 1 - versionOk = true - } - if versionOk { - break + var headerNames []string + for { + kv, err := r.ReadContinuedLine() + if len(kv) == 0 { + // We have finished to parse the headers if we receive empty + // data without an error + return headerNames, err } - } - - if !versionOk { - return nil, ErrMissingData - } - // Header parsing - state := _eNextHeader - start := headers - - for i := headers; i < total; i++ { - switch state { - case _eNextHeader: - switch input[i] { - case '\r': - state = _eNextHeaderN - case '\n': - return headerNames, nil - case ' ', '\t': - state = _eMLHeaderStart - default: - start = i - state = _eHeader - } - case _eNextHeaderN: - if input[i] != '\n' { - return nil, ErrBadProto - } - - return headerNames, nil - case _eHeader: - if input[i] == ':' { - headerName := input[start:i] - headerNames = append(headerNames, string(headerName)) - state = _eHeaderValueSpace - } - case _eHeaderValueSpace: - switch input[i] { - case ' ', '\t': - continue - } - - start = i - state = _eHeaderValue - case _eHeaderValue: - switch input[i] { - case '\r': - state = _eHeaderValueN - case '\n': - state = _eNextHeader - default: - continue - } - case _eHeaderValueN: - if input[i] != '\n' { - return nil, ErrBadProto - } - state = _eNextHeader - case _eMLHeaderStart: - switch input[i] { - case ' ', '\t': - continue - } - - start = i - state = _eMLHeaderValue - case _eMLHeaderValue: - switch input[i] { - case '\r': - state = _eHeaderValueN - case '\n': - state = _eNextHeader - default: - continue - } + // Key ends at first colon. + k, _, ok := strings.Cut(kv, ":") + if !ok { + return nil, ErrBadProto } + headerNames = append(headerNames, k) } - - return nil, ErrMissingData } diff --git a/internal/http1parser/header_test.go b/internal/http1parser/header_test.go index 9ee7a0f6..1e905442 100644 --- a/internal/http1parser/header_test.go +++ b/internal/http1parser/header_test.go @@ -1,6 +1,9 @@ package http1parser_test import ( + "bufio" + "bytes" + "net/textproto" "testing" "github.com/elazarl/goproxy/internal/http1parser" @@ -11,7 +14,9 @@ import ( func TestHttp1ExtractHeaders_Empty(t *testing.T) { http1Data := "POST /index.html HTTP/1.1\r\n" + "\r\n" - headers, err := http1parser.Http1ExtractHeaders([]byte(http1Data)) + + textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data)))) + headers, err := http1parser.Http1ExtractHeaders(textParser) require.NoError(t, err) assert.Empty(t, headers) } @@ -19,13 +24,14 @@ func TestHttp1ExtractHeaders_Empty(t *testing.T) { func TestHttp1ExtractHeaders(t *testing.T) { http1Data := "POST /index.html HTTP/1.1\r\n" + "Host: www.test.com\r\n" + - "Accept: */*\r\n" + + "Accept: */ /*\r\n" + "Content-Length: 17\r\n" + "lowercase: 3z\r\n" + "\r\n" + `{"hello":"world"}` - headers, err := http1parser.Http1ExtractHeaders([]byte(http1Data)) + textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data)))) + headers, err := http1parser.Http1ExtractHeaders(textParser) require.NoError(t, err) assert.Len(t, headers, 4) assert.Contains(t, headers, "Content-Length") @@ -35,6 +41,8 @@ func TestHttp1ExtractHeaders(t *testing.T) { func TestHttp1ExtractHeaders_InvalidData(t *testing.T) { http1Data := "POST /index.html HTTP/1.1\r\n" + `{"hello":"world"}` - _, err := http1parser.Http1ExtractHeaders([]byte(http1Data)) + + textParser := textproto.NewReader(bufio.NewReader(bytes.NewReader([]byte(http1Data)))) + _, err := http1parser.Http1ExtractHeaders(textParser) require.Error(t, err) } diff --git a/internal/http1parser/request.go b/internal/http1parser/request.go index e63fc008..0e37bc2d 100644 --- a/internal/http1parser/request.go +++ b/internal/http1parser/request.go @@ -33,11 +33,17 @@ func NewRequestReader(preventCanonicalization bool, conn io.Reader) *RequestRead } } +// IsEOF returns true if there is no more data that can be read from the +// buffer and the underlying connection is closed. func (r *RequestReader) IsEOF() bool { _, err := r.reader.Peek(1) return errors.Is(err, io.EOF) } +// Reader is used to take over the buffered connection data +// (e.g. with HTTP/2 data). +// After calling this function, make sure to consume all the data related +// to the current request. func (r *RequestReader) Reader() *bufio.Reader { return r.reader } @@ -54,8 +60,9 @@ func (r *RequestReader) ReadRequest() (*http.Request, error) { return nil, err } - httpData := getRequestData(r.reader, r.cloned) - headers, _ := Http1ExtractHeaders(httpData) + httpDataReader := getRequestReader(r.reader, r.cloned) + headers, _ := Http1ExtractHeaders(httpDataReader) + for _, headerName := range headers { canonicalizedName := textproto.CanonicalMIMEHeaderKey(headerName) if canonicalizedName == headerName { @@ -73,12 +80,15 @@ func (r *RequestReader) ReadRequest() (*http.Request, error) { return req, nil } -func getRequestData(r *bufio.Reader, cloned *bytes.Buffer) []byte { +func getRequestReader(r *bufio.Reader, cloned *bytes.Buffer) *textproto.Reader { // "Cloned" buffer uses the raw connection as the data source. // However, the *bufio.Reader can read also bytes of another unrelated // request on the same connection, since it's buffered, so we have to // ignore them before passing the data to our headers parser. // Data related to the next request will remain inside the buffer for // later usage. - return cloned.Next(cloned.Len() - r.Buffered()) + data := cloned.Next(cloned.Len() - r.Buffered()) + return &textproto.Reader{ + R: bufio.NewReader(bytes.NewReader(data)), + } }