Skip to content

Commit

Permalink
Adding configurable requests cache
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanoj3 committed Aug 4, 2019
1 parent b9ce7a3 commit fa5039d
Show file tree
Hide file tree
Showing 11 changed files with 193 additions and 5 deletions.
1 change: 1 addition & 0 deletions functional-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ SCAN_RESULT=$(./dist/dirstalk scan -h 2>&1 || true);
assert_contains "$SCAN_RESULT" "\-\-dictionary" "dictionary help is expected to be printed"
assert_contains "$SCAN_RESULT" "\-\-cookie" "cookie help is expected to be printed"
assert_contains "$SCAN_RESULT" "\-\-header" "header help is expected to be printed"
assert_contains "$SCAN_RESULT" "\-\-http-cache-requests" "http-cache-requests help is expected to be printed"
assert_contains "$SCAN_RESULT" "\-\-http-methods" "http-methods help is expected to be printed"
assert_contains "$SCAN_RESULT" "\-\-http-statuses-to-ignore" "http-statuses-to-ignore help is expected to be printed"
assert_contains "$SCAN_RESULT" "\-\-http-timeout" "http-timeout help is expected to be printed"
Expand Down
5 changes: 5 additions & 0 deletions pkg/cmd/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ func scanConfigFromCmd(cmd *cobra.Command) (*scan.Config, error) {
return nil, errors.Wrap(err, "failed to read http-timeout flag")
}

c.CacheRequests, err = cmd.Flags().GetBool(flagHTTPCacheRequests)
if err != nil {
return nil, errors.Wrap(err, "failed to read http-cache-requests flag")
}

