Skip to content

Commit

Permalink
added ut for resources/nodes.go
Browse files Browse the repository at this point in the history
  • Loading branch information
smritidahal653 committed Nov 5, 2023
1 parent 098ecd1 commit a65d705
Showing 1 changed file with 130 additions and 0 deletions.
130 changes: 130 additions & 0 deletions pkg/resources/nodes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
package resources

import (
"context"
"errors"
"testing"

"github.com/azure/kaito/pkg/utils"
"github.com/stretchr/testify/mock"
"gotest.tools/assert"
corev1 "k8s.io/api/core/v1"
"sigs.k8s.io/controller-runtime/pkg/client"
)

func TestUpdateNodeWithLabel(t *testing.T) {
testcases := map[string]struct {
callMocks func(c *utils.MockClient)
expectedError error
}{
"Fail to update node because it cannot be retrieved": {
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), client.ObjectKey{Name: "mockNode"}, mock.IsType(&corev1.Node{}), mock.Anything).Return(errors.New("Cannot retrieve node"))
},
expectedError: errors.New("Cannot retrieve node"),
},
"Fail to update node because node cannot be updated": {
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), client.ObjectKey{Name: "mockNode"}, mock.Anything, mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&corev1.Node{}), mock.Anything).Return(errors.New("Cannot update node"))
},
expectedError: errors.New("Cannot update node"),
},
"Successfully updates node": {
callMocks: func(c *utils.MockClient) {
c.On("Get", mock.IsType(context.Background()), client.ObjectKey{Name: "mockNode"}, mock.IsType(&corev1.Node{}), mock.Anything).Return(nil)
c.On("Update", mock.IsType(context.Background()), mock.IsType(&corev1.Node{}), mock.Anything).Return(nil)
},
expectedError: nil,
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := utils.NewClient()
tc.callMocks(mockClient)

err := UpdateNodeWithLabel(context.Background(), "mockNode", "fakeKey", "fakeVal", mockClient)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
})
}
}

func TestListNodes(t *testing.T) {
testcases := map[string]struct {
callMocks func(c *utils.MockClient)
expectedError error
}{
"Fails to list nodes": {
callMocks: func(c *utils.MockClient) {
c.On("List", mock.IsType(context.Background()), mock.IsType(&corev1.NodeList{}), mock.Anything).Return(errors.New("Cannot retrieve node list"))
},
expectedError: errors.New("Cannot retrieve node list"),
},
"Successfully lists all nodes": {
callMocks: func(c *utils.MockClient) {
nodeList := utils.MockNodeList
relevantMap := c.CreateMapWithType(nodeList)
//insert node objects into the map
for _, obj := range utils.MockNodeList.Items {
n := obj
objKey := client.ObjectKeyFromObject(&n)

relevantMap[objKey] = &n
}

c.On("List", mock.IsType(context.Background()), mock.IsType(&corev1.NodeList{}), mock.Anything).Return(nil)
},
expectedError: nil,
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
mockClient := utils.NewClient()
tc.callMocks(mockClient)

labelSelector := client.MatchingLabels{}
nodeList, err := ListNodes(context.Background(), mockClient, labelSelector)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
assert.Check(t, nodeList != nil, "Response node list should not be nil")
assert.Check(t, nodeList.Items != nil, "Response node list items should not be nil")
assert.Check(t, len(nodeList.Items) == 3, "Response should contain 3 nodes")

} else {
assert.Equal(t, tc.expectedError.Error(), err.Error())
}
})
}
}

func TestCheckNvidiaPlugin(t *testing.T) {
testcases := map[string]struct {
nodeObj *corev1.Node
isNvidiaPlugin bool
}{
"Is not nvidia plugin": {
nodeObj: &utils.MockNodeList.Items[1],
isNvidiaPlugin: false,
},
"Is nvidia plugin": {
nodeObj: &utils.MockNodeList.Items[0],
isNvidiaPlugin: true,
},
}

for k, tc := range testcases {
t.Run(k, func(t *testing.T) {
result := CheckNvidiaPlugin(context.Background(), tc.nodeObj)

assert.Equal(t, result, tc.isNvidiaPlugin)
})
}
}

0 comments on commit a65d705

Please sign in to comment.