diff --git a/charts/kubeai/templates/_helpers.tpl b/charts/kubeai/templates/_helpers.tpl index db18d1c2..fe0e8400 100644 --- a/charts/kubeai/templates/_helpers.tpl +++ b/charts/kubeai/templates/_helpers.tpl @@ -88,4 +88,11 @@ Create the name of the huggingface secret to use {{- end }} {{- .Values.secrets.huggingface.name }} {{- end }} +{{- end }} + +{{/* +Set the name of the configmap to use for storing model autosacling state +*/}} +{{- define "models.autoscalerStateConfigMapName" -}} +{{- default (printf "%s-autoscaler-state" (include "kubeai.fullname" .)) .Values.modelAutoscaling.stateConfigMapName }} {{- end }} \ No newline at end of file diff --git a/charts/kubeai/templates/autoscalerstateconfigmap.yaml b/charts/kubeai/templates/autoscalerstateconfigmap.yaml new file mode 100644 index 00000000..c0a54f38 --- /dev/null +++ b/charts/kubeai/templates/autoscalerstateconfigmap.yaml @@ -0,0 +1,4 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "models.autoscalerStateConfigMapName" . }} \ No newline at end of file diff --git a/charts/kubeai/templates/configmap.yaml b/charts/kubeai/templates/configmap.yaml index c4b8f482..85f72b73 100644 --- a/charts/kubeai/templates/configmap.yaml +++ b/charts/kubeai/templates/configmap.yaml @@ -27,6 +27,8 @@ data: {{- end}} serviceAccountName: {{ include "models.serviceAccountName" . }} modelAutoscaling: - {{- .Values.modelAutoscaling | toYaml | nindent 6 }} + interval: {{ .Values.modelAutoscaling.interval }} + timeWindow: {{ .Values.modelAutoscaling.timeWindow }} + stateConfigMapName: {{ include "models.autoscalerStateConfigMapName" . }} messaging: {{- .Values.messaging | toYaml | nindent 6 }} \ No newline at end of file diff --git a/charts/kubeai/values.yaml b/charts/kubeai/values.yaml index 2e72f73c..8d8c7a52 100644 --- a/charts/kubeai/values.yaml +++ b/charts/kubeai/values.yaml @@ -107,6 +107,9 @@ modelAutoscaling: # Time window the autoscaling algorithm will consider when calculating # the desired number of replicas. timeWindow: 10m + # The name of the ConfigMap that stores the state of the autoscaler. + # Defaults to "{fullname}-autoscaler-state". + # stateConfigMapName: "kubeai-autoscaler-state" messaging: errorMaxBackoff: 30s diff --git a/internal/config/system.go b/internal/config/system.go index 69419024..84d6178e 100644 --- a/internal/config/system.go +++ b/internal/config/system.go @@ -62,6 +62,9 @@ func (s *System) DefaultAndValidate() error { if s.ModelAutoscaling.TimeWindow.Duration == 0 { s.ModelAutoscaling.TimeWindow.Duration = 10 * time.Minute } + if s.ModelAutoscaling.StateConfigMapName == "" { + s.ModelAutoscaling.StateConfigMapName = "kubeai-autoscaler-state" + } if s.LeaderElection.LeaseDuration.Duration == 0 { s.LeaderElection.LeaseDuration.Duration = 15 * time.Second @@ -116,6 +119,12 @@ type ModelAutoscaling struct { // calculating the average number of requests. // Defaults to 10 minutes. TimeWindow Duration `json:"timeWindow" validate:"required"` + // StateConfigMapName is the name of the ConfigMap that will be used + // to store the state of the autoscaler. This ConfigMap ensures that + // the autoscaler can recover from crashes and restarts without losing + // its state. + // Defaults to "kubeai-autoscaler-state". + StateConfigMapName string `json:"stateConfigMapName" validate:"required"` } // RequiredConsecutiveScaleDowns returns the number of consecutive scale down diff --git a/internal/manager/run.go b/internal/manager/run.go index 1d879e52..e070f8fc 100644 --- a/internal/manager/run.go +++ b/internal/manager/run.go @@ -20,9 +20,11 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/types" clientgoscheme "k8s.io/client-go/kubernetes/scheme" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/cache" + "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/healthz" metricsserver "sigs.k8s.io/controller-runtime/pkg/metrics/server" @@ -171,6 +173,13 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { return fmt.Errorf("unable to create clientset: %w", err) } + // Create a new client for the autoscaler to use. This client will not use + // a cache and will be ready to user immediately. + k8sClient, err := client.New(mgr.GetConfig(), client.Options{Scheme: mgr.GetScheme()}) + if err != nil { + return fmt.Errorf("unable to create client: %w", err) + } + hostname, err := os.Hostname() if err != nil { return fmt.Errorf("unable to get hostname: %w", err) @@ -181,7 +190,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { cfg.LeaderElection.RetryPeriod.Duration, ) - endpointManager, err := endpoints.NewResolver(mgr) + endpointResolver, err := endpoints.NewResolver(mgr) if err != nil { return fmt.Errorf("unable to setup model resolver: %w", err) } @@ -216,16 +225,22 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { return fmt.Errorf("unable to parse metrics port: %w", err) } - modelAutoscaler := modelautoscaler.New( + modelAutoscaler, err := modelautoscaler.New( + ctx, + k8sClient, leaderElection, modelScaler, - endpointManager, + endpointResolver, cfg.ModelAutoscaling, metricsPort, + types.NamespacedName{Name: cfg.ModelAutoscaling.StateConfigMapName, Namespace: namespace}, cfg.FixedSelfMetricAddrs, ) + if err != nil { + return fmt.Errorf("unable to create model autoscaler: %w", err) + } - modelProxy := modelproxy.NewHandler(modelScaler, endpointManager, 3, nil) + modelProxy := modelproxy.NewHandler(modelScaler, endpointResolver, 3, nil) openaiHandler := openaiserver.NewHandler(mgr.GetClient(), modelProxy) mux := http.NewServeMux() mux.Handle("/openai/", openaiHandler) @@ -253,7 +268,7 @@ func Run(ctx context.Context, k8sCfg *rest.Config, cfg config.System) error { stream.MaxHandlers, cfg.Messaging.ErrorMaxBackoff.Duration, modelScaler, - endpointManager, + endpointResolver, httpClient, ) if err != nil { diff --git a/internal/modelautoscaler/autoscaler.go b/internal/modelautoscaler/autoscaler.go index f70b71d8..e5aff60b 100644 --- a/internal/modelautoscaler/autoscaler.go +++ b/internal/modelautoscaler/autoscaler.go @@ -13,32 +13,56 @@ import ( "github.com/substratusai/kubeai/internal/leader" "github.com/substratusai/kubeai/internal/modelscaler" "github.com/substratusai/kubeai/internal/movingaverage" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" ) func New( + ctx context.Context, + k8sClient client.Client, leaderElection *leader.Election, scaler *modelscaler.ModelScaler, resolver *endpoints.Resolver, cfg config.ModelAutoscaling, metricsPort int, + stateConfigMapRef types.NamespacedName, fixedSelfMetricAddrs []string, -) *Autoscaler { +) (*Autoscaler, error) { a := &Autoscaler{ + k8sClient: k8sClient, leaderElection: leaderElection, scaler: scaler, resolver: resolver, movingAvgByModel: map[string]*movingaverage.Simple{}, cfg: cfg, metricsPort: metricsPort, + stateConfigMapRef: stateConfigMapRef, fixedSelfMetricAddrs: fixedSelfMetricAddrs, } - return a + lastModelState, err := a.loadLastTotalModelState(ctx) + if err != nil { + return nil, fmt.Errorf("loading last state of models: %w", err) + } + log.Printf("Loaded last state of models: %d total, last calculated on %s", len(lastModelState.Models), lastModelState.LastCalculationTime) + for m, s := range lastModelState.Models { + // Preload moving averages with the last known state. + // If the last known state was 5.5, the preloaded moving average + // would look like [5.5, 5.5, 5.5, ...]. + preloaded := newPrefilledFloat64Slice(a.cfg.AverageWindowCount(), s.AverageActiveRequests) + a.movingAvgByModel[m] = movingaverage.NewSimple(preloaded) + } + + return a, nil } // Autoscaler is responsible for making continuous adjustments to // the scale of the backend. It is not responsible for scale-from-zero. // Each deployment has its own autoscaler. type Autoscaler struct { + k8sClient client.Client + + stateConfigMapRef types.NamespacedName + leaderElection *leader.Election scaler *modelscaler.ModelScaler @@ -76,6 +100,8 @@ func (a *Autoscaler) Start(ctx context.Context) { continue } + nextModelState := newTotalModelState() + var selfAddrs []string if len(a.fixedSelfMetricAddrs) > 0 { selfAddrs = a.fixedSelfMetricAddrs @@ -121,17 +147,33 @@ func (a *Autoscaler) Start(ctx context.Context) { log.Printf("Calculated target replicas for model %q: ceil(%v/%v) = %v, current requests: sum(%v) = %v, history: %v", m.Name, avgActiveRequests, *m.Spec.TargetRequests, ceil, activeRequests, activeRequestSum, avg.History()) a.scaler.Scale(ctx, &m, int32(ceil), a.cfg.RequiredConsecutiveScaleDowns(*m.Spec.ScaleDownDelaySeconds)) + + nextModelState.Models[m.Name] = modelState{ + AverageActiveRequests: avgActiveRequests, + } + } + + if err := a.saveTotalModelState(ctx, nextModelState); err != nil { + log.Printf("Failed to save model state: %v", err) } } } -func (r *Autoscaler) getMovingAvgActiveReqPerModel(model string) *movingaverage.Simple { - r.movingAvgByModelMtx.Lock() - a, ok := r.movingAvgByModel[model] +func (a *Autoscaler) getMovingAvgActiveReqPerModel(model string) *movingaverage.Simple { + a.movingAvgByModelMtx.Lock() + avg, ok := a.movingAvgByModel[model] if !ok { - a = movingaverage.NewSimple(make([]float64, r.cfg.AverageWindowCount())) - r.movingAvgByModel[model] = a + avg = movingaverage.NewSimple(make([]float64, a.cfg.AverageWindowCount())) + a.movingAvgByModel[model] = avg + } + a.movingAvgByModelMtx.Unlock() + return avg +} + +func newPrefilledFloat64Slice(length int, value float64) []float64 { + s := make([]float64, length) + for i := range s { + s[i] = value } - r.movingAvgByModelMtx.Unlock() - return a + return s } diff --git a/internal/modelautoscaler/state.go b/internal/modelautoscaler/state.go new file mode 100644 index 00000000..7514d069 --- /dev/null +++ b/internal/modelautoscaler/state.go @@ -0,0 +1,65 @@ +package modelautoscaler + +import ( + "context" + "encoding/json" + "fmt" + "log" + "time" + + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/controller-runtime/pkg/client" +) + +func newTotalModelState() totalModelState { + return totalModelState{ + Models: make(map[string]modelState), + LastCalculationTime: time.Now(), + } +} + +type totalModelState struct { + Models map[string]modelState `json:"models"` + LastCalculationTime time.Time `json:"lastCalculationTime"` +} + +type modelState struct { + AverageActiveRequests float64 `json:"averageActiveRequests"` +} + +func (a *Autoscaler) loadLastTotalModelState(ctx context.Context) (totalModelState, error) { + cm := &corev1.ConfigMap{} + if err := a.k8sClient.Get(ctx, a.stateConfigMapRef, cm); err != nil { + return totalModelState{}, fmt.Errorf("get ConfigMap %q: %w", a.stateConfigMapRef, err) + } + const key = "models" + jsonState, ok := cm.Data[key] + if !ok { + log.Printf("Autoscaler state ConfigMap %q has no key %q, state not loaded", key, a.stateConfigMapRef) + return totalModelState{}, nil + } + tms := totalModelState{} + if err := json.Unmarshal([]byte(jsonState), &tms); err != nil { + return totalModelState{}, fmt.Errorf("unmarshalling state: %w", err) + } + return tms, nil +} + +func (a *Autoscaler) saveTotalModelState(ctx context.Context, state totalModelState) error { + jsonState, err := json.Marshal(state) + if err != nil { + return fmt.Errorf("marshalling state: %w", err) + } + patch := fmt.Sprintf(`{"data":{"models":%q}}`, string(jsonState)) + if err := a.k8sClient.Patch(ctx, &corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Namespace: a.stateConfigMapRef.Namespace, + Name: a.stateConfigMapRef.Name, + }, + }, client.RawPatch(types.StrategicMergePatchType, []byte(patch))); err != nil { + return fmt.Errorf("patching ConfigMap %q: %w", a.stateConfigMapRef, err) + } + return nil +} diff --git a/test/integration/autoscaler_state_test.go b/test/integration/autoscaler_state_test.go new file mode 100644 index 00000000..93748a4f --- /dev/null +++ b/test/integration/autoscaler_state_test.go @@ -0,0 +1,71 @@ +package integration + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/substratusai/kubeai/internal/config" + corev1 "k8s.io/api/core/v1" + "k8s.io/apimachinery/pkg/types" + "k8s.io/utils/ptr" +) + +// TestAutoscalerState tests the autoscaler's state management. +func TestAutoscalerState(t *testing.T) { + m := modelForTest(t) + m.Spec.MaxReplicas = ptr.To[int32](100) + m.Spec.TargetRequests = ptr.To[int32](1) + m.Spec.ScaleDownDelaySeconds = ptr.To[int64](2) + + ms := newTestMetricsServer(t, m.Name) + + sysCfg := baseSysCfg(t) + sysCfg.ModelAutoscaling.TimeWindow = config.Duration{Duration: 1 * time.Second} + sysCfg.ModelAutoscaling.Interval = config.Duration{Duration: time.Second / 4} + sysCfg.FixedSelfMetricAddrs = []string{ + ms.addr(), + } + t.Logf("FixedSelfMetricAddrs: %v", sysCfg.FixedSelfMetricAddrs) + initTest(t, sysCfg) + + ms.activeRequests.Store(2) + + // Create the Model object in the Kubernetes cluster. + require.NoError(t, testK8sClient.Create(testCtx, m)) + + requireModelReplicas(t, m, 2, "Replicas should be autoscaled", 15*time.Second) + + // Assert that state was saved. + cm := &corev1.ConfigMap{} + require.EventuallyWithT(t, func(t *assert.CollectT) { + err := testK8sClient.Get(testCtx, types.NamespacedName{Name: sysCfg.ModelAutoscaling.StateConfigMapName, Namespace: testNS}, cm) + if !assert.NoError(t, err) { + return + } + if !assert.NotNil(t, cm.Data) { + return + } + data, ok := cm.Data["models"] + if !assert.True(t, ok) { + return + } + type modelState struct { + AverageActiveRequests float64 `json:"averageActiveRequests"` + } + var state struct { + Models map[string]modelState `json:"models"` + } + if !assert.NoError(t, json.Unmarshal([]byte(data), &state)) { + return + } + if !assert.Len(t, state.Models, 1) { + return + } + assert.Equal(t, state.Models[m.Name], modelState{ + AverageActiveRequests: 2.0, + }) + }, 15*time.Second, 1*time.Second) +} diff --git a/test/integration/autoscaling_ha_test.go b/test/integration/autoscaling_ha_test.go index 4c041dd5..2168e598 100644 --- a/test/integration/autoscaling_ha_test.go +++ b/test/integration/autoscaling_ha_test.go @@ -25,7 +25,7 @@ func TestAutoscalingHA(t *testing.T) { s2 := newTestMetricsServer(t, m.Name) s3 := newTestMetricsServer(t, m.Name) - sysCfg := baseSysCfg() + sysCfg := baseSysCfg(t) sysCfg.ModelAutoscaling.TimeWindow = config.Duration{Duration: 1 * time.Second} sysCfg.ModelAutoscaling.Interval = config.Duration{Duration: time.Second / 4} sysCfg.FixedSelfMetricAddrs = []string{ diff --git a/test/integration/main_test.go b/test/integration/main_test.go index 00d20498..b6b3258a 100644 --- a/test/integration/main_test.go +++ b/test/integration/main_test.go @@ -6,10 +6,12 @@ import ( "log" "net/http" "os" + "strings" "sync" "testing" "time" + "github.com/stretchr/testify/require" "github.com/substratusai/kubeai/internal/config" "github.com/substratusai/kubeai/internal/manager" "gocloud.dev/pubsub" @@ -127,6 +129,14 @@ func installCommonResources() error { // and shuts it down when the test is done. // SHOULD BE CALLED AT THE BEGINNING OF EACH TEST CASE. func initTest(t *testing.T, cfg config.System) { + autoscalerStateConfigMap := corev1.ConfigMap{ + ObjectMeta: metav1.ObjectMeta{ + Name: cfg.ModelAutoscaling.StateConfigMapName, + Namespace: testNS, + }, + } + require.NoError(t, testK8sClient.Create(testCtx, &autoscalerStateConfigMap)) + var wg sync.WaitGroup ctx, cancel := context.WithCancel(testCtx) t.Cleanup(func() { @@ -148,7 +158,7 @@ func initTest(t *testing.T, cfg config.System) { // baseSysCfg returns the System configuration for testing. // A function is used to avoid test cases accidentally modifying a global configuration variable // which would be tricky to debug. -func baseSysCfg() config.System { +func baseSysCfg(t *testing.T) config.System { return config.System{ MetricsAddr: "127.0.0.1:8080", HealthAddress: "127.0.0.1:8081", @@ -214,6 +224,9 @@ func baseSysCfg() config.System { }, }, }, + ModelAutoscaling: config.ModelAutoscaling{ + StateConfigMapName: strings.ToLower(t.Name()) + "-autoscaler-state", + }, LeaderElection: config.LeaderElection{ // Speed up the election process for tests. // This is important because the manager is restarted for each test case. diff --git a/test/integration/messenger_test.go b/test/integration/messenger_test.go index 5b456f02..a3e6485a 100644 --- a/test/integration/messenger_test.go +++ b/test/integration/messenger_test.go @@ -22,7 +22,7 @@ func TestMessenger(t *testing.T) { testRequestsTopic, err = pubsub.OpenTopic(testCtx, memRequestsURL) require.NoError(t, err) - sysCfg := baseSysCfg() + sysCfg := baseSysCfg(t) sysCfg.Messaging = config.Messaging{ Streams: []config.MessageStream{ { diff --git a/test/integration/model_default_test.go b/test/integration/model_default_test.go index 16c43674..a1fed915 100644 --- a/test/integration/model_default_test.go +++ b/test/integration/model_default_test.go @@ -11,7 +11,7 @@ import ( // TestModelDefaults tests that defaults are applied as expected. func TestModelDefaults(t *testing.T) { - initTest(t, baseSysCfg()) + initTest(t, baseSysCfg(t)) // Construct a Model object with MinReplicas set to 0. m := modelForTest(t) diff --git a/test/integration/model_pod_recovery_test.go b/test/integration/model_pod_recovery_test.go index bcb7b736..5a31d7e4 100644 --- a/test/integration/model_pod_recovery_test.go +++ b/test/integration/model_pod_recovery_test.go @@ -13,7 +13,7 @@ import ( // TestModelPodRecovery tests that if a Pod is deleted, it is recreated // with the same name. func TestModelPodRecovery(t *testing.T) { - initTest(t, baseSysCfg()) + initTest(t, baseSysCfg(t)) // Create a Model object. m := modelForTest(t) diff --git a/test/integration/model_pod_update_rollout_test.go b/test/integration/model_pod_update_rollout_test.go index 19162c86..ac7f7cb2 100644 --- a/test/integration/model_pod_update_rollout_test.go +++ b/test/integration/model_pod_update_rollout_test.go @@ -15,7 +15,7 @@ import ( // TestModelPodUpdateRollout tests that an update to a Model object triggers the recreation // of the corresponding Pods. func TestModelPodUpdateRollout(t *testing.T) { - initTest(t, baseSysCfg()) + initTest(t, baseSysCfg(t)) // Create a Model object. m := modelForTest(t) diff --git a/test/integration/model_profiles_test.go b/test/integration/model_profiles_test.go index f284c3a7..31b69a80 100644 --- a/test/integration/model_profiles_test.go +++ b/test/integration/model_profiles_test.go @@ -14,7 +14,8 @@ import ( // TestModelProfiles tests that profiles are applied as expected. func TestModelProfiles(t *testing.T) { - initTest(t, baseSysCfg()) + sysCfg := baseSysCfg(t) + initTest(t, sysCfg) // Construct a Model object with MinReplicas set to 0. m := modelForTest(t) @@ -48,9 +49,9 @@ func TestModelProfiles(t *testing.T) { container := mustFindPodContainerByName(t, pod, "server") assert.Equal(t, expectedResources, container.Resources) assert.Equal(t, ptr.To(cpuRuntimeClassName), pod.Spec.RuntimeClassName) - assert.Contains(t, pod.Spec.Tolerations, baseSysCfg().ResourceProfiles[resourceProfileCPU].Tolerations[0]) - assert.Equal(t, baseSysCfg().ResourceProfiles[resourceProfileCPU].Affinity, pod.Spec.Affinity) - assert.Equal(t, baseSysCfg().ResourceProfiles[resourceProfileCPU].NodeSelector, pod.Spec.NodeSelector) + assert.Contains(t, pod.Spec.Tolerations, sysCfg.ResourceProfiles[resourceProfileCPU].Tolerations[0]) + assert.Equal(t, sysCfg.ResourceProfiles[resourceProfileCPU].Affinity, pod.Spec.Affinity) + assert.Equal(t, sysCfg.ResourceProfiles[resourceProfileCPU].NodeSelector, pod.Spec.NodeSelector) }, 5*time.Second, time.Second/10, "Resource profile should be applied to the model Pod object") const userImage = "my-repo.com/my-repo/my-image:latest" diff --git a/test/integration/model_scaling_bounds_test.go b/test/integration/model_scaling_bounds_test.go index a10a9294..2008e6c9 100644 --- a/test/integration/model_scaling_bounds_test.go +++ b/test/integration/model_scaling_bounds_test.go @@ -15,7 +15,7 @@ import ( // MinReplicas and MaxReplicas set in the Model spec. // NOTE: This does not test scale-from-zero or autoscaling code paths! func TestModelScalingBounds(t *testing.T) { - initTest(t, baseSysCfg()) + initTest(t, baseSysCfg(t)) // Construct a Model object with MinReplicas set to 0. m := modelForTest(t) diff --git a/test/integration/proxy_test.go b/test/integration/proxy_test.go index cd3abf02..573f2930 100644 --- a/test/integration/proxy_test.go +++ b/test/integration/proxy_test.go @@ -17,7 +17,7 @@ import ( ) func TestProxy(t *testing.T) { - sysCfg := baseSysCfg() + sysCfg := baseSysCfg(t) sysCfg.ModelAutoscaling.TimeWindow = config.Duration{Duration: 6 * time.Second} sysCfg.ModelAutoscaling.Interval = config.Duration{Duration: time.Second} initTest(t, sysCfg) diff --git a/test/integration/utils_test.go b/test/integration/utils_test.go index 7e2b6526..57ef095d 100644 --- a/test/integration/utils_test.go +++ b/test/integration/utils_test.go @@ -55,10 +55,14 @@ func updateModel(t *testing.T, m *v1.Model, modify func(), msg string) { func requireModelReplicas(t *testing.T, m *v1.Model, expectedReplicas int32, msg string, after time.Duration) { require.EventuallyWithT(t, func(t *assert.CollectT) { - assert.NoError(t, testK8sClient.Get(testCtx, client.ObjectKeyFromObject(m), m)) + if !assert.NoError(t, testK8sClient.Get(testCtx, client.ObjectKeyFromObject(m), m)) { + return + } //jsn, _ := json.MarshalIndent(m, "", " ") //fmt.Println(string(jsn)) - assert.NotNil(t, m.Spec.Replicas) + if !assert.NotNil(t, m.Spec.Replicas) { + return + } assert.Equal(t, expectedReplicas, *m.Spec.Replicas) }, after, time.Second/10, "Model Replicas should match: "+msg) } @@ -66,15 +70,21 @@ func requireModelReplicas(t *testing.T, m *v1.Model, expectedReplicas int32, msg func requireModelPods(t *testing.T, m *v1.Model, expectedPods int, msg string, after time.Duration) { require.EventuallyWithT(t, func(t *assert.CollectT) { podList := &corev1.PodList{} - assert.NoError(t, testK8sClient.List(testCtx, podList, client.InNamespace(testNS), client.MatchingLabels{"model": m.Name})) - assert.Len(t, podList.Items, expectedPods) + if !assert.NoError(t, testK8sClient.List(testCtx, podList, client.InNamespace(testNS), client.MatchingLabels{"model": m.Name})) { + return + } + if !assert.Len(t, podList.Items, expectedPods) { + return + } }, after, time.Second/10, "Model Pods should match: "+msg) } func markAllModelPodsReady(t *testing.T, m *v1.Model) { require.EventuallyWithT(t, func(t *assert.CollectT) { podList := &corev1.PodList{} - require.NoError(t, testK8sClient.List(testCtx, podList, client.InNamespace(testNS), client.MatchingLabels{"model": m.Name})) + if !assert.NoError(t, testK8sClient.List(testCtx, podList, client.InNamespace(testNS), client.MatchingLabels{"model": m.Name})) { + return + } for _, pod := range podList.Items { pod.Status.Phase = corev1.PodRunning pod.Status.Conditions = []corev1.PodCondition{{Type: corev1.PodReady, Status: corev1.ConditionTrue}}