Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for HTTP X-Label-Selector headers to support Multitenancy #282

Merged
merged 13 commits into from
Oct 24, 2024
Binary file added docs/diagrams/multitenancy-labels.excalidraw.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
61 changes: 61 additions & 0 deletions docs/how-to/architect-for-multitenancy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Architect for Multitenancy

KubeAI can support multitenancy by filtering the models that it serves via Kubernetes label selectors. These label selectors can be applied when accessing any of the OpenAI-compatible endpoints through the `X-Label-Selector` HTTP header and will match on labels specified on the `kind: Model` objects. The pattern is similar to using a `WHERE` clause in a SQL query.

Example Models:

```yaml
kind: Model
metadata:
name: llama-3.2
labels:
tenancy: public
spec:
# ...
---
kind: Model
metadata:
name: custom-private-model
labels:
tenancy: org-abc
spec:
# ...
```

Example HTTP requests:

```bash
# The returned list of models will be filtered.
curl http://$KUBEAI_ENDPOINT/openai/v1/models \
-H "X-Label-Selector: tenancy in (org-abc, public)"
nstogner marked this conversation as resolved.
Show resolved Hide resolved

# When running inference, if the label selector does not match
# a 404 will be returned.
curl http://$KUBEAI_ENDPOINT/openai/v1/completions \
-H "Content-Type: application/json" \
-H "X-Label-Selector: tenancy in (org-abc, public)" \
-d '{"prompt": "Hi", "model": "llama-3.2"}'
```

