From 0719785c264b650fcbfb740f40ba5089252562c0 Mon Sep 17 00:00:00 2001 From: juan131 Date: Thu, 16 Nov 2023 16:40:48 +0100 Subject: [PATCH] fix: properly manage NotFound & MethodNotAllowed request errors Signed-off-by: juan131 --- internal/service/r_mock.go | 24 +++++++-- internal/service/r_mock_test.go | 88 +++++++++++++++++++++++++++++++++ internal/service/router.go | 16 ++---- pkg/api/codes.go | 9 ++++ 4 files changed, 122 insertions(+), 15 deletions(-) create mode 100644 pkg/api/codes.go diff --git a/internal/service/r_mock.go b/internal/service/r_mock.go index 085f125..c111711 100644 --- a/internal/service/r_mock.go +++ b/internal/service/r_mock.go @@ -47,20 +47,20 @@ func (svc *service) handleBatchMock(w http.ResponseWriter, r *http.Request) { encodedBody, err := io.ReadAll(r.Body) if err != nil { logID := svc.LogRequestFailure(r, fmt.Sprintf("[handleBatchMock] body reading error: %+v", err), err) - renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", 1001, logID)) + renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", api.CodeInvalidBody, logID)) return } decodedBody, err := url.ParseQuery(string(encodedBody)) if err != nil { logID := svc.LogRequestFailure(r, fmt.Sprintf("[handleBatchMock] body reading error: %+v", err), err) - renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", 1001, logID)) + renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", api.CodeInvalidBody, logID)) return } if err := json.Unmarshal([]byte(decodedBody.Get("batch")), &requests); err != nil { logID := svc.LogRequestFailure(r, fmt.Sprintf("[handleBatchMock] body reading error: %+v", err), err) - renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", 1001, logID)) + renderJSON(w, r, http.StatusBadRequest, api.MakeHTTPErrorResponse("body parsing error", api.CodeInvalidBody, logID)) return } @@ -93,6 +93,24 @@ func (svc *service) handleBatchMock(w http.ResponseWriter, r *http.Request) { render.JSON(w, r, responses) } +// handleNotFound handles not found requests +func (svc *service) handleNotFound(w http.ResponseWriter, r *http.Request) { + logID := svc.LogRequestFailure(r, "[handleNotFound] request to "+r.URL.Path, nil) + renderJSON(w, r, http.StatusNotFound, api.MakeHTTPErrorResponse("not found", api.CodeNotFound, logID)) +} + +// handleMethodNotAllowed handles not found requests +func (svc *service) handleMethodNotAllowed(w http.ResponseWriter, r *http.Request) { + logID := svc.LogRequestFailure(r, "[handleMethodNotAllowed] method "+r.Method, nil) + renderJSON(w, r, http.StatusNotFound, api.MakeHTTPErrorResponse("method not allowed", api.CodeMethodNotAllowed, logID)) +} + +// handleRateLimitExceeded handles rate limit exceeded requests +func (svc *service) handleRateLimitExceeded(w http.ResponseWriter, r *http.Request) { + render.Status(r, http.StatusTooManyRequests) + render.JSON(w, r, svc.cfg.rateExceededRespBody) +} + // shouldFail returns true if the request should fail based // on a given success ratio and a request counter func shouldFail(successRatio float64, requestsCounter int) bool { diff --git a/internal/service/r_mock_test.go b/internal/service/r_mock_test.go index c0d9b45..f8c8cc9 100644 --- a/internal/service/r_mock_test.go +++ b/internal/service/r_mock_test.go @@ -192,6 +192,94 @@ func Test_service_handleBatchMock(t *testing.T) { } } +func Test_handleNotFound(t *testing.T) { + svc := &service{ + logger: newStructuredLogger(), + } + tests := []struct { + name string + setupRequest func() *http.Request + respHandler func(tt *testing.T, resp *httptest.ResponseRecorder) + }{ + { + name: "request to unknown route", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/v1/mock/unknown", nil) + return req + }, + respHandler: func(tt *testing.T, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusNotFound { + tt.Errorf("expected status code %d, got %d", http.StatusNotFound, resp.Code) + } + if resp.Header().Get("Content-Type") != "application/json" { + tt.Errorf("expected content type %s, got %s", "application/json", resp.Header().Get("Content-Type")) + } + var errorResponse api.HTTPErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errorResponse); err != nil { + tt.Errorf("could not unmarshal response body: %+v", err) + } + + if errorResponse.Error.Message != "not found" || errorResponse.Error.Code != api.CodeNotFound { + tt.Errorf("expected error message %s and code %d, got %s and %d", "not found", api.CodeNotFound, errorResponse.Error.Message, errorResponse.Error.Code) + } + }, + }, + } + for _, testToRun := range tests { + test := testToRun + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + resp := httptest.NewRecorder() + svc.handleNotFound(resp, test.setupRequest()) + test.respHandler(tt, resp) + }) + } +} + +func Test_handleMethodNotAllowed(t *testing.T) { + svc := &service{ + logger: newStructuredLogger(), + } + tests := []struct { + name string + setupRequest func() *http.Request + respHandler func(tt *testing.T, resp *httptest.ResponseRecorder) + }{ + { + name: "request to unknown route", + setupRequest: func() *http.Request { + req := httptest.NewRequest(http.MethodGet, "/v1/mock/unknown", nil) + return req + }, + respHandler: func(tt *testing.T, resp *httptest.ResponseRecorder) { + if resp.Code != http.StatusNotFound { + tt.Errorf("expected status code %d, got %d", http.StatusNotFound, resp.Code) + } + if resp.Header().Get("Content-Type") != "application/json" { + tt.Errorf("expected content type %s, got %s", "application/json", resp.Header().Get("Content-Type")) + } + var errorResponse api.HTTPErrorResponse + if err := json.Unmarshal(resp.Body.Bytes(), &errorResponse); err != nil { + tt.Errorf("could not unmarshal response body: %+v", err) + } + + if errorResponse.Error.Message != "method not allowed" || errorResponse.Error.Code != api.CodeMethodNotAllowed { + tt.Errorf("expected error message %s and code %d, got %s and %d", "method not allowed", api.CodeMethodNotAllowed, errorResponse.Error.Message, errorResponse.Error.Code) + } + }, + }, + } + for _, testToRun := range tests { + test := testToRun + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + resp := httptest.NewRecorder() + svc.handleMethodNotAllowed(resp, test.setupRequest()) + test.respHandler(tt, resp) + }) + } +} + func Test_shouldFail(t *testing.T) { type args struct { successRatio float64 diff --git a/internal/service/router.go b/internal/service/router.go index 487ab9c..bc110cc 100644 --- a/internal/service/router.go +++ b/internal/service/router.go @@ -36,17 +36,12 @@ func (svc *service) MakeRouter() { middleware.Heartbeat("/ready"), render.SetContentType(render.ContentTypeJSON), ) - router.NotFound(func(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusNotFound) - render.JSON(w, r, map[string]interface{}{"success": false, "error": "not found"}) - }) - router.MethodNotAllowed(func(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusNotFound) - render.JSON(w, r, map[string]interface{}{"success": false, "error": "method not allowed"}) - }) // Endpoints handled by the service router.Route(uriPrefix, func(r chi.Router) { + r.NotFound(svc.handleNotFound) + r.MethodNotAllowed(svc.handleMethodNotAllowed) + r.Use(svc.RequestLogger()) if svc.cfg.apiToken != "" { r.Use(authn.BearerTokenAuth(svc.cfg.apiToken)) @@ -54,10 +49,7 @@ func (svc *service) MakeRouter() { r.Use(httprate.Limit( svc.cfg.rateLimit, // requests time.Second, // per duration - httprate.WithLimitHandler(func(w http.ResponseWriter, r *http.Request) { - render.Status(r, http.StatusTooManyRequests) - render.JSON(w, r, svc.cfg.rateExceededRespBody) - }), + httprate.WithLimitHandler(svc.handleRateLimitExceeded), )) r.Use(svc.incReqCounter()) diff --git a/pkg/api/codes.go b/pkg/api/codes.go new file mode 100644 index 0000000..124cfed --- /dev/null +++ b/pkg/api/codes.go @@ -0,0 +1,9 @@ +package api + +const ( + // 1xxx errors + requestBase = 1000 + CodeInvalidBody = requestBase + 1 + CodeNotFound = requestBase + 2 + CodeMethodNotAllowed = requestBase + 3 +)