Skip to content

Commit

Permalink
fix: properly manage NotFound & MethodNotAllowed request errors
Browse files Browse the repository at this point in the history
Signed-off-by: juan131 <jariza@vmware.com>
  • Loading branch information
juan131 committed Nov 16, 2023
1 parent 19fa42f commit 0719785
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 15 deletions.
24 changes: 21 additions & 3 deletions internal/service/r_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,20 @@ func (svc *service) handleBatchMock(w http.ResponseWriter, r *http.Request) {
encodedBody, err := io.ReadAll(r.Body)
if err != nil {
logID := svc.LogRequestFailure(r, fmt.Sprintf("[handleBatchMock] body reading error: %+v", err), err)
renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", 1001, logID))
renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", api.CodeInvalidBody, logID))
return
}

decodedBody, err := url.ParseQuery(string(encodedBody))
if err != nil {
logID := svc.LogRequestFailure(r, fmt.Sprintf("[handleBatchMock] body reading error: %+v", err), err)
renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", 1001, logID))
renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", api.CodeInvalidBody, logID))
return
}

if err := json.Unmarshal([]byte(decodedBody.Get("batch")), &requests); err != nil {
logID := svc.LogRequestFailure(r, fmt.Sprintf("[handleBatchMock] body reading error: %+v", err), err)
renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", 1001, logID))
renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", api.CodeInvalidBody, logID))
return
}

Expand Down Expand Up @@ -93,6 +93,24 @@ func (svc *service) handleBatchMock(w http.ResponseWriter, r *http.Request) {
render.JSON(w, r, responses)
}

// handleNotFound handles not found requests
func (svc *service) handleNotFound(w http.ResponseWriter, r *http.Request) {
logID := svc.LogRequestFailure(r, "[handleNotFound] request to "+r.URL.Path, nil)
renderJSON(w, r, http.StatusNotFound, api.MakeHTTPErrorResponse("not found", api.CodeNotFound, logID))
}

// handleMethodNotAllowed handles not found requests
func (svc *service) handleMethodNotAllowed(w http.ResponseWriter, r *http.Request) {
logID := svc.LogRequestFailure(r, "[handleMethodNotAllowed] method "+r.Method, nil)
renderJSON(w, r, http.StatusNotFound, api.MakeHTTPErrorResponse("method not allowed", api.CodeMethodNotAllowed, logID))
}

// handleRateLimitExceeded handles rate limit exceeded requests
func (svc *service) handleRateLimitExceeded(w http.ResponseWriter, r *http.Request) {
render.Status(r, http.StatusTooManyRequests)
render.JSON(w, r, svc.cfg.rateExceededRespBody)
}

// shouldFail returns true if the request should fail based
// on a given success ratio and a request counter
func shouldFail(successRatio float64, requestsCounter int) bool {
Expand Down
88 changes: 88 additions & 0 deletions internal/service/r_mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,94 @@ func Test_service_handleBatchMock(t *testing.T) {
}
}

func Test_handleNotFound(t *testing.T) {

Check failure on line 195 in internal/service/r_mock_test.go

View workflow job for this annotation

GitHub Actions / check

Test_handleNotFound should call t.Parallel on the top level as well as its subtests (tparallel)
svc := &service{
logger: newStructuredLogger(),
}
tests := []struct {
name string
setupRequest func() *http.Request
respHandler func(tt *testing.T, resp *httptest.ResponseRecorder)
}{
{
name: "request to unknown route",
setupRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/v1/mock/unknown", nil)
return req
},
respHandler: func(tt *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusNotFound {
tt.Errorf("expected status code %d, got %d", http.StatusNotFound, resp.Code)
}
if resp.Header().Get("Content-Type") != "application/json" {
tt.Errorf("expected content type %s, got %s", "application/json", resp.Header().Get("Content-Type"))
}
var errorResponse api.HTTPErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errorResponse); err != nil {
tt.Errorf("could not unmarshal response body: %+v", err)
}

if errorResponse.Error.Message != "not found" || errorResponse.Error.Code != api.CodeNotFound {
tt.Errorf("expected error message %s and code %d, got %s and %d", "not found", api.CodeNotFound, errorResponse.Error.Message, errorResponse.Error.Code)
}
},
},
}
for _, testToRun := range tests {
test := testToRun
t.Run(test.name, func(tt *testing.T) {
tt.Parallel()
resp := httptest.NewRecorder()
svc.handleNotFound(resp, test.setupRequest())
test.respHandler(tt, resp)
})
}
}