The header value can be any valid [Kubernetes label selector](https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#label-selectors). Some examples include:

```
X-Label-Selector: tenancy=org-abc
X-Label-Selector: tenancy in (org-abc, public)
X-Label-Selector: tenancy!=private
```

Multiple `X-Label-Selector` headers can be specified in the same HTTP request and will be treated as a logical `AND`. For example, the following request will only match Models
that have a label `tenant: org-abc` and `user: sam`:

```bash
curl http://$KUBEAI_ENDPOINT/openai/v1/completions \
-H "Content-Type: application/json" \
-H "X-Label-Selector: tenant=org-abc" \
-H "X-Label-Selector: user=sam" \
-d '{"prompt": "Hi", "model": "llama-3.2"}'
```

Example architecture:

![Multitenancy](../diagrams/multitenancy-labels.excalidraw.png)
4 changes: 2 additions & 2 deletions internal/messenger/messenger.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func NewMessenger(
}

type ModelScaler interface {
ModelExists(ctx context.Context, model string) (bool, error)
LookupModel(ctx context.Context, model string, selectors []string) (bool, error)
ScaleAtLeastOneReplica(ctx context.Context, model string) error
}

Expand Down Expand Up @@ -204,7 +204,7 @@ func (m *Messenger) handleRequest(ctx context.Context, msg *pubsub.Message) {
metrics.InferenceRequestsActive.Add(ctx, 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(ctx, -1, metricAttrs)

modelExists, err := m.modelScaler.ModelExists(ctx, req.model)
modelExists, err := m.modelScaler.LookupModel(ctx, req.model, nil)
if err != nil {
m.sendResponse(req, m.jsonError("error checking if model exists: %v", err), http.StatusInternalServerError)
return
Expand Down
6 changes: 3 additions & 3 deletions internal/modelproxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
)

type ModelScaler interface {
ModelExists(ctx context.Context, model string) (bool, error)
LookupModel(ctx context.Context, model string, selectors []string) (bool, error)
ScaleAtLeastOneReplica(ctx context.Context, model string) error
}

Expand Down Expand Up @@ -60,7 +60,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
pr := newProxyRequest(r)

// TODO: Only parse model for paths that would have a model.
if err := pr.parseModel(); err != nil {
if err := pr.parse(); err != nil {
pr.sendErrorResponse(w, http.StatusBadRequest, "unable to parse model: %v", err)
return
}
Expand All @@ -74,7 +74,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
metrics.InferenceRequestsActive.Add(pr.r.Context(), 1, metricAttrs)
defer metrics.InferenceRequestsActive.Add(pr.r.Context(), -1, metricAttrs)

modelExists, err := h.modelScaler.ModelExists(r.Context(), pr.model)
modelExists, err := h.modelScaler.LookupModel(r.Context(), pr.model, pr.selectors)
if err != nil {
pr.sendErrorResponse(w, http.StatusInternalServerError, "unable to resolve model: %v", err)
return
Expand Down
2 changes: 1 addition & 1 deletion internal/modelproxy/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ type testModelInterface struct {
models map[string]string
}

func (t *testModelInterface) ModelExists(ctx context.Context, model string) (bool, error) {
func (t *testModelInterface) LookupModel(ctx context.Context, model string, selector []string) (bool, error) {
_, ok := t.models[model]
return ok, nil
}
Expand Down
9 changes: 5 additions & 4 deletions internal/modelproxy/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ type proxyRequest struct {
// in order to determine the model.
body []byte

// metadata:
selectors []string

id string
status int
Expand All @@ -38,14 +38,15 @@ func newProxyRequest(r *http.Request) *proxyRequest {
}

return pr

}

// parseModel attempts to determine the model from the request.
// parse attempts to determine the model from the request.
// It first checks the "X-Model" header, and if that is not set, it
// attempts to unmarshal the request body as JSON and extract the
// .model field.
func (pr *proxyRequest) parseModel() error {
func (pr *proxyRequest) parse() error {
pr.selectors = pr.r.Header.Values("X-Label-Selector")

// Try to get the model from the header first
if headerModel := pr.r.Header.Get("X-Model"); headerModel != "" {
pr.model = headerModel
Expand Down
22 changes: 20 additions & 2 deletions internal/modelscaler/scaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
kubeaiv1 "github.com/substratusai/kubeai/api/v1"
autoscalingv1 "k8s.io/api/autoscaling/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/controller-runtime/pkg/client"
)
Expand All @@ -26,13 +27,30 @@ func NewModelScaler(client client.Client, namespace string) *ModelScaler {
return &ModelScaler{client: client, namespace: namespace, consecutiveScaleDowns: map[string]int{}}
}

func (s *ModelScaler) ModelExists(ctx context.Context, model string) (bool, error) {
if err := s.client.Get(ctx, types.NamespacedName{Name: model, Namespace: s.namespace}, &kubeaiv1.Model{}); err != nil {
// LookupModel checks if a model exists and matches the given label selectors.
func (s *ModelScaler) LookupModel(ctx context.Context, model string, labelSelectors []string) (bool, error) {
m := &kubeaiv1.Model{}
if err := s.client.Get(ctx, types.NamespacedName{Name: model, Namespace: s.namespace}, m); err != nil {
if apierrors.IsNotFound(err) {
return false, nil
}
return false, err
}

modelLabels := m.GetLabels()
if modelLabels == nil {
modelLabels = map[string]string{}
}
for _, sel := range labelSelectors {
parsedSel, err := labels.Parse(sel)
if err != nil {
return false, fmt.Errorf("parse label selector: %w", err)
}
if !parsedSel.Matches(labels.Set(modelLabels)) {
return false, nil
}
}

return true, nil
}

Expand Down
18 changes: 16 additions & 2 deletions internal/openaiserver/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

kubeaiv1 "github.com/substratusai/kubeai/api/v1"
"k8s.io/apimachinery/pkg/labels"
"sigs.k8s.io/controller-runtime/pkg/client"
)

Expand All @@ -22,14 +23,27 @@ func (h *Handler) getModels(w http.ResponseWriter, r *http.Request) {
features = []string{kubeaiv1.ModelFeatureTextGeneration}
}

var listOpts []client.ListOption
headerSelectors := r.Header.Values("X-Label-Selector")
for _, sel := range headerSelectors {
parsedSel, err := labels.Parse(sel)
if err != nil {
sendErrorResponse(w, http.StatusBadRequest, "failed to parse label selector: %v", err)
return
}
listOpts = append(listOpts, client.MatchingLabelsSelector{Selector: parsedSel})
}

var k8sModels []kubeaiv1.Model
k8sModelNames := map[string]struct{}{}
for _, feature := range features {
// NOTE(nstogner): Could not find a way to do an OR query with the client,
// NOTE: At time of writing an OR query is not supported with the
// Kubernetes API server
// so we just do multiple queries and merge the results.
labelSelector := client.MatchingLabels{kubeaiv1.ModelFeatureLabelDomain + "/" + feature: "true"}
list := &kubeaiv1.ModelList{}
if err := h.K8sClient.List(r.Context(), list, labelSelector); err != nil {
opts := append([]client.ListOption{labelSelector}, listOpts...)
if err := h.K8sClient.List(r.Context(), list, opts...); err != nil {
sendErrorResponse(w, http.StatusInternalServerError, "failed to list models: %v", err)
return
}
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
63 changes: 63 additions & 0 deletions proposals/multitenancy.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Multitenancy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we remove this doc? It's confusing to keep it imo.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I have seen in the past is that proposals are not generally removed even after implementation. It can be nice to understand why a particular path was not pursued.


The goal of this proposal is to allow KubeAI to be used in a multitenancy environment where
some users only have access to some models.

## Implementation Option 1: Auth Labels

In this implementation, KubeAI has well-known labels that correspond to groups that are allowed to access models.

The KubeAI system is configured to trust a configured header.

```yaml
auth:
http:
trustedHeader: X-Auth-Groups
# Possibly in future: configure Model roles.
# modelRoles:
# user: ["list", "describe", "infer"]
```

The groups associated with a request are passed in a trusted header.

```bash
curl http://localhost:8000/openai/v1/completions \
-H "X-Auth-Groups: grp-a, grp-b"
```

The groups that are allowed to access a given model are configured as labels on the Model.

```yaml
kind: Model
metadata:
name: llama-3.2
labels:
auth.kubeai.org/grp-a: <role>
auth.kubeai.org/grp-c: <role>
```

## Implementation Option 2: General Label Selector

**CURRENT PREFERENCE** (Unless there is a reason to introduce auth-specific configuration.)

In this implementation, label selectors are used to filter models. The decision of which labels to use are up to the architects of the system that KubeAI is a part of. These label selectors could be enforced by a server that is an intermediary between KubeAI and the end users.

![Auth with Label Selector](./diagrams/auth-with-label-selector.excalidraw.png)

```bash
curl http://localhost:8000/openai/v1/completions \
-H "X-Label-Selector: key1=value1"

curl http://localhost:8000/openai/v1/models \
-H "X-Label-Selector: key1=value1"
```

Models just need to have the labels set.

```yaml
kind: Model
metadata:
name: llama-3.2
labels:
key1: value1
```
13 changes: 7 additions & 6 deletions test/integration/proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,33 @@ func TestProxy(t *testing.T) {
// Wait for controller cache to sync.
time.Sleep(3 * time.Second)

// Send request number 1
var wg sync.WaitGroup
sendRequests(t, &wg, m.Name, 1, http.StatusOK, "request 1")

// Send request number 1
sendRequests(t, &wg, m.Name, nil, 1, http.StatusOK, "", "request 1")

requireModelReplicas(t, m, 1, "Replicas should be scaled up to 1 to process messaging request", 5*time.Second)
requireModelPods(t, m, 1, "Pod should be created for the messaging request", 5*time.Second)
markAllModelPodsReady(t, m)
completeRequests(backendComplete, 1)
closeChannels(backendComplete, 1)
require.Equal(t, int32(1), totalBackendRequests.Load(), "ensure the request made its way to the backend")

const autoscaleUpWait = 25 * time.Second
// Ensure the deployment is autoscaled past 1.
// Simulate the backend processing the request.
sendRequests(t, &wg, m.Name, 2, http.StatusOK, "request 2,3")
sendRequests(t, &wg, m.Name, nil, 2, http.StatusOK, "", "request 2,3")
requireModelReplicas(t, m, 2, "Replicas should be scaled up to 2 to process pending messaging request", autoscaleUpWait)
requireModelPods(t, m, 2, "2 Pods should be created for the messaging requests", 5*time.Second)
markAllModelPodsReady(t, m)

// Make sure deployment will not be scaled past max (3).
sendRequests(t, &wg, m.Name, 2, http.StatusOK, "request 4,5")
sendRequests(t, &wg, m.Name, nil, 2, http.StatusOK, "", "request 4,5")
require.Never(t, func() bool {
assert.NoError(t, testK8sClient.Get(testCtx, client.ObjectKeyFromObject(m), m))
return *m.Spec.Replicas > *m.Spec.MaxReplicas
}, autoscaleUpWait, time.Second/10, "Replicas should not be scaled past MaxReplicas")

completeRequests(backendComplete, 4)
closeChannels(backendComplete, 4)
require.Equal(t, int32(5), totalBackendRequests.Load(), "ensure all the requests made their way to the backend")

// Ensure the deployment is autoscaled back down to MinReplicas.
Expand Down
Loading
Loading