c.ScanDepth, err = cmd.Flags().GetInt(flagScanDepth)
if err != nil {
return nil, errors.Wrap(err, "failed to read http-timeout flag")
Expand Down
1 change: 1 addition & 0 deletions pkg/cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const (
flagHTTPMethods = "http-methods"
flagHTTPStatusesToIgnore = "http-statuses-to-ignore"
flagHTTPTimeout = "http-timeout"
flagHTTPCacheRequests = "http-cache-requests"
flagScanDepth = "scan-depth"
flagThreads = "threads"
flagThreadsShort = "t"
Expand Down
11 changes: 10 additions & 1 deletion pkg/cmd/scan.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func NewScanCommand(logger *logrus.Logger) (*cobra.Command, error) {
cmd.Flags().IntSlice(
flagHTTPStatusesToIgnore,
[]int{http.StatusNotFound},
"comma separated list of http statuses to ignore when showing results; eg: 404,301",
"comma separated list of http statuses to ignore when showing and processing results; eg: 404,301",
)

cmd.Flags().IntP(
Expand All @@ -66,6 +66,14 @@ func NewScanCommand(logger *logrus.Logger) (*cobra.Command, error) {
"timeout in milliseconds",
)

cmd.Flags().BoolP(
flagHTTPCacheRequests,
"",
true,
"cache requests to avoid performing the same request multiple times within the same scan (EG if the "+
"server reply with the same redirect location multiple times, dirstalk will follow it only once)",
)

cmd.Flags().IntP(
flagScanDepth,
"",
Expand Down Expand Up @@ -152,6 +160,7 @@ func startScan(logger *logrus.Logger, cnf *scan.Config, u *url.URL) error {
cnf.UseCookieJar,
cnf.Cookies,
cnf.Headers,
cnf.CacheRequests,
u,
)
if err != nil {
Expand Down
8 changes: 8 additions & 0 deletions pkg/scan/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func NewClientFromConfig(
useCookieJar bool,
cookies []*http.Cookie,
headers map[string]string,
shouldCacheRequests bool,
u *url.URL,
) (*http.Client, error) {
transport := http.Transport{
Expand Down Expand Up @@ -78,5 +79,12 @@ func NewClientFromConfig(
}
}

if shouldCacheRequests {
c.Transport, err = decorateTransportWithRequestCacheDecorator(c.Transport)
if err != nil {
return nil, errors.Wrap(err, "NewClientFromConfig: failed to decorate transport")
}
}

return c, nil
}
40 changes: 40 additions & 0 deletions pkg/scan/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ func TestWhenRemoteIsTooSlowClientShouldTimeout(t *testing.T) {
false,
nil,
nil,
true,
nil,
)
assert.NoError(t, err)
Expand Down Expand Up @@ -74,6 +75,7 @@ func TestShouldForwardProvidedCookiesWhenUsingJar(t *testing.T) {
true,
cookies,
map[string]string{},
false,
u,
)
assert.NoError(t, err)
Expand Down Expand Up @@ -133,6 +135,7 @@ func TestShouldForwardCookiesWhenJarIsDisabled(t *testing.T) {
false,
cookies,
map[string]string{},
true,
u,
)
assert.NoError(t, err)
Expand Down Expand Up @@ -172,6 +175,7 @@ func TestShouldForwardProvidedHeader(t *testing.T) {
false,
nil,
map[string]string{headerName: headerValue},
true,
u,
)
assert.NoError(t, err)
Expand All @@ -198,10 +202,46 @@ func TestShouldFailToCreateAClientWithInvalidSocks5Url(t *testing.T) {
false,
nil,
map[string]string{},
true,
nil,
)
assert.Nil(t, c)
assert.Error(t, err)

assert.Contains(t, err.Error(), "unknown scheme")
}

func TestShouldNotRepeatTheSameRequestTwice(t *testing.T) {
testServer, serverAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}),
)
defer testServer.Close()

u, err := url.Parse(testServer.URL)
assert.NoError(t, err)

c, err := client.NewClientFromConfig(
100,
nil,
"",
false,
nil,
nil,
true,
u,
)
assert.NoError(t, err)

req, err := http.NewRequest(http.MethodGet, u.String(), nil)
assert.NoError(t, err)

res, err := c.Do(req)
assert.NoError(t, err)
assert.Equal(t, http.StatusOK, res.StatusCode)

res, err = c.Do(req)
assert.Contains(t, err.Error(), client.ErrRequestRedundant.Error())
assert.Nil(t, res)

assert.Equal(t, 1, serverAssertion.Len())
}
44 changes: 44 additions & 0 deletions pkg/scan/client/request_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package client

import (
"errors"
"fmt"
"net/http"
"sync"
)

var (
// ErrRequestRedundant this error is returned when trying to perform the
// same request (method, host, path) more than one time
ErrRequestRedundant = errors.New("this request has been made already")
)

func decorateTransportWithRequestCacheDecorator(decorated http.RoundTripper) (*requestCacheTransportDecorator, error) {
if decorated == nil {
return nil, errors.New("decorated round tripper is nil")
}

return &requestCacheTransportDecorator{decorated: decorated}, nil
}

type requestCacheTransportDecorator struct {
decorated http.RoundTripper
requestMap sync.Map
}

func (u *requestCacheTransportDecorator) RoundTrip(r *http.Request) (*http.Response, error) {
key := u.keyForRequest(r)

_, found := u.requestMap.Load(key)
if found {
return nil, ErrRequestRedundant
}

u.requestMap.Store(key, struct{}{})

return u.decorated.RoundTrip(r)
}

func (u *requestCacheTransportDecorator) keyForRequest(r *http.Request) string {
return fmt.Sprintf("%s~%s~%s", r.Method, r.Host, r.URL.Path)
}
13 changes: 13 additions & 0 deletions pkg/scan/client/request_cache_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package client

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestRequestCacheTransportDecorator(t *testing.T) {
transport, err := decorateTransportWithRequestCacheDecorator(nil)
assert.Nil(t, transport)
assert.Error(t, err)
}
1 change: 1 addition & 0 deletions pkg/scan/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ type Config struct {
HTTPStatusesToIgnore []int
Threads int
TimeoutInMilliseconds int
CacheRequests bool
ScanDepth int
Socks5Url *url.URL
UserAgent string
Expand Down
7 changes: 5 additions & 2 deletions pkg/scan/scanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/sirupsen/logrus"
"github.com/stefanoj3/dirstalk/pkg/common/urlpath"
"github.com/stefanoj3/dirstalk/pkg/scan/client"
)

// Target represents the target to scan
Expand Down Expand Up @@ -117,6 +118,10 @@ func (s *Scanner) processRequest(
baseURL url.URL,
) {
res, err := s.httpClient.Do(req)
if err != nil && strings.Contains(err.Error(), client.ErrRequestRedundant.Error()) {
l.WithError(err).Debug("skipping, request was already made")
return
}
if err != nil {
l.WithError(err).Error("failed to perform request")
return
Expand All @@ -125,8 +130,6 @@ func (s *Scanner) processRequest(
l.WithError(err).Warn("failed to close response body")
}

// TODO: also add a cached version of the client to avoid doing the same requests multiple times
// (eg if there are multiple pages redirecting to the same URL maybe we can do the call only once)
result := NewResult(target, res)
results <- result

Expand Down
67 changes: 65 additions & 2 deletions pkg/scan/scanner_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func TestScannerWillLogAnErrorWithInvalidDictionary(t *testing.T) {
false,
nil,
nil,
true,
test.MustParseUrl(t, testServer.URL),
)
assert.NoError(t, err)
Expand Down Expand Up @@ -102,6 +103,7 @@ func TestScannerWillNotRedirectIfStatusCodeIsInvalid(t *testing.T) {
false,
nil,
nil,
true,
test.MustParseUrl(t, testServer.URL),
)
assert.NoError(t, err)
Expand Down Expand Up @@ -168,6 +170,7 @@ func TestScannerWillChangeMethodForRedirect(t *testing.T) {
false,
nil,
nil,
true,
test.MustParseUrl(t, testServer.URL),
)
assert.NoError(t, err)
Expand All @@ -179,7 +182,7 @@ func TestScannerWillChangeMethodForRedirect(t *testing.T) {
logger,
)

results := make([]scan.Result, 0, 4)
results := make([]scan.Result, 0, 3)
resultsChannel := sut.Scan(test.MustParseUrl(t, testServer.URL), 1)

for r := range resultsChannel {
Expand Down Expand Up @@ -242,6 +245,7 @@ func TestScannerWhenOutOfDepthWillNotFollowRedirect(t *testing.T) {
false,
nil,
nil,
true,
test.MustParseUrl(t, testServer.URL),
)
assert.NoError(t, err)
Expand Down Expand Up @@ -305,6 +309,7 @@ func TestScannerWillSkipRedirectWhenLocationHostIsDifferent(t *testing.T) {
false,
nil,
nil,
true,
test.MustParseUrl(t, testServer.URL),
)
assert.NoError(t, err)
Expand All @@ -316,7 +321,7 @@ func TestScannerWillSkipRedirectWhenLocationHostIsDifferent(t *testing.T) {
logger,
)

results := make([]scan.Result, 0, 4)
results := make([]scan.Result, 0, 2)
resultsChannel := sut.Scan(test.MustParseUrl(t, testServer.URL), 1)

for r := range resultsChannel {
Expand Down Expand Up @@ -345,3 +350,61 @@ func TestScannerWillSkipRedirectWhenLocationHostIsDifferent(t *testing.T) {
assert.NotContains(t, loggerBufferAsString, "error")
assert.Equal(t, 2, serverAssertion.Len())
}

func TestScannerWillIgnoreRequestRedundantError(t *testing.T) {
logger, loggerBuffer := test.NewLogger()

prod := producer.NewDictionaryProducer(
[]string{http.MethodGet},
[]string{"/home", "/home"}, // twice the same entry to trick the client into doing the same request twice
3,
)

testServer, serverAssertion := test.NewServerWithAssertion(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
}),
)
defer testServer.Close()

c, err := client.NewClientFromConfig(
1000,
nil,
"",
false,
nil,
nil,
true,
test.MustParseUrl(t, testServer.URL),
)
assert.NoError(t, err)

sut := scan.NewScanner(
c,
prod,
producer.NewReProducer(prod, []int{http.StatusNotFound}),
logger,
)

results := make([]scan.Result, 0, 1)
resultsChannel := sut.Scan(test.MustParseUrl(t, testServer.URL), 1)

for r := range resultsChannel {
results = append(results, r)
}

expectedResults := []scan.Result{
{
Target: scan.Target{Path: "/home", Method: http.MethodGet, Depth: 3},
StatusCode: http.StatusNotFound,
URL: *test.MustParseUrl(t, testServer.URL+"/home"),
},
}

assert.Equal(t, expectedResults, results)

loggerBufferAsString := loggerBuffer.String()
assert.Contains(t, loggerBufferAsString, "/home")
assert.Contains(t, loggerBufferAsString, "/home: this request has been made already")
assert.Equal(t, 1, serverAssertion.Len())
}

0 comments on commit fa5039d

Please sign in to comment.