func Test_handleMethodNotAllowed(t *testing.T) {

Check failure on line 239 in internal/service/r_mock_test.go

View workflow job for this annotation

GitHub Actions / check

Test_handleMethodNotAllowed should call t.Parallel on the top level as well as its subtests (tparallel)
svc := &service{
logger: newStructuredLogger(),
}
tests := []struct {
name string
setupRequest func() *http.Request
respHandler func(tt *testing.T, resp *httptest.ResponseRecorder)
}{
{
name: "request to unknown route",
setupRequest: func() *http.Request {
req := httptest.NewRequest(http.MethodGet, "/v1/mock/unknown", nil)
return req
},
respHandler: func(tt *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusNotFound {
tt.Errorf("expected status code %d, got %d", http.StatusNotFound, resp.Code)
}
if resp.Header().Get("Content-Type") != "application/json" {
tt.Errorf("expected content type %s, got %s", "application/json", resp.Header().Get("Content-Type"))
}
var errorResponse api.HTTPErrorResponse
if err := json.Unmarshal(resp.Body.Bytes(), &errorResponse); err != nil {
tt.Errorf("could not unmarshal response body: %+v", err)
}

if errorResponse.Error.Message != "method not allowed" || errorResponse.Error.Code != api.CodeMethodNotAllowed {
tt.Errorf("expected error message %s and code %d, got %s and %d", "method not allowed", api.CodeMethodNotAllowed, errorResponse.Error.Message, errorResponse.Error.Code)
}
},
},
}
for _, testToRun := range tests {
test := testToRun
t.Run(test.name, func(tt *testing.T) {
tt.Parallel()
resp := httptest.NewRecorder()
svc.handleMethodNotAllowed(resp, test.setupRequest())
test.respHandler(tt, resp)
})
}
}

func Test_shouldFail(t *testing.T) {
type args struct {
successRatio float64
Expand Down
16 changes: 4 additions & 12 deletions internal/service/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,20 @@ func (svc *service) MakeRouter() {
middleware.Heartbeat("/ready"),
render.SetContentType(render.ContentTypeJSON),
)
router.NotFound(func(w http.ResponseWriter, r *http.Request) {
render.Status(r, http.StatusNotFound)
render.JSON(w, r, map[string]interface{}{"success": false, "error": "not found"})
})
router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) {
render.Status(r, http.StatusNotFound)
render.JSON(w, r, map[string]interface{}{"success": false, "error": "method not allowed"})
})

// Endpoints handled by the service
router.Route(uriPrefix, func(r chi.Router) {
r.NotFound(svc.handleNotFound)
r.MethodNotAllowed(svc.handleMethodNotAllowed)

r.Use(svc.RequestLogger())
if svc.cfg.apiToken != "" {
r.Use(authn.BearerTokenAuth(svc.cfg.apiToken))
}
r.Use(httprate.Limit(
svc.cfg.rateLimit, // requests
time.Second, // per duration
httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) {
render.Status(r, http.StatusTooManyRequests)
render.JSON(w, r, svc.cfg.rateExceededRespBody)
}),
httprate.WithLimitHandler(svc.handleRateLimitExceeded),
))
r.Use(svc.incReqCounter())

Expand Down
9 changes: 9 additions & 0 deletions pkg/api/codes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package api

const (
// 1xxx errors
requestBase = 1000
CodeInvalidBody = requestBase + 1
CodeNotFound = requestBase + 2
CodeMethodNotAllowed = requestBase + 3
)

0 comments on commit 0719785

Please sign in to comment.