Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement statusCodeFilter option for client #13

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ type Response struct {
// Header is the cached response header.
Header http.Header

// StatusCode is the cached response status code.
StatusCode int

// Expiration is the cached response expiration date.
Expiration time.Time

Expand All @@ -66,6 +69,7 @@ type Client struct {
ttl time.Duration
refreshKey string
methods []string
statusCodeFilter func(int) bool
writeExpiresHeader bool
}

Expand Down Expand Up @@ -120,13 +124,13 @@ func (c *Client) Middleware(next http.Handler) http.Handler {
response.Frequency++
c.adapter.Set(key, response.Bytes(), response.Expiration)

//w.WriteHeader(http.StatusNotModified)
for k, v := range response.Header {
w.Header().Set(k, strings.Join(v, ","))
}
if c.writeExpiresHeader {
w.Header().Set("Expires", response.Expiration.UTC().Format(http.TimeFormat))
}
w.WriteHeader(response.StatusCode)
w.Write(response.Value)
return
}
Expand All @@ -143,10 +147,11 @@ func (c *Client) Middleware(next http.Handler) http.Handler {
value := rec.Body.Bytes()
now := time.Now()
expires := now.Add(c.ttl)
if statusCode < 400 {
if c.statusCodeFilter(statusCode) {
response := Response{
Value: value,
Header: result.Header,
StatusCode: statusCode,
Expiration: expires,
LastAccess: now,
Frequency: 1,
Expand Down Expand Up @@ -244,6 +249,9 @@ func NewClient(opts ...ClientOption) (*Client, error) {
if c.methods == nil {
c.methods = []string{http.MethodGet}
}
if c.statusCodeFilter == nil {
c.statusCodeFilter = func(code int) bool { return code < 400 }
}

return c, nil
}
Expand Down Expand Up @@ -293,6 +301,15 @@ func ClientWithMethods(methods []string) ClientOption {
}
}

// ClientWithStatusCodeFilter sets the acceptable status codes to be cached.
// Optional setting. If not set, default filter allows caching of every response with status code below 400.
func ClientWithStatusCodeFilter(filter func(int) bool) ClientOption {
return func(c *Client) error {
c.statusCodeFilter = filter
return nil
}
}

// ClientWithExpiresHeader enables middleware to add an Expires header to responses.
// Optional setting. If not set, default is false.
func ClientWithExpiresHeader() ClientOption {
Expand Down
58 changes: 58 additions & 0 deletions cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,22 @@ func TestMiddleware(t *testing.T) {
store: map[uint64][]byte{
14974843192121052621: Response{
Value: []byte("value 1"),
StatusCode: 200,
Expiration: time.Now().Add(1 * time.Minute),
}.Bytes(),
14974839893586167988: Response{
Value: []byte("value 2"),
StatusCode: 200,
Expiration: time.Now().Add(1 * time.Minute),
}.Bytes(),
14974840993097796199: Response{
Value: []byte("value 3"),
StatusCode: 200,
Expiration: time.Now().Add(-1 * time.Minute),
}.Bytes(),
10956846073361780255: Response{
Value: []byte("value 4"),
StatusCode: 200,
Expiration: time.Now().Add(-1 * time.Minute),
}.Bytes(),
},
Expand Down Expand Up @@ -473,9 +477,63 @@ func TestNewClient(t *testing.T) {
t.Errorf("NewClient() error = %v, wantErr %v", err, tt.wantErr)
return
}

if tt.want != nil {
got.statusCodeFilter = nil
tt.want.statusCodeFilter = nil
}

if !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewClient() = %v, want %v", got, tt.want)
}
})
}
}

func TestNewClientWithStatusCodeFilter(t *testing.T) {
adapter := &adapterMock{}

tests := []struct {
name string
opts []ClientOption
wantCache []int
wantSkip []int
}{
{
"returns new client with status code filter",
[]ClientOption{
ClientWithAdapter(adapter),
ClientWithTTL(1 * time.Millisecond),
},
[]int{200, 300},
[]int{400, 500},
},
{
"returns new client with status code filter",
[]ClientOption{
ClientWithAdapter(adapter),
ClientWithTTL(1 * time.Millisecond),
ClientWithStatusCodeFilter(func(code int) bool { return code < 350 || code > 450 }),
},
[]int{200, 300, 500},
[]int{400},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, _ := NewClient(tt.opts...)

for _, c := range tt.wantCache {
if got.statusCodeFilter(c) == false {
t.Errorf("NewClient() allows caching of status code %v, don't want it to", c)
}
}

for _, c := range tt.wantSkip {
if got.statusCodeFilter(c) == true {
t.Errorf("NewClient() doesn't allow caching of status code %v, want it to", c)
}
}
})
}
}