Skip to content

Commit

Permalink
frontend: fix tests using httptest.ResponseRecorder
Browse files Browse the repository at this point in the history
Signed-off-by: Simon Pasquier <spasquie@redhat.com>
  • Loading branch information
simonpasquier committed Feb 20, 2025
1 parent 4033786 commit 5cfbdbb
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 66 deletions.
1 change: 1 addition & 0 deletions frontend/pkg/frontend/middleware_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

const megabyte int64 = (1 << 20)

// MiddlewareBody ensures that the request's body doesn't exceed the maximum size of 4MB.
func MiddlewareBody(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
switch r.Method {
case http.MethodPatch, http.MethodPost, http.MethodPut:
Expand Down
45 changes: 19 additions & 26 deletions frontend/pkg/frontend/middleware_body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@ package frontend
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/stretchr/testify/assert"
)

func TestMiddlewareBody(t *testing.T) {
Expand Down Expand Up @@ -80,9 +82,7 @@ func TestMiddlewareBody(t *testing.T) {
writer := httptest.NewRecorder()

request, err := http.NewRequest(method, "", bytes.NewReader(tt.body))
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
request.Header = tt.header

next := func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -92,36 +92,29 @@ func TestMiddlewareBody(t *testing.T) {

MiddlewareBody(writer, request, next)

if tt.wantErr == "" {
if writer.Code != http.StatusOK {
t.Error(writer.Code)
}
res := writer.Result()
b, err := io.ReadAll(res.Body)
assert.NoError(t, err)

if writer.Body.String() != "" {
t.Error(writer.Body.String())
}
if tt.wantErr == "" {
assert.Equal(t, http.StatusOK, res.StatusCode)
assert.Empty(t, string(b))

if method != http.MethodGet {
body, err := BodyFromContext(request.Context())
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(body, tt.body) {
t.Error(string(body))
}
assert.NoError(t, err)
assert.Equal(t, string(tt.body), string(body))
}
} else {
var cloudErr *arm.CloudError
err = json.Unmarshal(writer.Body.Bytes(), &cloudErr)
if err != nil {
t.Fatal(err)
}
cloudErr.StatusCode = writer.Code

if tt.wantErr != cloudErr.Error() {
t.Error(cloudErr)
}
return
}

var cloudErr *arm.CloudError
err = json.Unmarshal(b, &cloudErr)
assert.NoError(t, err)

cloudErr.StatusCode = res.StatusCode
assert.Equal(t, tt.wantErr, cloudErr.Error())
})
}
}
Expand Down
1 change: 1 addition & 0 deletions frontend/pkg/frontend/middleware_validatestatic.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
var rxHCPOpenShiftClusterResourceName = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{2,53}$`)
var rxNodePoolResourceName = regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9-]{2,14}$`)

// MiddlewareValidateStatic ensures that the URL path parses to a valid resource ID.
func MiddlewareValidateStatic(w http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// To conform with "OAPI012: Resource IDs must not be case sensitive"
// we need to use the original, non-lowercased resource ID components
Expand Down
26 changes: 7 additions & 19 deletions frontend/pkg/frontend/middleware_validatestatic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@ package frontend

import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/stretchr/testify/assert"
)

type CloudErrorContainer struct {
Expand Down Expand Up @@ -111,29 +110,18 @@ func TestMiddlewareValidateStatic(t *testing.T) {
// Execute the middleware
MiddlewareValidateStatic(w, req, nextHandler)

res := w.Result()

// Check the response status code
if status := w.Code; status != tc.expectedStatusCode {
t.Errorf("handler returned wrong status code: got %v want %v",
status, tc.expectedStatusCode)
}
assert.Equal(t, tc.expectedStatusCode, res.StatusCode)

if tc.expectedStatusCode != http.StatusOK {

var resp CloudErrorContainer
body, err := io.ReadAll(http.MaxBytesReader(w, w.Result().Body, 4*megabyte))
if err != nil {
t.Fatalf("failed to read response body: %v", err)
}
err = json.Unmarshal(body, &resp)
if err != nil {
t.Fatalf("failed to unmarshal response body: %v", err)
}
err := json.NewDecoder(res.Body).Decode(&resp)
assert.NoError(t, err)

// Check if the error message contains the expected text
if !strings.Contains(resp.Error.Message, tc.expectedBody) {
t.Errorf("handler returned unexpected body: got %v want %v",
resp.Error.Message, tc.expectedBody)
}
assert.Contains(t, tc.expectedBody, resp.Error.Message)
}
})
}
Expand Down
38 changes: 17 additions & 21 deletions frontend/pkg/frontend/middleware_validatesubscription_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,17 @@ package frontend
import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"

"github.com/google/go-cmp/cmp"
"go.uber.org/mock/gomock"

"github.com/Azure/ARO-HCP/internal/api/arm"
"github.com/Azure/ARO-HCP/internal/database"
"github.com/Azure/ARO-HCP/internal/mocks"
"github.com/stretchr/testify/assert"
)

func TestMiddlewareValidateSubscription(t *testing.T) {
Expand Down Expand Up @@ -186,9 +185,7 @@ func TestMiddlewareValidateSubscription(t *testing.T) {
writer := httptest.NewRecorder()

request, err := http.NewRequest(tt.httpMethod, tt.requestPath, nil)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)

// Add a logger to the context so parsing errors will be logged.
ctx := request.Context()
Expand All @@ -207,28 +204,27 @@ func TestMiddlewareValidateSubscription(t *testing.T) {

MiddlewareValidateSubscriptionState(writer, request, next)

res := writer.Result()
if tt.expectedError != nil {
var actualError *arm.CloudError
body, _ := io.ReadAll(http.MaxBytesReader(writer, writer.Result().Body, 4*megabyte))
_ = json.Unmarshal(body, &actualError)
if (writer.Result().StatusCode != tt.expectedError.StatusCode) || actualError.Code != tt.expectedError.Code || actualError.Message != tt.expectedError.Message {
t.Errorf("unexpected CloudError, wanted %v, got %v", tt.expectedError, actualError)
}
} else {
if doc.Subscription.State != tt.expectedState {
t.Error(cmp.Diff(doc.Subscription.State, tt.expectedState))
}
var actualError arm.CloudError
err = json.NewDecoder(res.Body).Decode(&actualError)
assert.NoError(t, err)

assert.Equal(t, tt.expectedError.StatusCode, res.StatusCode)
assert.Equal(t, tt.expectedError.Code, actualError.Code)
assert.Equal(t, tt.expectedError.Message, actualError.Message)
return
}

assert.Equal(t, tt.expectedState, doc.Subscription.State)
})
}

t.Run("nil DB client in the context", func(t *testing.T) {
writer := httptest.NewRecorder()

request, err := http.NewRequest(http.MethodGet, defaultRequestPath, nil)
if err != nil {
t.Fatal(err)
}
assert.NoError(t, err)
request.SetPathValue(PathSegmentSubscriptionID, subscriptionId)

ctx := request.Context()
Expand All @@ -237,8 +233,8 @@ func TestMiddlewareValidateSubscription(t *testing.T) {

next := func(w http.ResponseWriter, r *http.Request) {}
MiddlewareValidateSubscriptionState(writer, request, next)
if writer.Code != http.StatusInternalServerError {
t.Errorf("expected status code %d, got %d", http.StatusInternalServerError, writer.Code)
}

res := writer.Result()
assert.Equal(t, http.StatusInternalServerError, res.StatusCode)
})
}

0 comments on commit 5cfbdbb

Please sign in to comment.