From 58db5efaf2de781b347291e7aa912f87bf8bc466 Mon Sep 17 00:00:00 2001 From: Nick Ross Date: Thu, 23 Nov 2023 11:02:01 +0000 Subject: [PATCH] fix: tests --- .../tests/applications/test_views.py | 22 ++++-- .../tests/datasets/test_models.py | 3 +- .../tests/datasets/test_pipelines.py | 73 ++++++++++++++++--- 3 files changed, 81 insertions(+), 17 deletions(-) diff --git a/dataworkspace/dataworkspace/tests/applications/test_views.py b/dataworkspace/dataworkspace/tests/applications/test_views.py index bfd313f976..62e5fe0cc5 100644 --- a/dataworkspace/dataworkspace/tests/applications/test_views.py +++ b/dataworkspace/dataworkspace/tests/applications/test_views.py @@ -316,7 +316,10 @@ def test_unapprove_visualisation_successfully(self): class TestDataVisualisationUIDatasetsPage: - def test_shows_app_schema_pipelines_on_datasets_page_and_no_others(self, staff_client): + @mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") + def test_shows_app_schema_pipelines_on_datasets_page_and_no_others( + self, mock_s3_client, staff_client + ): visualisation = factories.VisualisationCatalogueItemFactory.create( short_description="summary", published=False, @@ -344,7 +347,8 @@ def test_shows_app_schema_pipelines_on_datasets_page_and_no_others(self, staff_c assert b"table_3" in response.content assert response.status_code == 200 - def test_shows_parsed_schemas_and_tables(self, staff_client): + @mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") + def test_shows_parsed_schemas_and_tables(self, mock_s3_client, staff_client): visualisation = factories.VisualisationCatalogueItemFactory.create( short_description="summary", published=False, @@ -372,7 +376,10 @@ def test_shows_parsed_schemas_and_tables(self, staff_client): @pytest.mark.django_db @mock.patch("dataworkspace.apps.applications.views.save_pipeline_to_dataflow") - def test_can_save_sql_if_reducing_tables(self, save_pipeline_to_dataflow, staff_client): + @mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") + def test_can_save_sql_if_reducing_tables( + self, mock_s3_client, save_pipeline_to_dataflow, staff_client + ): visualisation = factories.VisualisationCatalogueItemFactory.create( short_description="summary", published=False, @@ -417,7 +424,8 @@ def test_can_save_sql_if_reducing_tables(self, save_pipeline_to_dataflow, staff_ "INSERT INTO my_table(col_a) VALUES ('a', 'b');", ), ) - def test_cannot_save_sql_if_not_select(self, staff_client, sql): + @mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") + def test_cannot_save_sql_if_not_select(self, mock_s3_client, staff_client, sql): visualisation = factories.VisualisationCatalogueItemFactory.create( short_description="summary", published=False, @@ -454,7 +462,8 @@ def test_cannot_save_sql_if_not_select(self, staff_client, sql): assert response.status_code == 200 @pytest.mark.django_db - def test_cannot_save_sql_if_from_different_table(self, staff_client): + @mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") + def test_cannot_save_sql_if_from_different_table(self, mock_s3_client, staff_client): visualisation = factories.VisualisationCatalogueItemFactory.create( short_description="summary", published=False, @@ -492,7 +501,8 @@ def test_cannot_save_sql_if_from_different_table(self, staff_client): assert response.status_code == 200 @pytest.mark.django_db - def test_cannot_save_sql_if_from_different_visualisation(self, staff_client): + @mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") + def test_cannot_save_sql_if_from_different_visualisation(self, mock_s3_client, staff_client): visualisation_a = factories.VisualisationCatalogueItemFactory.create( short_description="summary", published=False, diff --git a/dataworkspace/dataworkspace/tests/datasets/test_models.py b/dataworkspace/dataworkspace/tests/datasets/test_models.py index 9baabb0887..8bde80554b 100644 --- a/dataworkspace/dataworkspace/tests/datasets/test_models.py +++ b/dataworkspace/dataworkspace/tests/datasets/test_models.py @@ -279,7 +279,8 @@ def test_preview_csv(self, mock_client): @pytest.mark.django_db -def test_pipeline_versions(): +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_pipeline_versions(mock_s3_client): pipeline = factories.PipelineFactory.create( table_name="schema.original_table_name", config={"sql": "SELECT 1, 2, 3"} ) diff --git a/dataworkspace/dataworkspace/tests/datasets/test_pipelines.py b/dataworkspace/dataworkspace/tests/datasets/test_pipelines.py index f7adf78bfe..5bdacce628 100644 --- a/dataworkspace/dataworkspace/tests/datasets/test_pipelines.py +++ b/dataworkspace/dataworkspace/tests/datasets/test_pipelines.py @@ -1,4 +1,5 @@ import pytest +from django.conf import settings from django.urls import reverse from mock import mock @@ -35,8 +36,16 @@ ), ) @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") def test_create_sql_pipeline( - mock_sync, table_name, expected_output, added_pipelines, staff_client + mock_s3_client, + mock_list, + mock_sync, + table_name, + expected_output, + added_pipelines, + staff_client, ): pipeline_count = Pipeline.objects.count() staff_client.post(reverse("admin:index"), follow=True) @@ -47,10 +56,17 @@ def test_create_sql_pipeline( ) assert expected_output in resp.content.decode(resp.charset) assert pipeline_count + added_pipelines == Pipeline.objects.count() + if added_pipelines > 0: + pipeline = Pipeline.objects.latest("created_date") + mock_s3_client().put_object.assert_called_once_with( + Body=mock.ANY, Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") -def test_create_sharepoint_pipeline(mock_sync, staff_client): +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_create_sharepoint_pipeline(mock_s3_client, mock_list, mock_sync, staff_client): pipeline_count = Pipeline.objects.count() staff_client.post(reverse("admin:index"), follow=True) resp = staff_client.post( @@ -66,6 +82,10 @@ def test_create_sharepoint_pipeline(mock_sync, staff_client): ) assert "Pipeline created successfully" in resp.content.decode(resp.charset) assert pipeline_count + 1 == Pipeline.objects.count() + pipeline = Pipeline.objects.latest("created_date") + mock_s3_client().put_object.assert_called_once_with( + Body=mock.ANY, Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") @@ -150,7 +170,9 @@ def test_create_pipeline_validates_duplicate_column_names(mock_sync, staff_clien @pytest.mark.django_db @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") -def test_edit_sql_pipeline(mock_sync, staff_client): +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_edit_sql_pipeline(mock_s3_client, mock_list, mock_sync, staff_client): pipeline = factories.PipelineFactory.create(config={"sql": "SELECT 1"}) staff_client.post(reverse("admin:index"), follow=True) resp = staff_client.post( @@ -166,6 +188,9 @@ def test_edit_sql_pipeline(mock_sync, staff_client): assert "Pipeline updated successfully" in resp.content.decode(resp.charset) pipeline.refresh_from_db() assert pipeline.config["sql"] == "SELECT 2" + mock_s3_client().put_object.assert_called_with( + Body=mock.ANY, Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @pytest.mark.django_db @@ -187,7 +212,11 @@ def test_edit_sql_pipeline(mock_sync, staff_client): ), ) @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") def test_edit_sql_pipeline_with_notes( + mock_s3_client, + mock_list, mock_sync, staff_client, intial_notes, @@ -210,11 +239,16 @@ def test_edit_sql_pipeline_with_notes( pipeline.refresh_from_db() assert pipeline.config["sql"] == "SELECT 2" assert pipeline.notes == edited_notes + mock_s3_client().put_object.assert_called_with( + Body=mock.ANY, Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @pytest.mark.django_db @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") -def test_edit_sharepoint_pipeline(mock_sync, staff_client): +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_edit_sharepoint_pipeline(mock_s3_client, mock_list, mock_sync, staff_client): pipeline = factories.PipelineFactory.create( type="sharepoint", config={"site_name": "site1", "list_name": "list1"} ) @@ -233,6 +267,9 @@ def test_edit_sharepoint_pipeline(mock_sync, staff_client): assert "Pipeline updated successfully" in resp.content.decode(resp.charset) pipeline.refresh_from_db() assert pipeline.config == {"site_name": "site2", "list_name": "list2"} + mock_s3_client().put_object.assert_called_with( + Body=mock.ANY, Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @pytest.mark.django_db @@ -254,7 +291,11 @@ def test_edit_sharepoint_pipeline(mock_sync, staff_client): ), ) @mock.patch("dataworkspace.apps.datasets.pipelines.views.save_pipeline_to_dataflow") -def test_edit_sharepoint_pipeline_with_notes(mock_sync, staff_client, intial_notes, edited_notes): +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_edit_sharepoint_pipeline_with_notes( + mock_s3_client, mock_list, mock_sync, staff_client, intial_notes, edited_notes +): pipeline = factories.PipelineFactory.create( type="sharepoint", config={"site_name": "site1", "list_name": "list1", "notes": intial_notes}, @@ -276,11 +317,16 @@ def test_edit_sharepoint_pipeline_with_notes(mock_sync, staff_client, intial_not pipeline.refresh_from_db() assert pipeline.config == {"site_name": "site2", "list_name": "list2"} assert pipeline.notes == edited_notes + mock_s3_client().put_object.assert_called_with( + Body=mock.ANY, Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @pytest.mark.django_db @mock.patch("dataworkspace.apps.datasets.pipelines.views.delete_pipeline_from_dataflow") -def test_delete_pipeline(mock_delete, staff_client): +@mock.patch("dataworkspace.apps.datasets.pipelines.views.list_pipelines") +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_delete_pipeline(mock_s3_client, mock_list, mock_delete, staff_client): pipeline = factories.PipelineFactory.create() pipeline_count = Pipeline.objects.count() staff_client.post(reverse("admin:index"), follow=True) @@ -290,11 +336,15 @@ def test_delete_pipeline(mock_delete, staff_client): ) assert "Pipeline deleted successfully" in resp.content.decode(resp.charset) assert pipeline_count - 1 == Pipeline.objects.count() + mock_s3_client().delete_object.assert_called_once_with( + Bucket=settings.AWS_UPLOADS_BUCKET, Key=pipeline.get_config_file_path() + ) @pytest.mark.django_db @mock.patch("dataworkspace.apps.datasets.pipelines.views.run_pipeline") -def test_run_pipeline(mock_run, staff_client): +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_run_pipeline(mock_s3_client, mock_run, staff_client): pipeline = factories.PipelineFactory.create(type="sql") staff_client.post(reverse("admin:index"), follow=True) resp = staff_client.post( @@ -306,7 +356,8 @@ def test_run_pipeline(mock_run, staff_client): @pytest.mark.django_db @mock.patch("dataworkspace.apps.datasets.pipelines.views.stop_pipeline") -def test_stop_pipeline(mock_stop, staff_client): +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_stop_pipeline(mock_s3_client, mock_stop, staff_client): pipeline = factories.PipelineFactory.create(type="sharepoint") staff_client.post(reverse("admin:index"), follow=True) resp = staff_client.post( @@ -317,7 +368,8 @@ def test_stop_pipeline(mock_stop, staff_client): @pytest.mark.django_db -def test_pipeline_log_success(staff_client, mocker): +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_pipeline_log_success(mock_s3_client, staff_client, mocker): pipeline = factories.PipelineFactory.create(type="sharepoint") _return_value = [ { @@ -336,7 +388,8 @@ def test_pipeline_log_success(staff_client, mocker): @pytest.mark.django_db -def test_pipeline_log_failure(staff_client, mocker): +@mock.patch("dataworkspace.apps.core.boto3_client.boto3.client") +def test_pipeline_log_failure(mock_s3_client, staff_client, mocker): pipeline = factories.PipelineFactory.create() mocker.patch( "dataworkspace.apps.datasets.pipelines.views.get_pipeline_logs",