Skip to content

Commit

Permalink
feat: add support for GPU only node template field (#195)
Browse files Browse the repository at this point in the history
Co-authored-by: Laimonas Rastenis <laimonas@cast.ai>
  • Loading branch information
laimonasr and Laimonas Rastenis authored Jul 24, 2023
1 parent 68941f9 commit 8e8aed5
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 53 deletions.
8 changes: 4 additions & 4 deletions castai/resource_autoscaler.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ func resourceCastaiAutoscalerUpdate(ctx context.Context, data *schema.ResourceDa
return nil
}

func getCurrentPolicies(ctx context.Context, client *sdk.ClientWithResponses, clusterId sdk.ClusterId) ([]byte, error) {
func getCurrentPolicies(ctx context.Context, client *sdk.ClientWithResponses, clusterId string) ([]byte, error) {
log.Printf("[INFO] Getting cluster autoscaler information.")

resp, err := client.PoliciesAPIGetClusterPolicies(ctx, clusterId)
Expand Down Expand Up @@ -157,7 +157,7 @@ func updateAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, me
return upsertPolicies(ctx, meta, clusterId, changedPoliciesJSON)
}

func upsertPolicies(ctx context.Context, meta interface{}, clusterId sdk.ClusterId, changedPoliciesJSON string) error {
func upsertPolicies(ctx context.Context, meta interface{}, clusterId string, changedPoliciesJSON string) error {
client := meta.(*ProviderConfig).api

resp, err := client.PoliciesAPIUpsertClusterPoliciesWithBodyWithResponse(ctx, clusterId, "application/json", bytes.NewReader([]byte(changedPoliciesJSON)))
Expand Down Expand Up @@ -192,7 +192,7 @@ func readAutoscalerPolicies(ctx context.Context, data *schema.ResourceData, meta
return nil
}

func getChangedPolicies(ctx context.Context, data *schema.ResourceData, meta interface{}, clusterId sdk.ClusterId) ([]byte, error) {
func getChangedPolicies(ctx context.Context, data *schema.ResourceData, meta interface{}, clusterId string) ([]byte, error) {
policyChangesJSON, found := data.GetOk(FieldAutoscalerPoliciesJSON)
if !found {
log.Printf("[DEBUG] policies json not provided. Skipping autoscaler policies changes")
Expand Down Expand Up @@ -222,7 +222,7 @@ func getChangedPolicies(ctx context.Context, data *schema.ResourceData, meta int
return policies, nil
}

func getClusterId(data *schema.ResourceData) sdk.ClusterId {
func getClusterId(data *schema.ResourceData) string {
value, found := data.GetOk(FieldClusterId)
if !found {
return ""
Expand Down
22 changes: 19 additions & 3 deletions castai/resource_node_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ func resourceNodeTemplate() *schema.Resource {
Default: false,
Description: "Compute optimized instance constraint - will only pick compute optimized nodes if true.",
},
"is_gpu_only": {
Type: schema.TypeBool,
Optional: true,
Default: false,
Description: "GPU instance constraint - will only pick nodes with GPU if true",
},
"instance_families": {
Type: schema.TypeList,
MaxItems: 1,
Expand Down Expand Up @@ -381,6 +387,9 @@ func flattenConstraints(c *sdk.NodetemplatesV1TemplateConstraints) ([]map[string
if c.Spot != nil {
out["spot"] = c.Spot
}
if c.IsGpuOnly != nil {
out["is_gpu_only"] = c.IsGpuOnly
}

if c.UseSpotFallbacks != nil {
out["use_spot_fallbacks"] = c.UseSpotFallbacks
Expand Down Expand Up @@ -600,12 +609,14 @@ func resourceNodeTemplateCreate(ctx context.Context, d *schema.ResourceData, met
return resourceNodeTemplateRead(ctx, d, meta)
}

func getNodeTemplateByName(ctx context.Context, data *schema.ResourceData, meta any, clusterID sdk.ClusterId) (*sdk.NodetemplatesV1NodeTemplate, error) {
func getNodeTemplateByName(ctx context.Context, data *schema.ResourceData, meta any, clusterID string) (*sdk.NodetemplatesV1NodeTemplate, error) {
client := meta.(*ProviderConfig).api
nodeTemplateName := data.Id()

log.Printf("[INFO] Getting current node templates")
resp, err := client.NodeTemplatesAPIListNodeTemplatesWithResponse(ctx, clusterID)
resp, err := client.NodeTemplatesAPIListNodeTemplatesWithResponse(ctx, clusterID, &sdk.NodeTemplatesAPIListNodeTemplatesParams{
IncludeDefault: lo.ToPtr(false),
})
notFound := fmt.Errorf("node templates for cluster %q not found at CAST AI", clusterID)
if err != nil {
return nil, err
Expand Down Expand Up @@ -657,7 +668,9 @@ func nodeTemplateStateImporter(ctx context.Context, d *schema.ResourceData, meta

// Find node templates
client := meta.(*ProviderConfig).api
resp, err := client.NodeTemplatesAPIListNodeTemplatesWithResponse(ctx, clusterID)
resp, err := client.NodeTemplatesAPIListNodeTemplatesWithResponse(ctx, clusterID, &sdk.NodeTemplatesAPIListNodeTemplatesParams{
IncludeDefault: lo.ToPtr(false),
})
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -795,6 +808,9 @@ func toTemplateConstraints(obj map[string]any) *sdk.NodetemplatesV1TemplateConst
if v, ok := obj["architectures"].([]any); ok {
out.Architectures = toPtr(toStringList(v))
}
if v, ok := obj["is_gpu_only"].(bool); ok {
out.IsGpuOnly = toPtr(v)
}

return out
}
Expand Down
14 changes: 11 additions & 3 deletions castai/resource_node_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"fmt"
"github.com/samber/lo"
"io"
"net/http"
"testing"
Expand Down Expand Up @@ -98,7 +99,9 @@ func TestNodeTemplateResourceReadContext(t *testing.T) {
}
`)))
mockClient.EXPECT().
NodeTemplatesAPIListNodeTemplates(gomock.Any(), clusterId).
NodeTemplatesAPIListNodeTemplates(gomock.Any(), clusterId, gomock.Eq(&sdk.NodeTemplatesAPIListNodeTemplatesParams{
IncludeDefault: lo.ToPtr(false),
})).
Return(&http.Response{StatusCode: 200, Body: body, Header: map[string][]string{"Content-Type": {"json"}}}, nil)

resource := resourceNodeTemplate()
Expand Down Expand Up @@ -139,6 +142,7 @@ constraints.0.instance_families.0.exclude.4 = g5g
constraints.0.instance_families.0.exclude.5 = g5
constraints.0.instance_families.0.exclude.6 = g3
constraints.0.instance_families.0.include.# = 0
constraints.0.is_gpu_only = false
constraints.0.max_cpu = 10000
constraints.0.max_memory = 0
constraints.0.min_cpu = 10
Expand Down Expand Up @@ -180,7 +184,9 @@ func TestNodeTemplateResourceReadContextEmptyList(t *testing.T) {
clusterId := "b6bfc074-a267-400f-b8f1-db0850c369b1"
body := io.NopCloser(bytes.NewReader([]byte(`{"items": []}`)))
mockClient.EXPECT().
NodeTemplatesAPIListNodeTemplates(gomock.Any(), clusterId).
NodeTemplatesAPIListNodeTemplates(gomock.Any(), clusterId, gomock.Eq(&sdk.NodeTemplatesAPIListNodeTemplatesParams{
IncludeDefault: lo.ToPtr(false),
})).
Return(&http.Response{StatusCode: 200, Body: body, Header: map[string][]string{"Content-Type": {"json"}}}, nil)

resource := resourceNodeTemplate()
Expand Down Expand Up @@ -370,7 +376,9 @@ func testAccCheckNodeTemplateDestroy(s *terraform.State) error {

id := rs.Primary.ID
clusterID := rs.Primary.Attributes["cluster_id"]
response, err := client.NodeTemplatesAPIListNodeTemplatesWithResponse(ctx, clusterID)
response, err := client.NodeTemplatesAPIListNodeTemplatesWithResponse(ctx, clusterID, &sdk.NodeTemplatesAPIListNodeTemplatesParams{
IncludeDefault: lo.ToPtr(false),
})
if err != nil {
return err
}
Expand Down
77 changes: 49 additions & 28 deletions castai/sdk/api.gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 8e8aed5

Please sign in to comment.