Skip to content

Commit

Permalink
Retries - ReverseProxy.ErrorHandler based approach (#61)
Browse files Browse the repository at this point in the history
Allow for retrying failed requests to backends. Default to 1 retry per
lingo-request.

Fixes #48 

Builds on work done by @alpe in #64 and #51.

Co-authored-by: Alex Peters <alpe@users.noreply.github.com>
  • Loading branch information
nstogner and alpe authored Feb 15, 2024
1 parent 5227052 commit 48df469
Show file tree
Hide file tree
Showing 7 changed files with 512 additions and 258 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ENVTEST_K8S_VERSION = 1.27.1

.PHONY: test
test: test-unit test-race test-integration test-e2e
test: test-unit test-integration test-e2e

.PHONY: test-unit
test-unit:
Expand Down
2 changes: 2 additions & 0 deletions cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ func run() error {

concurrency := getEnvInt("CONCURRENCY", 100)
scaleDownDelay := getEnvInt("SCALE_DOWN_DELAY", 30)
backendRetries := getEnvInt("BACKEND_RETRIES", 1)

var metricsAddr string
var probeAddr string
Expand Down Expand Up @@ -154,6 +155,7 @@ func run() error {

proxy.MustRegister(metricsRegistry)
proxyHandler := proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyHandler.MaxRetries = backendRetries
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}

statsHandler := &stats.Handler{
Expand Down
238 changes: 135 additions & 103 deletions pkg/proxy/handler.go
Original file line number Diff line number Diff line change
@@ -1,153 +1,140 @@
package proxy

import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/http/httputil"
"net/url"
"strconv"

"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"

"github.com/substratusai/lingo/pkg/deployments"
"github.com/substratusai/lingo/pkg/endpoints"
"github.com/substratusai/lingo/pkg/queue"
)

var httpDuration = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "http_response_time_seconds",
Help: "Duration of HTTP requests.",
Buckets: prometheus.DefBuckets,
}, []string{"model", "status_code"})

func MustRegister(r prometheus.Registerer) {
r.MustRegister(httpDuration)
}

type DeploymentManager interface {
ResolveDeployment(model string) (string, bool)
AtLeastOne(model string)
}

type EndpointManager interface {
AwaitHostAddress(ctx context.Context, service, portName string) (string, error)
}

type QueueManager interface {
EnqueueAndWait(ctx context.Context, deploymentName, id string) func()
}

// Handler serves http requests for end-clients.
// It is also responsible for triggering scale-from-zero.
type Handler struct {
Deployments *deployments.Manager
Endpoints *endpoints.Manager
Queues *queue.Manager
Deployments DeploymentManager
Endpoints EndpointManager
Queues QueueManager

MaxRetries int
RetryCodes map[int]struct{}
}

func NewHandler(deployments *deployments.Manager, endpoints *endpoints.Manager, queues *queue.Manager) *Handler {
return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues}
func NewHandler(
deployments DeploymentManager,
endpoints EndpointManager,
queues QueueManager,
) *Handler {
return &Handler{
Deployments: deployments,
Endpoints: endpoints,
Queues: queues,
}
}

var defaultRetryCodes = map[int]struct{}{
http.StatusInternalServerError: {},
http.StatusBadGateway: {},
http.StatusServiceUnavailable: {},
http.StatusGatewayTimeout: {},
}

func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
var modelName string
captureStatusRespWriter := newCaptureStatusCodeResponseWriter(w)
w = captureStatusRespWriter
timer := prometheus.NewTimer(prometheus.ObserverFunc(func(v float64) {
httpDuration.WithLabelValues(modelName, strconv.Itoa(captureStatusRespWriter.statusCode)).Observe(v)
}))
defer timer.ObserveDuration()

id := uuid.New().String()
log.Printf("request: %v", r.URL)
log.Printf("url: %v", r.URL)

w.Header().Set("X-Proxy", "lingo")

var (
proxyRequest *http.Request
err error
)
pr := newProxyRequest(r)
defer pr.done()

