Skip to content

Commit

Permalink
Added additional unit tests to increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
venkatajagannath committed Aug 26, 2024
1 parent 3337df6 commit bd6d847
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
96 changes: 96 additions & 0 deletions tests/hooks/test_ray_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,3 +676,99 @@ def test_delete_ray_cluster_success(
mock_delete_daemon_set.assert_called_once()
mock_delete_custom_object.assert_called_once()
mock_uninstall_kuberay_operator.assert_called_once()

@patch("ray_provider.hooks.ray.JobSubmissionClient")
def test_ray_client_exception(self, mock_job_client, ray_hook):
mock_job_client.side_effect = Exception("Connection failed")
with pytest.raises(AirflowException) as exc_info:
ray_hook.ray_client()
assert str(exc_info.value) == "Failed to create Ray JobSubmissionClient: Connection failed"

@patch("ray_provider.hooks.ray.RayHook.get_custom_object")
@patch("ray_provider.hooks.ray.RayHook.create_custom_object")
def test_create_or_update_cluster_exception(self, mock_create, mock_get, ray_hook):
mock_get.side_effect = client.exceptions.ApiException(status=500, reason="Internal Server Error")
with pytest.raises(AirflowException) as exc_info:
ray_hook._create_or_update_cluster(
update_if_exists=False,
group="ray.io",
version="v1",
plural="rayclusters",
name="test-cluster",
namespace="default",
cluster_spec={},
)
assert "Error accessing Ray cluster 'test-cluster'" in str(exc_info.value)

@patch("ray_provider.hooks.ray.RayHook.get_custom_object")
@patch("ray_provider.hooks.ray.RayHook.custom_object_client")
def test_create_or_update_cluster_update(self, mock_client, mock_get, ray_hook):
mock_get.return_value = {"metadata": {"name": "test-cluster"}}
ray_hook._create_or_update_cluster(
update_if_exists=True,
group="ray.io",
version="v1",
plural="rayclusters",
name="test-cluster",
namespace="default",
cluster_spec={"spec": {"some": "config"}},
)
mock_client.patch_namespaced_custom_object.assert_called_once_with(
group="ray.io",
version="v1",
namespace="default",
plural="rayclusters",
name="test-cluster",
body={"spec": {"some": "config"}},
)

@patch("ray_provider.hooks.ray.RayHook._validate_yaml_file")
@patch("ray_provider.hooks.ray.RayHook.install_kuberay_operator")
@patch("ray_provider.hooks.ray.RayHook.load_yaml_content")
@patch("ray_provider.hooks.ray.RayHook._create_or_update_cluster")
@patch("ray_provider.hooks.ray.RayHook._setup_gpu_driver")
@patch("ray_provider.hooks.ray.RayHook._setup_load_balancer")
def test_setup_ray_cluster_exception(
self,
mock_setup_lb,
mock_setup_gpu,
mock_create_or_update,
mock_load_yaml,
mock_install_operator,
mock_validate_yaml,
ray_hook,
):
mock_create_or_update.side_effect = Exception("Cluster creation failed")
context = {"task_instance": MagicMock()}
with pytest.raises(AirflowException) as exc_info:
ray_hook.setup_ray_cluster(
context=context,
ray_cluster_yaml="test.yaml",
kuberay_version="1.0.0",
gpu_device_plugin_yaml="gpu.yaml",
update_if_exists=False,
)
assert "Failed to set up Ray cluster: Cluster creation failed" in str(exc_info.value)

@patch("ray_provider.hooks.ray.RayHook._validate_yaml_file")
@patch("ray_provider.hooks.ray.RayHook.load_yaml_content")
@patch("ray_provider.hooks.ray.RayHook.get_custom_object")
@patch("ray_provider.hooks.ray.RayHook.delete_custom_object")
@patch("ray_provider.hooks.ray.RayHook.get_daemon_set")
@patch("ray_provider.hooks.ray.RayHook.delete_daemon_set")
@patch("ray_provider.hooks.ray.RayHook.uninstall_kuberay_operator")
def test_delete_ray_cluster_exception(
self,
mock_uninstall_operator,
mock_delete_daemon_set,
mock_get_daemon_set,
mock_delete_custom_object,
mock_get_custom_object,
mock_load_yaml,
mock_validate_yaml,
ray_hook,
):
mock_delete_custom_object.side_effect = Exception("Cluster deletion failed")
with pytest.raises(AirflowException) as exc_info:
ray_hook.delete_ray_cluster(ray_cluster_yaml="test.yaml", gpu_device_plugin_yaml="gpu.yaml")
assert "Failed to delete Ray cluster: Cluster deletion failed" in str(exc_info.value)
42 changes: 42 additions & 0 deletions tests/operators/test_ray_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,45 @@ def test_template_fields(self):
"ray_cluster_yaml",
"job_timeout_seconds",
)

@patch("ray_provider.operators.ray.RayHook")
def test_setup_cluster_exception(self, mock_ray_hook, context):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
ray_cluster_yaml="cluster.yaml",
)

mock_hook = mock_ray_hook.return_value
operator.hook = mock_hook

mock_hook.setup_ray_cluster.side_effect = Exception("Cluster setup failed")

with pytest.raises(Exception) as exc_info:
operator._setup_cluster(context)

assert str(exc_info.value) == "Cluster setup failed"
mock_hook.setup_ray_cluster.assert_called_once()

@patch("ray_provider.operators.ray.RayHook")
def test_delete_cluster_exception(self, mock_ray_hook):
operator = SubmitRayJob(
task_id="test_task",
conn_id="test_conn",
entrypoint="python script.py",
runtime_env={},
ray_cluster_yaml="cluster.yaml",
)

mock_hook = mock_ray_hook.return_value
operator.hook = mock_hook

mock_hook.delete_ray_cluster.side_effect = Exception("Cluster deletion failed")

with pytest.raises(Exception) as exc_info:
operator._delete_cluster()

assert str(exc_info.value) == "Cluster deletion failed"
mock_hook.delete_ray_cluster.assert_called_once()

0 comments on commit bd6d847

Please sign in to comment.