Skip to content

Commit

Permalink
add test case for multipart form data model
Browse files Browse the repository at this point in the history
  • Loading branch information
samos123 committed Sep 4, 2024
1 parent 225d9d9 commit dd9a973
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
16 changes: 16 additions & 0 deletions internal/modelproxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ func TestHandler(t *testing.T) {
},
expBackendRequestCount: 1,
},
"happy 200 model in form data": {
reqHeaders: map[string]string{"Content-Type": "multipart/form-data; boundary=12345"},
reqBody: fmt.Sprintf(
"--12345\r\nContent-Disposition: form-data; name=\"model\"\r\n\r\n%s\r\n--12345--\r\n",
model1,
),
backendCode: http.StatusOK,
backendBody: `{"result":"ok"}`,
expCode: http.StatusOK,
expBody: `{"result":"ok"}`,
expLabels: map[string]string{
"model": model1,
"status_code": "200",
},
expBackendRequestCount: 1,
},
"retryable 500": {
reqBody: fmt.Sprintf(`{"model":%q}`, model1),
backendCode: http.StatusInternalServerError,
Expand Down
10 changes: 8 additions & 2 deletions internal/modelproxy/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"log"
"mime"
"mime/multipart"
"net/http"
"strconv"
Expand Down Expand Up @@ -72,9 +73,14 @@ func (pr *proxyRequest) parseModel() error {
return fmt.Errorf("read: %w", err)
}

if pr.r.Header.Get("Content-Type") == "multipart/form-data" {
// Error is ignored because we default to JSON handling if content-type is not set.
mediaType, params, _ := mime.ParseMediaType(pr.r.Header.Get("Content-Type"))
if mediaType == "multipart/form-data" {
// Parse the form data to get the model.
mpReader := multipart.NewReader(bytes.NewReader(pr.body), pr.r.Header.Get("Content-Type"))
if params["boundary"] == "" {
return fmt.Errorf("no boundary specified in multipart form data")
}
mpReader := multipart.NewReader(bytes.NewReader(pr.body), params["boundary"])
form, err := mpReader.ReadForm(1 << 20)
if err != nil {
return fmt.Errorf("read form: %w", err)
Expand Down

0 comments on commit dd9a973

Please sign in to comment.