// TODO: Only parse model for paths that would have a model.
modelName, proxyRequest, err = parseModel(r)
if err != nil || modelName == "" {
modelName = "unknown"
log.Printf("error reading model from request body: %v", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Bad request: unable to parse .model from JSON payload"))
if err := pr.parseModel(); err != nil {
pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err)
return
}
log.Println("model:", modelName)

deploy, found := h.Deployments.ResolveDeployment(modelName)
if !found {
log.Printf("deployment not found for model: %v", err)
w.WriteHeader(http.StatusNotFound)
w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
log.Println("model:", pr.model)

var backendExists bool
pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model)
if !backendExists {
pr.sendErrorResponse(w, http.StatusNotFound, "model not found: %v", pr.model)
return
}

h.Deployments.AtLeastOne(deploy)
// Ensure the backend is scaled to at least one Pod.
h.Deployments.AtLeastOne(pr.backendDeployment)

log.Println("Entering queue", id)
complete := h.Queues.EnqueueAndWait(r.Context(), deploy, id)
log.Println("Admitted into queue", id)
log.Printf("Entering queue: %v", pr.id)

// Wait to until the request is admitted into the queue before proceeding with
// serving the request.
complete := h.Queues.EnqueueAndWait(r.Context(), pr.backendDeployment, pr.id)
defer complete()

// abort when deployment was removed meanwhile
if _, exists := h.Deployments.ResolveDeployment(modelName); !exists {
log.Printf("deployment not active for model removed: %v", err)
w.WriteHeader(http.StatusNotFound)
_, _ = w.Write([]byte(fmt.Sprintf("Deployment for model not found: %v", modelName)))
log.Printf("Admitted into queue: %v", pr.id)

// After waiting for the request to be admitted, double check that the model
// still exists. It's possible that the model was deleted while waiting.
// This would lead to a long subequent wait with the host lookup.
pr.backendDeployment, backendExists = h.Deployments.ResolveDeployment(pr.model)
if !backendExists {
pr.sendErrorResponse(w, http.StatusNotFound, "model not found after being dequeued: %v", pr.model)
return
}

log.Println("Waiting for IPs", id)
host, err := h.Endpoints.AwaitHostAddress(r.Context(), deploy, "http")
h.proxyHTTP(w, pr)
}

// AdditionalProxyRewrite is an injection point for modifying proxy requests.
// Used in tests.
var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {}

func (h *Handler) proxyHTTP(w http.ResponseWriter, pr *proxyRequest) {
log.Printf("Waiting for host: %v", pr.id)

host, err := h.Endpoints.AwaitHostAddress(pr.r.Context(), pr.backendDeployment, "http")
if err != nil {
log.Printf("error while finding the host address %v", err)
switch {
case errors.Is(err, context.Canceled):
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Request cancelled"))
pr.sendErrorResponse(w, http.StatusInternalServerError, "request cancelled while finding host: %v", err)
return
case errors.Is(err, context.DeadlineExceeded):
w.WriteHeader(http.StatusGatewayTimeout)
_, _ = w.Write([]byte(fmt.Sprintf("Request timed out for model: %v", modelName)))
pr.sendErrorResponse(w, http.StatusGatewayTimeout, "request timeout while finding host: %v", err)
return
default:
w.WriteHeader(http.StatusInternalServerError)
_, _ = w.Write([]byte("Internal server error"))
pr.sendErrorResponse(w, http.StatusGatewayTimeout, "unable to find host: %v", err)
return
}
}
log.Printf("Got host: %v, id: %v\n", host, id)

// TODO: Avoid creating new reverse proxies for each request.
// TODO: Consider implementing a round robin scheme.
log.Printf("Proxying request to host %v: %v\n", host, id)
newReverseProxy(host).ServeHTTP(w, proxyRequest)
}

// parseModel parses the model name from the request
// returns empty string when none found or an error for failures on the proxy request object
func parseModel(r *http.Request) (string, *http.Request, error) {
if model := r.Header.Get("X-Model"); model != "" {
return model, r, nil
}
// parse request body for model name, ignore errors
body, err := io.ReadAll(r.Body)
if err != nil {
return "", r, nil
}

var payload struct {
Model string `json:"model"`
}
var model string
if err := json.Unmarshal(body, &payload); err == nil {
model = payload.Model
}
log.Printf("Got host: %v, id: %v\n", host, pr.id)

// create new request object
proxyReq, err := http.NewRequestWithContext(r.Context(), r.Method, r.URL.String(), bytes.NewReader(body))
if err != nil {
return "", nil, fmt.Errorf("create proxy request: %w", err)
}
proxyReq.Header = r.Header
if err := proxyReq.ParseForm(); err != nil {
return "", nil, fmt.Errorf("parse proxy form: %w", err)
}
return model, proxyReq, nil
}

// AdditionalProxyRewrite is an injection point for modifying proxy requests.
// Used in tests.
var AdditionalProxyRewrite = func(*httputil.ProxyRequest) {}

func newReverseProxy(host string) *httputil.ReverseProxy {
proxy := &httputil.ReverseProxy{
Rewrite: func(r *httputil.ProxyRequest) {
r.SetURL(&url.URL{
Expand All @@ -158,5 +145,50 @@ func newReverseProxy(host string) *httputil.ReverseProxy {
AdditionalProxyRewrite(r)
},
}
return proxy

proxy.ModifyResponse = func(r *http.Response) error {
// Record the response for metrics.
pr.status = r.StatusCode

// This point is reached if a response code is received.
if h.isRetryCode(r.StatusCode) && pr.attempt < h.MaxRetries {
// Returning an error will trigger the ErrorHandler.
return ErrRetry
}

return nil
}

proxy.ErrorHandler = func(w http.ResponseWriter, r *http.Request, err error) {
// This point could be reached if a bad response code was sent by the backend
// or
// if there was an issue with the connection and no response was ever received.
if err != nil && r.Context().Err() == nil && pr.attempt < h.MaxRetries {
pr.attempt++

log.Printf("Retrying request (%v/%v): %v", pr.attempt, h.MaxRetries, pr.id)
h.proxyHTTP(w, pr)
return
}

if !errors.Is(err, ErrRetry) {
pr.sendErrorResponse(w, http.StatusBadGateway, "proxy: exceeded retries: %v/%v", pr.attempt, h.MaxRetries)
}
}

log.Printf("Proxying request to host %v: %v\n", host, pr.id)
proxy.ServeHTTP(w, pr.httpRequest())
}

var ErrRetry = errors.New("retry")

func (h *Handler) isRetryCode(status int) bool {
var retry bool
// TODO: avoid the nil check here and set a default map in the constructor.
if h.RetryCodes != nil {
_, retry = h.RetryCodes[status]
} else {
_, retry = defaultRetryCodes[status]
}
return retry
}
Loading

0 comments on commit 48df469

Please sign in to comment.