From f546cbe77957a62c138359f3a60398aec0203dcf Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 22 Oct 2024 09:20:59 -0700 Subject: [PATCH 1/2] Update dbt-spark and pysql dependencies (#833) --- .github/workflows/main.yml | 11 ++++------- CHANGELOG.md | 3 ++- requirements.txt | 4 ++-- setup.py | 12 ++++++------ 4 files changed, 14 insertions(+), 16 deletions(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 9561c40c..3a1637c8 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -65,7 +65,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4.3.0 with: - python-version: "3.8" + python-version: "3.9" - name: Install python dependencies run: | @@ -86,7 +86,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] env: TOXENV: "unit" @@ -138,7 +138,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v4.3.0 with: - python-version: "3.8" + python-version: "3.9" - name: Install python dependencies run: | @@ -183,7 +183,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] - python-version: ["3.8", "3.9", "3.10", "3.11"] + python-version: ["3.9", "3.10", "3.11", "3.12"] steps: - name: Set up Python ${{ matrix.python-version }} @@ -207,9 +207,6 @@ jobs: - name: Install wheel distributions run: | find ./dist/*.whl -maxdepth 1 -type f | xargs python -m pip install --force-reinstall --find-links=dist/ - - name: Install dbt-core - run: | - python -m pip install dbt-core==1.8.0rc2 - name: Check wheel distributions run: | dbt --version diff --git a/CHANGELOG.md b/CHANGELOG.md index 8238716f..960fcf8e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,7 +21,8 @@ ### Under the Hood - Fix places where we were not properly closing cursors, and other test warnings ([713](https://github.com/databricks/dbt-databricks/pull/713)) -- Upgrade databricks-sql-connector dependency to 3.4.0 ([790](https://github.com/databricks/dbt-databricks/pull/790)) +- Drop support for Python 3.8 ([713](https://github.com/databricks/dbt-databricks/pull/713)) +- Upgrade databricks-sql-connector dependency to 3.5.0 ([833](https://github.com/databricks/dbt-databricks/pull/833)) ## dbt-databricks 1.8.7 (October 10, 2024) diff --git a/requirements.txt b/requirements.txt index 3df5cb12..fb9c7f54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -databricks-sql-connector>=3.4.0, <3.5.0 -dbt-spark~=1.8.0 +databricks-sql-connector>=3.5.0, <4.0 +dbt-spark>=1.9.0b1, <2.0 dbt-core>=1.9.0b1, <2.0 dbt-common>=1.10.0, <2.0 dbt-adapters>=1.7.0, <2.0 diff --git a/setup.py b/setup.py index a5b2ccc5..5c236324 100644 --- a/setup.py +++ b/setup.py @@ -3,9 +3,9 @@ import sys # require python 3.8 or newer -if sys.version_info < (3, 8): +if sys.version_info < (3, 9): print("Error: dbt does not support this version of Python.") - print("Please upgrade to Python 3.8 or higher.") + print("Please upgrade to Python 3.9 or higher.") sys.exit(1) @@ -54,11 +54,11 @@ def _get_plugin_version() -> str: packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - "dbt-spark>=1.8.0, <2.0", + "dbt-spark>=1.9.0b1, <2.0", "dbt-core>=1.9.0b1, <2.0", "dbt-adapters>=1.7.0, <2.0", "dbt-common>=1.10.0, <2.0", - "databricks-sql-connector>=3.4.0, <3.5.0", + "databricks-sql-connector>=3.5.0, <4.0.0", "databricks-sdk==0.17.0", "keyring>=23.13.0", "pandas<2.2.0", @@ -71,10 +71,10 @@ def _get_plugin_version() -> str: "Operating System :: Microsoft :: Windows", "Operating System :: MacOS :: MacOS X", "Operating System :: POSIX :: Linux", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], - python_requires=">=3.8", + python_requires=">=3.9", ) From 84fc024fb8e881ef5dc59d09a56bfcba99842091 Mon Sep 17 00:00:00 2001 From: Ben Cassell <98852248+benc-db@users.noreply.github.com> Date: Tue, 22 Oct 2024 15:50:55 -0700 Subject: [PATCH 2/2] Refactor python config handling (#830) --- CHANGELOG.md | 1 + .../databricks/python_models/python_config.py | 57 ++ .../python_models/python_submissions.py | 670 +++++++++++------- docs/workflow-job-submission.md | 134 ++-- .../adapter/python_model/fixtures.py | 2 +- tests/unit/python/test_python_config.py | 135 ++++ tests/unit/python/test_python_helpers.py | 79 +++ tests/unit/python/test_python_job_support.py | 185 +++++ tests/unit/python/test_python_submissions.py | 249 ------- tests/unit/python/test_python_submitters.py | 172 +++++ .../python/test_python_workflow_support.py | 142 ++++ 11 files changed, 1258 insertions(+), 568 deletions(-) create mode 100644 dbt/adapters/databricks/python_models/python_config.py create mode 100644 tests/unit/python/test_python_config.py create mode 100644 tests/unit/python/test_python_helpers.py create mode 100644 tests/unit/python/test_python_job_support.py delete mode 100644 tests/unit/python/test_python_submissions.py create mode 100644 tests/unit/python/test_python_submitters.py create mode 100644 tests/unit/python/test_python_workflow_support.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 960fcf8e..cb759c85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ ### Under the Hood +- Significant refactoring and increased testing of python_submissions ([830](https://github.com/databricks/dbt-databricks/pull/830)) - Fix places where we were not properly closing cursors, and other test warnings ([713](https://github.com/databricks/dbt-databricks/pull/713)) - Drop support for Python 3.8 ([713](https://github.com/databricks/dbt-databricks/pull/713)) - Upgrade databricks-sql-connector dependency to 3.5.0 ([833](https://github.com/databricks/dbt-databricks/pull/833)) diff --git a/dbt/adapters/databricks/python_models/python_config.py b/dbt/adapters/databricks/python_models/python_config.py new file mode 100644 index 00000000..6398397d --- /dev/null +++ b/dbt/adapters/databricks/python_models/python_config.py @@ -0,0 +1,57 @@ +from typing import Any, Dict, List, Optional +import uuid +from pydantic import BaseModel, Field + + +DEFAULT_TIMEOUT = 60 * 60 * 24 + + +class PythonJobConfig(BaseModel): + """Pydantic model for config found in python_job_config.""" + + name: Optional[str] = None + grants: Dict[str, List[Dict[str, str]]] = Field(exclude=True, default_factory=dict) + existing_job_id: str = Field("", exclude=True) + post_hook_tasks: List[Dict[str, Any]] = Field(exclude=True, default_factory=list) + additional_task_settings: Dict[str, Any] = Field(exclude=True, default_factory=dict) + + class Config: + extra = "allow" + + +class PythonModelConfig(BaseModel): + """ + Pydantic model for a Python model configuration. + Includes some job-specific settings that are not yet part of PythonJobConfig. + """ + + user_folder_for_python: bool = False + timeout: int = Field(DEFAULT_TIMEOUT, gt=0) + job_cluster_config: Dict[str, Any] = Field(default_factory=dict) + access_control_list: List[Dict[str, str]] = Field(default_factory=list) + packages: List[str] = Field(default_factory=list) + index_url: Optional[str] = None + additional_libs: List[Dict[str, Any]] = Field(default_factory=list) + python_job_config: PythonJobConfig = Field(default_factory=lambda: PythonJobConfig(**{})) + cluster_id: Optional[str] = None + http_path: Optional[str] = None + create_notebook: bool = False + + +class ParsedPythonModel(BaseModel): + """Pydantic model for a Python model parsed from a dbt manifest""" + + catalog: str = Field("hive_metastore", alias="database") + + # Schema is a reserved name in Pydantic + schema_: str = Field("default", alias="schema") + + identifier: str = Field(alias="alias") + config: PythonModelConfig + + @property + def run_name(self) -> str: + return f"{self.catalog}-{self.schema_}-{self.identifier}-{uuid.uuid4()}" + + class Config: + allow_population_by_field_name = True diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index de02f473..4b564f1c 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,14 +1,17 @@ -import uuid +from abc import ABC, abstractmethod from typing import Any from typing import Dict from typing import List from typing import Optional from typing import Tuple +from attr import dataclass +from typing_extensions import override from dbt.adapters.base import PythonJobHelper -from dbt.adapters.databricks.api_client import CommandExecution +from dbt.adapters.databricks.api_client import CommandExecution, WorkflowJobApi from dbt.adapters.databricks.api_client import DatabricksApiClient from dbt.adapters.databricks.credentials import DatabricksCredentials +from dbt.adapters.databricks.python_models.python_config import ParsedPythonModel from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt_common.exceptions import DbtRuntimeError @@ -16,55 +19,137 @@ DEFAULT_TIMEOUT = 60 * 60 * 24 +class PythonSubmitter(ABC): + """Interface for submitting Python models to run on Databricks.""" + + @abstractmethod + def submit(self, compiled_code: str) -> None: + """Submit the compiled code to Databricks.""" + pass + + class BaseDatabricksHelper(PythonJobHelper): + """Base helper for python models on Databricks.""" + tracker = PythonRunTracker() - @property - def workflow_spec(self) -> Dict[str, Any]: + def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.credentials.validate_creds() + self.parsed_model = ParsedPythonModel(**parsed_model) + + self.api_client = DatabricksApiClient.create( + credentials, + self.parsed_model.config.timeout, + self.parsed_model.config.user_folder_for_python, + ) + self.validate_config() + + self.command_submitter = self.build_submitter() + + def validate_config(self) -> None: + """Perform any validation required to ensure submission method can proceed.""" + pass + + @abstractmethod + def build_submitter(self) -> PythonSubmitter: """ - The workflow gets modified throughout. Settings added through dbt are popped off - before the spec is sent to the Databricks API + Since we don't own instantiation of the Helper, we construct the submitter here, + after validation. """ - return self.parsed_model["config"].get("workflow_job_config", {}) + pass - @property - def cluster_spec(self) -> Dict[str, Any]: - return self.parsed_model["config"].get("job_cluster_config", {}) + def submit(self, compiled_code: str) -> None: + self.command_submitter.submit(compiled_code) - def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: - self.credentials = credentials - self.identifier = parsed_model["alias"] - self.schema = parsed_model["schema"] - self.database = parsed_model.get("database") - self.parsed_model = parsed_model - use_user_folder = parsed_model["config"].get("user_folder_for_python", False) - self.check_credentials() +class PythonCommandSubmitter(PythonSubmitter): + """Submitter for Python models using the Command API.""" - self.api_client = DatabricksApiClient.create( - credentials, self.get_timeout(), use_user_folder - ) + def __init__( + self, api_client: DatabricksApiClient, tracker: PythonRunTracker, cluster_id: str + ) -> None: + self.api_client = api_client + self.tracker = tracker + self.cluster_id = cluster_id - self.job_grants: Dict[str, List[Dict[str, Any]]] = self.workflow_spec.pop("grants", {}) + @override + def submit(self, compiled_code: str) -> None: + context_id = self.api_client.command_contexts.create(self.cluster_id) + command_exec: Optional[CommandExecution] = None + try: + command_exec = self.api_client.commands.execute( + self.cluster_id, context_id, compiled_code + ) - def get_timeout(self) -> int: - timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT) - if timeout <= 0: - raise ValueError("Timeout must be a positive integer") - return timeout + self.tracker.insert_command(command_exec) + # poll until job finish + self.api_client.commands.poll_for_completion(command_exec) - def check_credentials(self) -> None: - self.credentials.validate_creds() + finally: + if command_exec: + self.tracker.remove_command(command_exec) + self.api_client.command_contexts.destroy(self.cluster_id, context_id) + + +class PythonNotebookUploader: + """Uploads a compiled Python model as a notebook to the Databricks workspace.""" + + def __init__(self, api_client: DatabricksApiClient, parsed_model: ParsedPythonModel) -> None: + self.api_client = api_client + self.catalog = parsed_model.catalog + self.schema = parsed_model.schema_ + self.identifier = parsed_model.identifier + + def upload(self, compiled_code: str) -> str: + """Upload the compiled code to the Databricks workspace.""" + workdir = self.api_client.workspace.create_python_model_dir(self.catalog, self.schema) + file_path = f"{workdir}{self.identifier}" + self.api_client.workspace.upload_notebook(file_path, compiled_code) + return file_path + + +@dataclass(frozen=True) +class PythonJobDetails: + """Details required to submit a Python job run to Databricks.""" - def _update_with_acls(self, cluster_dict: dict) -> dict: - acl = self.parsed_model["config"].get("access_control_list", None) - if acl: - cluster_dict.update({"access_control_list": acl}) - return cluster_dict + run_name: str + job_spec: Dict[str, Any] + additional_job_config: Dict[str, Any] + + +class PythonPermissionBuilder: + """Class for building access control list for Python jobs.""" + + def __init__( + self, + api_client: DatabricksApiClient, + ) -> None: + self.api_client = api_client + + def _get_job_owner_for_config(self) -> Tuple[str, str]: + """Get the owner of the job (and type) for the access control list.""" + curr_user = self.api_client.curr_user.get_username() + is_service_principal = self.api_client.curr_user.is_service_principal(curr_user) + + source = "service_principal_name" if is_service_principal else "user_name" + return curr_user, source + + @staticmethod + def _build_job_permission( + job_grants: List[Dict[str, Any]], permission: str + ) -> List[Dict[str, Any]]: + return [{**grant, **{"permission_level": permission}} for grant in job_grants] + + def build_job_permissions( + self, + job_grants: Dict[str, List[Dict[str, Any]]], + acls: List[Dict[str, str]], + ) -> List[Dict[str, Any]]: + """Build the access control list for the job.""" - def _build_job_permissions(self) -> List[Dict[str, Any]]: access_control_list = [] - owner, permissions_attribute = self._build_job_owner() + owner, permissions_attribute = self._get_job_owner_for_config() access_control_list.append( { permissions_attribute: owner, @@ -72,255 +157,355 @@ def _build_job_permissions(self) -> List[Dict[str, Any]]: } ) - for grant in self.job_grants.get("view", []): - acl_grant = grant.copy() - acl_grant.update( - { - "permission_level": "CAN_VIEW", - } - ) - access_control_list.append(acl_grant) - for grant in self.job_grants.get("run", []): - acl_grant = grant.copy() - acl_grant.update( - { - "permission_level": "CAN_MANAGE_RUN", - } - ) - access_control_list.append(acl_grant) - for grant in self.job_grants.get("manage", []): - acl_grant = grant.copy() - acl_grant.update( - { - "permission_level": "CAN_MANAGE", - } - ) - access_control_list.append(acl_grant) + access_control_list.extend( + self._build_job_permission(job_grants.get("view", []), "CAN_VIEW") + ) + access_control_list.extend( + self._build_job_permission(job_grants.get("run", []), "CAN_MANAGE_RUN") + ) + access_control_list.extend( + self._build_job_permission(job_grants.get("manage", []), "CAN_MANAGE") + ) - return access_control_list + return access_control_list + acls - def _build_job_owner(self) -> Tuple[str, str]: - """ - :return: a tuple of the user id and the ACL attribute it came from ie: - [user_name|group_name|service_principal_name] - For example: `("mateizaharia@databricks.com", "user_name")` - """ - curr_user = self.api_client.curr_user.get_username() - is_service_principal = self.api_client.curr_user.is_service_principal(curr_user) - if is_service_principal: - return curr_user, "service_principal_name" +def get_library_config( + packages: List[str], + index_url: Optional[str], + additional_libraries: List[Dict[str, Any]], +) -> Dict[str, Any]: + """Update the job configuration with the required libraries.""" + + libraries = [] + + for package in packages: + if index_url: + libraries.append({"pypi": {"package": package, "repo": index_url}}) else: - return curr_user, "user_name" + libraries.append({"pypi": {"package": package}}) + + for library in additional_libraries: + libraries.append(library) + + return {"libraries": libraries} + + +class PythonJobConfigCompiler: + """Compiles a Python model into a job configuration for Databricks.""" + + def __init__( + self, + api_client: DatabricksApiClient, + permission_builder: PythonPermissionBuilder, + parsed_model: ParsedPythonModel, + cluster_spec: Dict[str, Any], + ) -> None: + self.api_client = api_client + self.permission_builder = permission_builder + self.run_name = parsed_model.run_name + packages = parsed_model.config.packages + index_url = parsed_model.config.index_url + additional_libraries = parsed_model.config.additional_libs + library_config = get_library_config(packages, index_url, additional_libraries) + self.cluster_spec = {**cluster_spec, **library_config} + self.job_grants = parsed_model.config.python_job_config.grants + self.acls = parsed_model.config.access_control_list + self.additional_job_settings = parsed_model.config.python_job_config.dict() + + def compile(self, path: str) -> PythonJobDetails: - def _submit_job(self, path: str, cluster_spec: dict) -> str: job_spec: Dict[str, Any] = { "task_key": "inner_notebook", "notebook_task": { "notebook_path": path, }, } - job_spec.update(cluster_spec) # updates 'new_cluster' config - - # PYPI packages - packages = self.parsed_model["config"].get("packages", []) - - # custom index URL or default - index_url = self.parsed_model["config"].get("index_url", None) - - # additional format of packages - additional_libs = self.parsed_model["config"].get("additional_libs", []) - libraries = [] + job_spec.update(self.cluster_spec) # updates 'new_cluster' config - for package in packages: - if index_url: - libraries.append({"pypi": {"package": package, "repo": index_url}}) - else: - libraries.append({"pypi": {"package": package}}) - - for lib in additional_libs: - libraries.append(lib) - - job_spec.update({"libraries": libraries}) - run_name = f"{self.database}-{self.schema}-{self.identifier}-{uuid.uuid4()}" - - additional_job_config = self._build_additional_job_settings() - access_control_list = self._build_job_permissions() - additional_job_config["access_control_list"] = access_control_list - - run_id = self.api_client.job_runs.submit(run_name, job_spec, **additional_job_config) - self.tracker.insert_run_id(run_id) - return run_id - - def _build_additional_job_settings(self) -> Dict[str, Any]: - additional_configs = {} - attrs_to_add = [ - "email_notifications", - "webhook_notifications", - "notification_settings", - "timeout_seconds", - "health", - "environments", - ] - for attr in attrs_to_add: - if attr in self.workflow_spec: - additional_configs[attr] = self.workflow_spec[attr] - - return additional_configs - - def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> None: - workdir = self.api_client.workspace.create_python_model_dir( - self.database or "hive_metastore", self.schema + additional_job_config = self.additional_job_settings + access_control_list = self.permission_builder.build_job_permissions( + self.job_grants, self.acls ) - file_path = f"{workdir}{self.identifier}" + if access_control_list: + job_spec["access_control_list"] = access_control_list + + return PythonJobDetails(self.run_name, job_spec, additional_job_config) + + +class PythonNotebookSubmitter(PythonSubmitter): + """Submitter for Python models using the Job API.""" + + def __init__( + self, + api_client: DatabricksApiClient, + tracker: PythonRunTracker, + uploader: PythonNotebookUploader, + config_compiler: PythonJobConfigCompiler, + ) -> None: + self.api_client = api_client + self.tracker = tracker + self.uploader = uploader + self.config_compiler = config_compiler + + @staticmethod + def create( + api_client: DatabricksApiClient, + tracker: PythonRunTracker, + parsed_model: ParsedPythonModel, + cluster_spec: Dict[str, Any], + ) -> "PythonNotebookSubmitter": + notebook_uploader = PythonNotebookUploader(api_client, parsed_model) + permission_builder = PythonPermissionBuilder(api_client) + config_compiler = PythonJobConfigCompiler( + api_client, + permission_builder, + parsed_model, + cluster_spec, + ) + return PythonNotebookSubmitter(api_client, tracker, notebook_uploader, config_compiler) - self.api_client.workspace.upload_notebook(file_path, compiled_code) + @override + def submit(self, compiled_code: str) -> None: + file_path = self.uploader.upload(compiled_code) + job_config = self.config_compiler.compile(file_path) # submit job - run_id = self._submit_job(file_path, cluster_spec) + run_id = self.api_client.job_runs.submit( + job_config.run_name, job_config.job_spec, **job_config.additional_job_config + ) + self.tracker.insert_run_id(run_id) try: self.api_client.job_runs.poll_for_completion(run_id) finally: self.tracker.remove_run_id(run_id) - def submit(self, compiled_code: str) -> None: - raise NotImplementedError( - "BasePythonJobHelper is an abstract class and you should implement submit method." - ) - class JobClusterPythonJobHelper(BaseDatabricksHelper): - def check_credentials(self) -> None: - super().check_credentials() - if not self.parsed_model["config"].get("job_cluster_config", None): + """Top level helper for Python models using job runs on a job cluster.""" + + @override + def build_submitter(self) -> PythonSubmitter: + return PythonNotebookSubmitter.create( + self.api_client, + self.tracker, + self.parsed_model, + {"new_cluster": self.parsed_model.config.job_cluster_config}, + ) + + @override + def validate_config(self) -> None: + if not self.parsed_model.config.job_cluster_config: raise ValueError( "`job_cluster_config` is required for the `job_cluster` submission method." ) - def submit(self, compiled_code: str) -> None: - cluster_spec = {"new_cluster": self.parsed_model["config"]["job_cluster_config"]} - self._submit_through_notebook(compiled_code, self._update_with_acls(cluster_spec)) - class AllPurposeClusterPythonJobHelper(BaseDatabricksHelper): - @property - def cluster_id(self) -> Optional[str]: - return self.parsed_model["config"].get( - "cluster_id", - self.credentials.extract_cluster_id( - self.parsed_model["config"].get("http_path", self.credentials.http_path) - ), + """ + Top level helper for Python models using job runs or Command API on an all-purpose cluster. + """ + + def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: + self.credentials = credentials + self.credentials.validate_creds() + self.parsed_model = ParsedPythonModel(**parsed_model) + + self.api_client = DatabricksApiClient.create( + credentials, + self.parsed_model.config.timeout, + self.parsed_model.config.user_folder_for_python, + ) + + config = self.parsed_model.config + self.create_notebook = config.create_notebook + self.cluster_id = config.cluster_id or self.credentials.extract_cluster_id( + config.http_path or self.credentials.http_path or "" ) + self.validate_config() + + self.command_submitter = self.build_submitter() + + @override + def build_submitter(self) -> PythonSubmitter: + if self.create_notebook: + return PythonNotebookSubmitter.create( + self.api_client, + self.tracker, + self.parsed_model, + {"existing_cluster_id": self.cluster_id}, + ) + else: + return PythonCommandSubmitter(self.api_client, self.tracker, self.cluster_id or "") - def check_credentials(self) -> None: - super().check_credentials() + @override + def validate_config(self) -> None: if not self.cluster_id: raise ValueError( "Databricks `http_path` or `cluster_id` of an all-purpose cluster is required " "for the `all_purpose_cluster` submission method." ) - def submit(self, compiled_code: str) -> None: - assert ( - self.cluster_id is not None - ), "cluster_id is required for all_purpose_cluster submission method." - if self.parsed_model["config"].get("create_notebook", False): - config = {} - if self.cluster_id: - config["existing_cluster_id"] = self.cluster_id - self._submit_through_notebook(compiled_code, self._update_with_acls(config)) - else: - context_id = self.api_client.command_contexts.create(self.cluster_id) - command_exec: Optional[CommandExecution] = None - try: - command_exec = self.api_client.commands.execute( - self.cluster_id, context_id, compiled_code - ) - - self.tracker.insert_command(command_exec) - # poll until job finish - self.api_client.commands.poll_for_completion(command_exec) - - finally: - if command_exec: - self.tracker.remove_command(command_exec) - self.api_client.command_contexts.destroy(self.cluster_id, context_id) - class ServerlessClusterPythonJobHelper(BaseDatabricksHelper): - def submit(self, compiled_code: str) -> None: - self._submit_through_notebook(compiled_code, {}) - - -class WorkflowPythonJobHelper(BaseDatabricksHelper): - - @property - def default_job_name(self) -> str: - return f"dbt__{self.database}-{self.schema}-{self.identifier}" - - @property - def notebook_path(self) -> str: - return f"{self.notebook_dir}/{self.identifier}" - - @property - def notebook_dir(self) -> str: - return self.api_client.workspace.user_api.get_folder(self.catalog, self.schema) - - @property - def catalog(self) -> str: - return self.database or "hive_metastore" - - def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: - super().__init__(parsed_model, credentials) - - self.post_hook_tasks = self.workflow_spec.pop("post_hook_tasks", []) - self.additional_task_settings = self.workflow_spec.pop("additional_task_settings", {}) - - def check_credentials(self) -> None: - workflow_config = self.parsed_model["config"].get("workflow_job_config", None) - if not workflow_config: - raise ValueError( - "workflow_job_config is required for the `workflow_job_config` submission method." + """Top level helper for Python models using job runs on a serverless cluster.""" + + def build_submitter(self) -> PythonSubmitter: + return PythonNotebookSubmitter.create(self.api_client, self.tracker, self.parsed_model, {}) + + +class PythonWorkflowConfigCompiler: + """Compiles a Python model into a workflow configuration for Databricks.""" + + def __init__( + self, + task_settings: Dict[str, Any], + workflow_spec: Dict[str, Any], + existing_job_id: str, + post_hook_tasks: List[Dict[str, Any]], + ) -> None: + self.task_settings = task_settings + self.existing_job_id = existing_job_id + self.workflow_spec = workflow_spec + self.post_hook_tasks = post_hook_tasks + + @staticmethod + def create(parsed_model: ParsedPythonModel) -> "PythonWorkflowConfigCompiler": + cluster_settings = PythonWorkflowConfigCompiler.cluster_settings(parsed_model) + config = parsed_model.config + if config.python_job_config: + cluster_settings.update(config.python_job_config.additional_task_settings) + workflow_spec = config.python_job_config.dict() + workflow_spec["name"] = PythonWorkflowConfigCompiler.workflow_name(parsed_model) + existing_job_id = config.python_job_config.existing_job_id + post_hook_tasks = config.python_job_config.post_hook_tasks + return PythonWorkflowConfigCompiler( + cluster_settings, workflow_spec, existing_job_id, post_hook_tasks ) + else: + return PythonWorkflowConfigCompiler(cluster_settings, {}, "", []) + + @staticmethod + def workflow_name(parsed_model: ParsedPythonModel) -> str: + name: Optional[str] = None + if parsed_model.config.python_job_config: + name = parsed_model.config.python_job_config.name + return ( + name or f"dbt__{parsed_model.catalog}-{parsed_model.schema_}-{parsed_model.identifier}" + ) - def submit(self, compiled_code: str) -> None: - workflow_spec = self._build_job_spec() - self._submit_through_workflow(compiled_code, workflow_spec) + @staticmethod + def cluster_settings(parsed_model: ParsedPythonModel) -> Dict[str, Any]: + config = parsed_model.config + job_cluster_config = config.job_cluster_config - def _build_job_spec(self) -> Dict[str, Any]: - workflow_spec = dict(self.workflow_spec) - workflow_spec["name"] = self.workflow_spec.get("name", self.default_job_name) + cluster_settings: Dict[str, Any] = {} + if job_cluster_config: + cluster_settings["new_cluster"] = job_cluster_config + elif config.cluster_id: + cluster_settings["existing_cluster_id"] = config.cluster_id - # Undefined cluster settings defaults to serverless in the Databricks API - cluster_settings = {} - if self.cluster_spec: - cluster_settings["new_cluster"] = self.cluster_spec - elif "existing_cluster_id" in self.workflow_spec: - cluster_settings["existing_cluster_id"] = self.workflow_spec["existing_cluster_id"] + return cluster_settings + def compile(self, path: str) -> Tuple[Dict[str, Any], str]: notebook_task = { "task_key": "inner_notebook", "notebook_task": { - "notebook_path": self.notebook_path, + "notebook_path": path, "source": "WORKSPACE", }, } - notebook_task.update(cluster_settings) - notebook_task.update(self.additional_task_settings) + notebook_task.update(self.task_settings) - workflow_spec["tasks"] = [notebook_task] + self.post_hook_tasks - return workflow_spec + self.workflow_spec["tasks"] = [notebook_task] + self.post_hook_tasks + return self.workflow_spec, self.existing_job_id - def _submit_through_workflow(self, compiled_code: str, workflow_spec: Dict[str, Any]) -> None: - self.api_client.workspace.create_python_model_dir(self.catalog, self.schema) - self.api_client.workspace.upload_notebook(self.notebook_path, compiled_code) - job_id, is_new = self._get_or_create_job(workflow_spec) +class PythonWorkflowCreator: + """Manages the creation or updating of a workflow job on Databricks.""" - if not is_new: - self.api_client.workflows.update_job_settings(job_id, workflow_spec) + def __init__(self, workflows: WorkflowJobApi) -> None: + self.workflows = workflows - access_control_list = self._build_job_permissions() + def create_or_update( + self, + workflow_spec: Dict[str, Any], + existing_job_id: Optional[str], + ) -> str: + """ + :return: tuple of job_id and whether the job is new + """ + if not existing_job_id: + workflow_name = workflow_spec["name"] + response_jobs = self.workflows.search_by_name(workflow_name) + + if len(response_jobs) > 1: + raise DbtRuntimeError( + f"Multiple jobs found with name {workflow_name}. Use a" + " unique job name or specify the `existing_job_id` in the python_job_config." + ) + elif len(response_jobs) == 1: + existing_job_id = response_jobs[0]["job_id"] + else: + return self.workflows.create(workflow_spec) + + assert existing_job_id is not None + self.workflows.update_job_settings(existing_job_id, workflow_spec) + return existing_job_id + + +class PythonNotebookWorkflowSubmitter(PythonSubmitter): + """Submitter for Python models using the Workflow API.""" + + def __init__( + self, + api_client: DatabricksApiClient, + tracker: PythonRunTracker, + uploader: PythonNotebookUploader, + config_compiler: PythonWorkflowConfigCompiler, + permission_builder: PythonPermissionBuilder, + workflow_creater: PythonWorkflowCreator, + job_grants: Dict[str, List[Dict[str, str]]], + acls: List[Dict[str, str]], + ) -> None: + self.api_client = api_client + self.tracker = tracker + self.uploader = uploader + self.config_compiler = config_compiler + self.permission_builder = permission_builder + self.workflow_creater = workflow_creater + self.job_grants = job_grants + self.acls = acls + + @staticmethod + def create( + api_client: DatabricksApiClient, tracker: PythonRunTracker, parsed_model: ParsedPythonModel + ) -> "PythonNotebookWorkflowSubmitter": + uploader = PythonNotebookUploader(api_client, parsed_model) + config_compiler = PythonWorkflowConfigCompiler.create(parsed_model) + permission_builder = PythonPermissionBuilder(api_client) + workflow_creater = PythonWorkflowCreator(api_client.workflows) + return PythonNotebookWorkflowSubmitter( + api_client, + tracker, + uploader, + config_compiler, + permission_builder, + workflow_creater, + parsed_model.config.python_job_config.grants, + parsed_model.config.access_control_list, + ) + + @override + def submit(self, compiled_code: str) -> None: + file_path = self.uploader.upload(compiled_code) + + workflow_config, existing_job_id = self.config_compiler.compile(file_path) + job_id = self.workflow_creater.create_or_update(workflow_config, existing_job_id) + + access_control_list = self.permission_builder.build_job_permissions( + self.job_grants, self.acls + ) self.api_client.workflow_permissions.put(job_id, access_control_list) run_id = self.api_client.workflows.run(job_id, enable_queueing=True) @@ -331,23 +516,12 @@ def _submit_through_workflow(self, compiled_code: str, workflow_spec: Dict[str, finally: self.tracker.remove_run_id(run_id) - def _get_or_create_job(self, workflow_spec: Dict[str, Any]) -> Tuple[str, bool]: - """ - :return: tuple of job_id and whether the job is new - """ - existing_job_id = workflow_spec.pop("existing_job_id", "") - if existing_job_id: - return existing_job_id, False - response_jobs = self.api_client.workflows.search_by_name(workflow_spec["name"]) - - if len(response_jobs) > 1: - raise DbtRuntimeError( - f"""Multiple jobs found with name {workflow_spec['name']}. Use a unique job - name or specify the `existing_job_id` in the workflow_job_config.""" - ) +class WorkflowPythonJobHelper(BaseDatabricksHelper): + """Top level helper for Python models using workflow jobs on Databricks.""" - if len(response_jobs) == 1: - return response_jobs[0]["job_id"], False - else: - return self.api_client.workflows.create(workflow_spec), True + @override + def build_submitter(self) -> PythonSubmitter: + return PythonNotebookWorkflowSubmitter.create( + self.api_client, self.tracker, self.parsed_model + ) diff --git a/docs/workflow-job-submission.md b/docs/workflow-job-submission.md index b22abd3e..8e607801 100644 --- a/docs/workflow-job-submission.md +++ b/docs/workflow-job-submission.md @@ -1,7 +1,7 @@ ## Databricks Workflow Job Submission Use the `workflow_job` submission method to run your python model as a long-lived -Databricks Workflow. Models look the same as they would using the `job_cluster` submission +Databricks Workflow. Models look the same as they would using the `job_cluster` submission method, but allow for additional configuration. Some of that configuration can also be used for `job_cluster` models. @@ -31,55 +31,45 @@ The config for a model could look like: models: - name: my_model config: - workflow_job_config: + python_job_config: # This is also applied to one-time run models - email_notifications: { - on_failure: ["reynoldxin@databricks.com"] - } + email_notifications: { on_failure: ["reynoldxin@databricks.com"] } max_retries: 2 timeout_seconds: 18000 existing_cluster_id: 1234a-123-1234 # Use in place of job_cluster_config or null - + # Name must be unique unless existing_job_id is also defined - name: my_workflow_name + name: my_workflow_name existing_job_id: 12341234 - + # Override settings for your model's dbt task. For instance, you can # change the task key - additional_task_settings: { - "task_key": "my_dbt_task" - } - + additional_task_settings: { "task_key": "my_dbt_task" } + # Define tasks to run before/after the model - post_hook_tasks: [{ - "depends_on": [{ "task_key": "my_dbt_task" }], - "task_key": 'OPTIMIZE_AND_VACUUM', - "notebook_task": { - "notebook_path": "/my_notebook_path", - "source": "WORKSPACE", - }, - }] - + post_hook_tasks: + [ + { + "depends_on": [{ "task_key": "my_dbt_task" }], + "task_key": "OPTIMIZE_AND_VACUUM", + "notebook_task": + { "notebook_path": "/my_notebook_path", "source": "WORKSPACE" }, + }, + ] + # Also applied to one-time run models grants: - view: [ - {"group_name": "marketing-team"}, - ] - run: [ - {"user_name": "alighodsi@databricks.com"} - ] + view: [{ "group_name": "marketing-team" }] + run: [{ "user_name": "alighodsi@databricks.com" }] manage: [] - + # Reused for the workflow job cluster definition job_cluster_config: spark_version: "15.3.x-scala2.12" node_type_id: "rd-fleet.2xlarge" runtime_engine: "{{ var('job_cluster_defaults.runtime_engine') }}" data_security_mode: "{{ var('job_cluster_defaults.data_security_mode') }}" - autoscale: { - "min_workers": 1, - "max_workers": 4 - } + autoscale: { "min_workers": 1, "max_workers": 4 } ``` ### Configuration @@ -89,7 +79,7 @@ that can be set. #### Reuse in job_cluster submission method -If the following values are defined in `config.workflow_job_config`, they will be used even if +If the following values are defined in `config.python_job_config`, they will be used even if the model uses the job_cluster submission method. For example, you can define a job_cluster model to send an email notification on failure. @@ -109,8 +99,8 @@ dbt will generate a name based on the catalog, schema, and model identifier. #### Clusters - If defined, dbt will re-use the `config.job_cluster_config` to define a job cluster for the workflow tasks. -- If `config.workflow_job_config.existing_cluster_id` is defined, dbt will use that cluster -- Similarly, you can define a reusable job cluster for the workflow and tell the task to use that +- If `config.python_job_config.existing_cluster_id` is defined, dbt will use that cluster +- Similarly, you can define a reusable job cluster for the workflow and tell the task to use that - If none of those are in the configuration, the task cluster will be serverless ```yaml @@ -118,48 +108,52 @@ dbt will generate a name based on the catalog, schema, and model identifier. models: - name: my_model - + config: - workflow_job_config: - additional_task_settings: { - task_key: 'task_a', - job_cluster_key: 'cluster_a', - } - post_hook_tasks: [{ - depends_on: [{ "task_key": "task_a" }], - task_key: 'OPTIMIZE_AND_VACUUM', - job_cluster_key: 'cluster_a', - notebook_task: { - notebook_path: "/OPTIMIZE_AND_VACUUM", - source: "WORKSPACE", - base_parameters: { - database: "{{ target.database }}", - schema: "{{ target.schema }}", - table_name: "my_model" - } - }, - }] - job_clusters: [{ - job_cluster_key: 'cluster_a', - new_cluster: { - spark_version: "{{ var('dbr_versions')['lts_v14'] }}", - node_type_id: "{{ var('cluster_node_types')['large_job'] }}", - runtime_engine: "{{ var('job_cluster_defaults.runtime_engine') }}", - autoscale: { - "min_workers": 1, - "max_workers": 2 + python_job_config: + additional_task_settings: + { task_key: "task_a", job_cluster_key: "cluster_a" } + post_hook_tasks: + [ + { + depends_on: [{ "task_key": "task_a" }], + task_key: "OPTIMIZE_AND_VACUUM", + job_cluster_key: "cluster_a", + notebook_task: + { + notebook_path: "/OPTIMIZE_AND_VACUUM", + source: "WORKSPACE", + base_parameters: + { + database: "{{ target.database }}", + schema: "{{ target.schema }}", + table_name: "my_model", + }, + }, }, - } - }] + ] + job_clusters: + [ + { + job_cluster_key: "cluster_a", + new_cluster: + { + spark_version: "{{ var('dbr_versions')['lts_v14'] }}", + node_type_id: "{{ var('cluster_node_types')['large_job'] }}", + runtime_engine: "{{ var('job_cluster_defaults.runtime_engine') }}", + autoscale: { "min_workers": 1, "max_workers": 2 }, + }, + }, + ] ``` #### Grants You might want to give certain users or teams access to run your workflows outside of -dbt in an ad hoc way. You can define those permissions in the `workflow_job_config.grants`. +dbt in an ad hoc way. You can define those permissions in the `python_job_config.grants`. The owner will always be the user or service principal creating the workflows. -These grants will also be applied to one-time run models using the `job_cluster` submission +These grants will also be applied to one-time run models using the `job_cluster` submission method. The dbt rules correspond with the following Databricks permissions: @@ -181,6 +175,6 @@ grants: #### Post hooks -It is possible to add in python hooks by using the `config.workflow_job_config.post_hook_tasks` +It is possible to add in python hooks by using the `config.python_job_config.post_hook_tasks` attribute. You will need to define the cluster for each task, or use a reusable one from -`config.workflow_job_config.job_clusters`. \ No newline at end of file +`config.python_job_config.job_clusters`. diff --git a/tests/functional/adapter/python_model/fixtures.py b/tests/functional/adapter/python_model/fixtures.py index ee70339f..5ce51702 100644 --- a/tests/functional/adapter/python_model/fixtures.py +++ b/tests/functional/adapter/python_model/fixtures.py @@ -40,7 +40,7 @@ def model(dbt, spark): config: submission_method: workflow_job user_folder_for_python: true - workflow_job_config: + python_job_config: max_retries: 2 timeout_seconds: 500 additional_task_settings: { diff --git a/tests/unit/python/test_python_config.py b/tests/unit/python/test_python_config.py new file mode 100644 index 00000000..ef450afc --- /dev/null +++ b/tests/unit/python/test_python_config.py @@ -0,0 +1,135 @@ +from pydantic import ValidationError +import pytest +from dbt.adapters.databricks.python_models.python_config import ( + ParsedPythonModel, + PythonJobConfig, + PythonModelConfig, +) + + +class TestParsedPythonModel: + def test_parsed_model__default_database_schema(self): + parsed_model = { + "alias": "test", + "config": {}, + } + + model = ParsedPythonModel(**parsed_model) + assert model.catalog == "hive_metastore" + assert model.schema_ == "default" + assert model.identifier == "test" + + def test_parsed_model__empty_model_config(self): + parsed_model = { + "database": "database", + "schema": "schema", + "alias": "test", + "config": {}, + } + + model = ParsedPythonModel(**parsed_model) + assert model.catalog == "database" + assert model.schema_ == "schema" + assert model.identifier == "test" + config = model.config + assert config.user_folder_for_python is False + assert config.timeout == 86400 + assert config.job_cluster_config == {} + assert config.access_control_list == [] + assert config.packages == [] + assert config.index_url is None + assert config.additional_libs == [] + assert config.python_job_config.name is None + assert config.python_job_config.grants == {} + assert config.python_job_config.existing_job_id == "" + assert config.python_job_config.post_hook_tasks == [] + assert config.python_job_config.additional_task_settings == {} + assert config.cluster_id is None + assert config.http_path is None + assert config.create_notebook is False + + def test_parsed_model__valid_model_config(self): + parsed_model = { + "alias": "test", + "config": { + "user_folder_for_python": True, + "timeout": 100, + "job_cluster_config": {"key": "value"}, + "access_control_list": [{"key": "value"}], + "packages": ["package"], + "index_url": "index_url", + "additional_libs": [{"key": "value"}], + "python_job_config": {"name": "name"}, + "cluster_id": "cluster_id", + "http_path": "http_path", + "create_notebook": True, + }, + } + + model = ParsedPythonModel(**parsed_model) + config = model.config + assert config.user_folder_for_python is True + assert config.timeout == 100 + assert config.job_cluster_config == {"key": "value"} + assert config.access_control_list == [{"key": "value"}] + assert config.packages == ["package"] + assert config.index_url == "index_url" + assert config.additional_libs == [{"key": "value"}] + assert config.python_job_config.name == "name" + assert config.python_job_config.grants == {} + assert config.python_job_config.existing_job_id == "" + assert config.python_job_config.post_hook_tasks == [] + assert config.python_job_config.additional_task_settings == {} + assert config.cluster_id == "cluster_id" + assert config.http_path == "http_path" + assert config.create_notebook is True + + def test_parsed_model__extra_model_config(self): + parsed_model = { + "alias": "test", + "config": { + "python_job_config": {"foo": "bar"}, + }, + } + + model = ParsedPythonModel(**parsed_model) + assert model.config.python_job_config.foo == "bar" + + def test_parsed_model__run_name(self): + parsed_model = {"alias": "test", "config": {}} + model = ParsedPythonModel(**parsed_model) + assert model.run_name.startswith("hive_metastore-default-test-") + + def test_parsed_model__invalid_config(self): + parsed_model = {"alias": "test", "config": []} + with pytest.raises(ValidationError): + ParsedPythonModel(**parsed_model) + + +class TestPythonModelConfig: + def test_python_model_config__invalid_timeout(self): + config = {"timeout": -1} + with pytest.raises(ValidationError): + PythonModelConfig(**config) + + +class TestPythonJobConfig: + def test_python_job_config__dict_excludes_expected_fields(self): + config = { + "name": "name", + "grants": {"view": [{"user": "user"}]}, + "existing_job_id": "existing_job_id", + "post_hook_tasks": [{"task": "task"}], + "additional_task_settings": {"key": "value"}, + } + job_config = PythonJobConfig(**config).dict() + assert job_config == {"name": "name"} + + def test_python_job_config__extra_values(self): + config = { + "name": "name", + "existing_job_id": "existing_job_id", + "foo": "bar", + } + job_config = PythonJobConfig(**config).dict() + assert job_config == {"name": "name", "foo": "bar"} diff --git a/tests/unit/python/test_python_helpers.py b/tests/unit/python/test_python_helpers.py new file mode 100644 index 00000000..41cd3441 --- /dev/null +++ b/tests/unit/python/test_python_helpers.py @@ -0,0 +1,79 @@ +from unittest.mock import Mock +import pytest + +from dbt.adapters.databricks.python_models.python_submissions import ( + AllPurposeClusterPythonJobHelper, + JobClusterPythonJobHelper, + PythonCommandSubmitter, + PythonNotebookSubmitter, + PythonNotebookWorkflowSubmitter, + ServerlessClusterPythonJobHelper, + WorkflowPythonJobHelper, +) + + +@pytest.fixture +def parsed_model(): + return {"alias": "test", "config": {}} + + +@pytest.fixture +def credentials(): + c = Mock() + c.get_all_http_headers.return_value = {} + c.connection_parameters = {} + return c + + +class TestJobClusterPythonJobHelper: + def test_init__golden_path(self, parsed_model, credentials): + parsed_model["config"]["job_cluster_config"] = {"cluster_id": "test"} + helper = JobClusterPythonJobHelper(parsed_model, credentials) + assert isinstance(helper.command_submitter, PythonNotebookSubmitter) + assert helper.command_submitter.config_compiler.cluster_spec == { + "libraries": [], + "new_cluster": {"cluster_id": "test"}, + } + + def test_init__no_cluster_config(self, parsed_model, credentials): + with pytest.raises(ValueError) as exc: + JobClusterPythonJobHelper(parsed_model, credentials) + assert exc.match( + "`job_cluster_config` is required for the `job_cluster` submission method." + ) + + +class TestAllPurposeClusterPythonJobHelper: + def test_init__no_notebook_credential_http_path(self, parsed_model, credentials): + credentials.extract_cluster_id.return_value = "test" + helper = AllPurposeClusterPythonJobHelper(parsed_model, credentials) + assert isinstance(helper.command_submitter, PythonCommandSubmitter) + assert helper.cluster_id == "test" + + def test_init__notebook_cluster_id(self, parsed_model, credentials): + parsed_model["config"] = {"create_notebook": True, "cluster_id": "test"} + helper = AllPurposeClusterPythonJobHelper(parsed_model, credentials) + assert isinstance(helper.command_submitter, PythonNotebookSubmitter) + assert helper.cluster_id == "test" + + def test_init__no_cluster_id(self, parsed_model, credentials): + credentials.extract_cluster_id.return_value = None + with pytest.raises(ValueError) as exc: + AllPurposeClusterPythonJobHelper(parsed_model, credentials) + assert exc.match( + "Databricks `http_path` or `cluster_id` of an all-purpose cluster is required " + "for the `all_purpose_cluster` submission method." + ) + + +class TestServerlessClusterPythonJobHelper: + def test_build_submitter(self, parsed_model, credentials): + helper = ServerlessClusterPythonJobHelper(parsed_model, credentials) + assert isinstance(helper.command_submitter, PythonNotebookSubmitter) + assert helper.command_submitter.config_compiler.cluster_spec == {"libraries": []} + + +class TestWorkflowPythonJobHelper: + def test_init__golden_path(self, parsed_model, credentials): + helper = WorkflowPythonJobHelper(parsed_model, credentials) + assert isinstance(helper.command_submitter, PythonNotebookWorkflowSubmitter) diff --git a/tests/unit/python/test_python_job_support.py b/tests/unit/python/test_python_job_support.py new file mode 100644 index 00000000..0d7acb53 --- /dev/null +++ b/tests/unit/python/test_python_job_support.py @@ -0,0 +1,185 @@ +from unittest.mock import Mock +import pytest + +from dbt.adapters.databricks.python_models import python_submissions +from dbt.adapters.databricks.python_models.python_submissions import ( + PythonJobConfigCompiler, + PythonNotebookUploader, + PythonPermissionBuilder, +) + + +@pytest.fixture +def client(): + return Mock() + + +@pytest.fixture +def compiled_code(): + return "compiled_code" + + +@pytest.fixture +def parsed_model(): + return Mock() + + +class TestPythonNotebookUploader: + @pytest.fixture + def workdir(self): + return "workdir" + + @pytest.fixture + def identifier(self, parsed_model): + return "identifier" + + @pytest.fixture + def uploader(self, client, parsed_model, identifier): + parsed_model.identifier = identifier + return PythonNotebookUploader(client, parsed_model) + + def test_upload__golden_path(self, uploader, client, compiled_code, workdir, identifier): + client.workspace.create_python_model_dir.return_value = workdir + + file_path = uploader.upload(compiled_code) + assert file_path == f"{workdir}{identifier}" + client.workspace.upload_notebook.assert_called_once_with(file_path, compiled_code) + + +class TestPythonPermissionBuilder: + @pytest.fixture + def builder(self, client): + return PythonPermissionBuilder(client) + + def test_build_permission__no_grants_no_acls_user_owner(self, builder, client): + client.curr_user.get_username.return_value = "user" + client.curr_user.is_service_principal.return_value = False + acls = builder.build_job_permissions({}, []) + assert acls == [{"user_name": "user", "permission_level": "IS_OWNER"}] + + def test_build_permission__no_grants_no_acls_sp_owner(self, builder, client): + client.curr_user.get_username.return_value = "user" + client.curr_user.is_service_principal.return_value = True + acls = builder.build_job_permissions({}, []) + assert acls == [{"service_principal_name": "user", "permission_level": "IS_OWNER"}] + + def test_build_permission__grants_no_acls(self, builder, client): + grants = { + "view": [{"user_name": "user1"}], + "run": [{"user_name": "user2"}], + "manage": [{"user_name": "user3"}], + } + client.curr_user.get_username.return_value = "user" + client.curr_user.is_service_principal.return_value = False + + expected = [ + {"user_name": "user", "permission_level": "IS_OWNER"}, + {"user_name": "user1", "permission_level": "CAN_VIEW"}, + {"user_name": "user2", "permission_level": "CAN_MANAGE_RUN"}, + {"user_name": "user3", "permission_level": "CAN_MANAGE"}, + ] + + assert builder.build_job_permissions(grants, []) == expected + + def test_build_permission__grants_and_acls(self, builder, client): + grants = { + "view": [{"user_name": "user1"}], + } + acls = [{"user_name": "user2", "permission_level": "CAN_MANAGE_RUN"}] + client.curr_user.get_username.return_value = "user" + client.curr_user.is_service_principal.return_value = False + + expected = [ + {"user_name": "user", "permission_level": "IS_OWNER"}, + {"user_name": "user1", "permission_level": "CAN_VIEW"}, + {"user_name": "user2", "permission_level": "CAN_MANAGE_RUN"}, + ] + + assert builder.build_job_permissions(grants, acls) == expected + + +class TestGetLibraryConfig: + def test_get_library_config__no_packages_no_libraries(self): + config = python_submissions.get_library_config([], None, []) + assert config == {"libraries": []} + + def test_get_library_config__packages_no_index_no_libraries(self): + config = python_submissions.get_library_config(["package1", "package2"], None, []) + assert config == { + "libraries": [{"pypi": {"package": "package1"}}, {"pypi": {"package": "package2"}}] + } + + def test_get_library_config__packages_index_url_no_libraries(self): + index_url = "http://example.com" + config = python_submissions.get_library_config(["package1", "package2"], index_url, []) + assert config == { + "libraries": [ + {"pypi": {"package": "package1", "repo": index_url}}, + {"pypi": {"package": "package2", "repo": index_url}}, + ] + } + + def test_get_library_config__packages_libraries(self): + config = python_submissions.get_library_config( + ["package1", "package2"], None, [{"pypi": {"package": "package3"}}] + ) + assert config == { + "libraries": [ + {"pypi": {"package": "package1"}}, + {"pypi": {"package": "package2"}}, + {"pypi": {"package": "package3"}}, + ] + } + + +class TestPythonJobConfigCompiler: + @pytest.fixture + def permission_builder(self): + return Mock() + + @pytest.fixture + def run_name(self, parsed_model): + run_name = "run_name" + parsed_model.run_name = run_name + parsed_model.config.packages = [] + parsed_model.config.additional_libs = [] + return run_name + + def test_compile__empty_configs(self, client, permission_builder, parsed_model, run_name): + parsed_model.config.python_job_config.dict.return_value = {} + compiler = PythonJobConfigCompiler(client, permission_builder, parsed_model, {}) + permission_builder.build_job_permissions.return_value = [] + details = compiler.compile("path") + assert details.run_name == run_name + assert details.job_spec == { + "task_key": "inner_notebook", + "notebook_task": { + "notebook_path": "path", + }, + "libraries": [], + } + assert details.additional_job_config == {} + + def test_compile__nonempty_configs(self, client, permission_builder, parsed_model, run_name): + parsed_model.config.packages = ["foo"] + parsed_model.config.index_url = None + parsed_model.config.python_job_config.dict.return_value = {"foo": "bar"} + + permission_builder.build_job_permissions.return_value = [ + {"user_name": "user", "permission_level": "IS_OWNER"} + ] + compiler = PythonJobConfigCompiler( + client, permission_builder, parsed_model, {"cluster_id": "id"} + ) + details = compiler.compile("path") + assert details.run_name == run_name + assert details.job_spec == { + "task_key": "inner_notebook", + "notebook_task": { + "notebook_path": "path", + }, + "cluster_id": "id", + "libraries": [{"pypi": {"package": "foo"}}], + "access_control_list": [{"user_name": "user", "permission_level": "IS_OWNER"}], + } + assert details.additional_job_config == {"foo": "bar"} diff --git a/tests/unit/python/test_python_submissions.py b/tests/unit/python/test_python_submissions.py deleted file mode 100644 index 90283142..00000000 --- a/tests/unit/python/test_python_submissions.py +++ /dev/null @@ -1,249 +0,0 @@ -from mock import patch -from unittest.mock import Mock - -from dbt.adapters.databricks.credentials import DatabricksCredentials -from dbt.adapters.databricks.python_models.python_submissions import BaseDatabricksHelper -from dbt.adapters.databricks.python_models.python_submissions import WorkflowPythonJobHelper - - -# class TestDatabricksPythonSubmissions: -# def test_start_cluster_returns_on_receiving_running_state(self): -# session_mock = Mock() -# # Mock the start command -# post_mock = Mock() -# post_mock.status_code = 200 -# session_mock.post.return_value = post_mock -# # Mock the status command -# get_mock = Mock() -# get_mock.status_code = 200 -# get_mock.json.return_value = {"state": "RUNNING"} -# session_mock.get.return_value = get_mock - -# context = DBContext(Mock(), None, None, session_mock) -# context.start_cluster() - -# session_mock.get.assert_called_once() - - -class DatabricksTestHelper(BaseDatabricksHelper): - def __init__(self, parsed_model: dict, credentials: DatabricksCredentials): - self.parsed_model = parsed_model - self.credentials = credentials - self.job_grants = self.workflow_spec.get("grants", {}) - - -class TestAclUpdate: - def test_empty_acl_empty_config(self): - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - assert helper._update_with_acls({}) == {} - - def test_empty_acl_non_empty_config(self): - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - assert helper._update_with_acls({"a": "b"}) == {"a": "b"} - - def test_non_empty_acl_empty_config(self): - expected_access_control = { - "access_control_list": [ - {"user_name": "user2", "permission_level": "CAN_VIEW"}, - ] - } - helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials()) - assert helper._update_with_acls({}) == expected_access_control - - def test_non_empty_acl_non_empty_config(self): - expected_access_control = { - "access_control_list": [ - {"user_name": "user2", "permission_level": "CAN_VIEW"}, - ] - } - helper = DatabricksTestHelper({"config": expected_access_control}, DatabricksCredentials()) - assert helper._update_with_acls({"a": "b"}) == { - "a": "b", - "access_control_list": expected_access_control["access_control_list"], - } - - -class TestJobGrants: - - @patch.object(BaseDatabricksHelper, "_build_job_owner") - def test_job_owner_user(self, mock_job_owner): - mock_job_owner.return_value = ("alighodsi@databricks.com", "user_name") - - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - helper.job_grants = {} - - assert helper._build_job_permissions() == [ - { - "permission_level": "IS_OWNER", - "user_name": "alighodsi@databricks.com", - } - ] - - @patch.object(BaseDatabricksHelper, "_build_job_owner") - def test_job_owner_service_principal(self, mock_job_owner): - mock_job_owner.return_value = ( - "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - "service_principal_name", - ) - - helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) - helper.job_grants = {} - - assert helper._build_job_permissions() == [ - { - "permission_level": "IS_OWNER", - "service_principal_name": "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - } - ] - - @patch.object(BaseDatabricksHelper, "_build_job_owner") - def test_job_grants(self, mock_job_owner): - mock_job_owner.return_value = ( - "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - "service_principal_name", - ) - helper = DatabricksTestHelper( - { - "config": { - "workflow_job_config": { - "grants": { - "view": [ - {"user_name": "reynoldxin@databricks.com"}, - {"user_name": "alighodsi@databricks.com"}, - ], - "run": [{"group_name": "dbt-developers"}], - "manage": [{"group_name": "dbt-admins"}], - } - } - } - }, - DatabricksCredentials(), - ) - - actual = helper._build_job_permissions() - - expected_owner = { - "service_principal_name": "9533b8cc-2d60-46dd-84f2-a39b3939e37a", - "permission_level": "IS_OWNER", - } - expected_viewer_1 = { - "permission_level": "CAN_VIEW", - "user_name": "reynoldxin@databricks.com", - } - expected_viewer_2 = { - "permission_level": "CAN_VIEW", - "user_name": "alighodsi@databricks.com", - } - expected_runner = {"permission_level": "CAN_MANAGE_RUN", "group_name": "dbt-developers"} - expected_manager = {"permission_level": "CAN_MANAGE", "group_name": "dbt-admins"} - - assert expected_owner in actual - assert expected_viewer_1 in actual - assert expected_viewer_2 in actual - assert expected_runner in actual - assert expected_manager in actual - - -class TestWorkflowConfig: - def default_config(self): - return { - "alias": "test_model", - "database": "test_database", - "schema": "test_schema", - "config": { - "workflow_job_config": { - "email_notifications": "test@example.com", - "max_retries": 2, - "timeout_seconds": 500, - }, - "job_cluster_config": { - "spark_version": "15.3.x-scala2.12", - "node_type_id": "rd-fleet.2xlarge", - "autoscale": {"min_workers": 1, "max_workers": 2}, - }, - }, - } - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_default(self, mock_api_client): - job = WorkflowPythonJobHelper(self.default_config(), Mock()) - result = job._build_job_spec() - - assert result["name"] == "dbt__test_database-test_schema-test_model" - assert len(result["tasks"]) == 1 - - task = result["tasks"][0] - assert task["task_key"] == "inner_notebook" - assert task["new_cluster"]["spark_version"] == "15.3.x-scala2.12" - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_custom_name(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["name"] = "custom_job_name" - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - assert result["name"] == "custom_job_name" - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_existing_cluster(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["existing_cluster_id"] = "cluster-123" - del config["config"]["job_cluster_config"] - - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - task = result["tasks"][0] - assert task["existing_cluster_id"] == "cluster-123" - assert "new_cluster" not in task - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_serverless(self, mock_api_client): - config = self.default_config() - del config["config"]["job_cluster_config"] - - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - task = result["tasks"][0] - assert "existing_cluster_id" not in task - assert "new_cluster" not in task - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_with_additional_task_settings(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["additional_task_settings"] = { - "task_key": "my_dbt_task" - } - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - task = result["tasks"][0] - assert task["task_key"] == "my_dbt_task" - - @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") - def test_build_job_spec_with_post_hooks(self, mock_api_client): - config = self.default_config() - config["config"]["workflow_job_config"]["post_hook_tasks"] = [ - { - "depends_on": [{"task_key": "inner_notebook"}], - "task_key": "task_b", - "notebook_task": { - "notebook_path": "/Workspace/Shared/test_notebook", - "source": "WORKSPACE", - }, - "new_cluster": { - "spark_version": "14.3.x-scala2.12", - "node_type_id": "rd-fleet.2xlarge", - "autoscale": {"min_workers": 1, "max_workers": 2}, - }, - } - ] - - job = WorkflowPythonJobHelper(config, Mock()) - result = job._build_job_spec() - - assert len(result["tasks"]) == 2 - assert result["tasks"][1]["task_key"] == "task_b" - assert result["tasks"][1]["new_cluster"]["spark_version"] == "14.3.x-scala2.12" diff --git a/tests/unit/python/test_python_submitters.py b/tests/unit/python/test_python_submitters.py new file mode 100644 index 00000000..ad37839d --- /dev/null +++ b/tests/unit/python/test_python_submitters.py @@ -0,0 +1,172 @@ +from unittest.mock import Mock +import pytest + +from dbt.adapters.databricks.python_models.python_submissions import ( + PythonCommandSubmitter, + PythonJobConfigCompiler, + PythonJobDetails, + PythonNotebookSubmitter, + PythonNotebookUploader, + PythonNotebookWorkflowSubmitter, + PythonPermissionBuilder, + PythonWorkflowConfigCompiler, + PythonWorkflowCreator, +) + + +@pytest.fixture +def client(): + return Mock() + + +@pytest.fixture +def tracker(): + return Mock() + + +@pytest.fixture +def compiled_code(): + return "compiled_code" + + +@pytest.fixture +def config_compiler(): + compiler = Mock() + compiler.compile.return_value = PythonJobDetails("name", {}, {}) + return compiler + + +@pytest.fixture +def uploader(): + return Mock() + + +class TestPythonCommandSubmitter: + @pytest.fixture + def cluster_id(self): + return "cluster_id" + + @pytest.fixture + def submitter(self, client, tracker, cluster_id, context_id): + client.command_contexts.create.return_value = context_id + return PythonCommandSubmitter(client, tracker, cluster_id) + + @pytest.fixture + def context_id(self): + return "context_id" + + def test_submit__golden_path( + self, submitter, compiled_code, client, cluster_id, context_id, tracker + ): + command_exec = client.commands.execute.return_value + submitter.submit(compiled_code) + client.commands.execute.assert_called_once_with(cluster_id, context_id, compiled_code) + client.commands.poll_for_completion.assert_called_once_with(command_exec) + client.command_contexts.destroy.assert_called_once_with(cluster_id, context_id) + tracker.remove_command.assert_called_once_with(command_exec) + + def test_submit__execute_fails__cleans_up( + self, submitter, compiled_code, client, cluster_id, context_id, tracker + ): + client.commands.execute.side_effect = Exception("error") + with pytest.raises(Exception): + submitter.submit(compiled_code) + client.command_contexts.destroy.assert_called_once_with(cluster_id, context_id) + tracker.remove_command.assert_not_called() + + def test_submit__poll_fails__cleans_up( + self, submitter, compiled_code, client, cluster_id, context_id, tracker + ): + command_exec = client.commands.execute.return_value + client.commands.poll_for_completion.side_effect = Exception("error") + with pytest.raises(Exception): + submitter.submit(compiled_code) + client.command_contexts.destroy.assert_called_once_with(cluster_id, context_id) + tracker.remove_command.assert_called_once_with(command_exec) + + +class TestPythonNotebookSubmitter: + @pytest.fixture + def submitter(self, client, tracker, uploader, config_compiler): + return PythonNotebookSubmitter(client, tracker, uploader, config_compiler) + + @pytest.fixture + def run_id(self, client): + return client.job_runs.submit.return_value + + def test_submit__golden_path(self, submitter, compiled_code, client, tracker, run_id): + submitter.submit(compiled_code) + tracker.insert_run_id.assert_called_once_with(run_id) + client.job_runs.poll_for_completion.assert_called_once_with(run_id) + tracker.remove_run_id.assert_called_once_with(run_id) + + def test_submit__poll_fails__cleans_up(self, submitter, compiled_code, client, tracker, run_id): + client.job_runs.poll_for_completion.side_effect = Exception("error") + with pytest.raises(Exception): + submitter.submit(compiled_code) + tracker.remove_run_id.assert_called_once_with(run_id) + + def test_create__golden_path(self, client, tracker): + parsed_model = Mock() + parsed_model.config.packages = [] + parsed_model.config.additional_libs = [] + cluster_spec = {} + submitter = PythonNotebookSubmitter.create(client, tracker, parsed_model, cluster_spec) + assert submitter.api_client == client + assert submitter.tracker == tracker + assert isinstance(submitter.uploader, PythonNotebookUploader) + assert isinstance(submitter.config_compiler, PythonJobConfigCompiler) + + +class TestPythonNotebookWorkflowSubmitter: + @pytest.fixture + def permission_builder(self): + return Mock() + + @pytest.fixture + def workflow_creater(self): + return Mock() + + @pytest.fixture + def submitter( + self, client, tracker, uploader, config_compiler, permission_builder, workflow_creater + ): + return PythonNotebookWorkflowSubmitter( + client, tracker, uploader, config_compiler, permission_builder, workflow_creater, {}, [] + ) + + def test_submit__golden_path(self, submitter): + submitter.uploader.upload.return_value = "upload_path" + submitter.config_compiler.compile.return_value = ({}, "existing_job_id") + submitter.workflow_creater.create_or_update.return_value = "existing_job_id" + submitter.permission_builder.build_permissions.return_value = [] + submitter.api_client.workflows.run.return_value = "run_id" + submitter.submit(compiled_code) + submitter.tracker.insert_run_id.assert_called_once_with("run_id") + submitter.api_client.job_runs.poll_for_completion.assert_called_once_with("run_id") + submitter.tracker.remove_run_id.assert_called_once_with("run_id") + + def test_submit__poll_fails__cleans_up(self, submitter): + submitter.uploader.upload.return_value = "upload_path" + submitter.config_compiler.compile.return_value = ({}, "existing_job_id") + submitter.workflow_creater.create_or_update.return_value = "existing_job_id" + submitter.permission_builder.build_permissions.return_value = [] + submitter.api_client.workflows.run.return_value = "run_id" + submitter.api_client.job_runs.poll_for_completion.side_effect = Exception("error") + with pytest.raises(Exception): + submitter.submit(compiled_code) + submitter.tracker.remove_run_id.assert_called_once_with("run_id") + + def test_create__golden_path(self, client, tracker): + parsed_model = Mock() + parsed_model.config.python_job_config.grants = {} + parsed_model.config.python_job_config.additional_task_settings = {} + parsed_model.config.python_job_config.dict.return_value = {} + parsed_model.config.access_control_list = [] + submitter = PythonNotebookWorkflowSubmitter.create(client, tracker, parsed_model) + assert submitter.api_client == client + assert submitter.tracker == tracker + assert isinstance(submitter.uploader, PythonNotebookUploader) + assert isinstance(submitter.config_compiler, PythonWorkflowConfigCompiler) + assert isinstance(submitter.permission_builder, PythonPermissionBuilder) + assert isinstance(submitter.workflow_creater, PythonWorkflowCreator) diff --git a/tests/unit/python/test_python_workflow_support.py b/tests/unit/python/test_python_workflow_support.py new file mode 100644 index 00000000..2e2e0941 --- /dev/null +++ b/tests/unit/python/test_python_workflow_support.py @@ -0,0 +1,142 @@ +from unittest.mock import Mock +import pytest + +from dbt.adapters.databricks.python_models.python_submissions import ( + PythonWorkflowConfigCompiler, + PythonWorkflowCreator, +) + + +class TestPythonWorkflowConfigCompiler: + @pytest.fixture + def parsed_model(self): + model = Mock() + model.catalog = "catalog" + model.schema_ = "schema" + model.identifier = "identifier" + return model + + def test_workflow_name__no_config(self, parsed_model): + parsed_model.config.python_job_config = None + assert ( + PythonWorkflowConfigCompiler.workflow_name(parsed_model) + == "dbt__catalog-schema-identifier" + ) + + def test_workflow_name__config_without_name(self, parsed_model): + parsed_model.config.python_job_config = {} + assert ( + PythonWorkflowConfigCompiler.workflow_name(parsed_model) + == "dbt__catalog-schema-identifier" + ) + + def test_workflow_name__config_with_name(self, parsed_model): + parsed_model.config.python_job_config.name = "test" + assert PythonWorkflowConfigCompiler.workflow_name(parsed_model) == "test" + + def test_cluster_settings__no_cluster_id(self, parsed_model): + parsed_model.config.job_cluster_config = None + parsed_model.config.cluster_id = None + assert PythonWorkflowConfigCompiler.cluster_settings(parsed_model) == {} + + def test_cluster_settings__no_job_cluster_config(self, parsed_model): + parsed_model.config.job_cluster_config = None + parsed_model.config.cluster_id = "test" + assert PythonWorkflowConfigCompiler.cluster_settings(parsed_model) == { + "existing_cluster_id": "test" + } + + def test_cluster_settings__job_cluster_config(self, parsed_model): + parsed_model.config.job_cluster_config = {"foo": "bar"} + assert PythonWorkflowConfigCompiler.cluster_settings(parsed_model) == { + "new_cluster": {"foo": "bar"} + } + + def test_compile__golden_path(self): + workflow_settings = {"foo": "bar"} + workflow_spec = {"baz": "qux"} + post_hook_tasks = [{"task_key": "post_hook"}] + compiler = PythonWorkflowConfigCompiler( + workflow_settings, workflow_spec, "existing_job_id", post_hook_tasks + ) + path = "path" + assert compiler.compile(path) == ( + { + "tasks": [ + { + "task_key": "inner_notebook", + "notebook_task": {"notebook_path": path, "source": "WORKSPACE"}, + "foo": "bar", + } + ] + + post_hook_tasks, + "baz": "qux", + }, + "existing_job_id", + ) + + def test_create__no_python_job_config(self, parsed_model): + parsed_model.config.python_job_config = None + parsed_model.config.job_cluster_config = None + parsed_model.config.cluster_id = "test" + compiler = PythonWorkflowConfigCompiler.create(parsed_model) + assert compiler.task_settings == {"existing_cluster_id": "test"} + assert compiler.workflow_spec == {} + assert compiler.existing_job_id == "" + assert compiler.post_hook_tasks == [] + + def test_create__python_job_config(self, parsed_model): + parsed_model.config.python_job_config.dict.return_value = {"bar": "baz"} + parsed_model.config.python_job_config.additional_task_settings = {"foo": "bar"} + parsed_model.config.python_job_config.existing_job_id = "test" + parsed_model.config.python_job_config.name = "name" + parsed_model.config.python_job_config.post_hook_tasks = [{"task_key": "post_hook"}] + parsed_model.config.job_cluster_config = None + parsed_model.config.cluster_id = None + compiler = PythonWorkflowConfigCompiler.create(parsed_model) + assert compiler.task_settings == {"foo": "bar"} + assert compiler.workflow_spec == {"name": "name", "bar": "baz"} + assert compiler.existing_job_id == "test" + assert compiler.post_hook_tasks == [{"task_key": "post_hook"}] + + +class TestPythonWorkflowCreator: + @pytest.fixture + def workflows(self): + return Mock() + + @pytest.fixture + def workflow_spec(self): + return {"name": "bar"} + + @pytest.fixture + def existing_job_id(self): + return "test" + + @pytest.fixture + def creator(self, workflows): + return PythonWorkflowCreator(workflows) + + def test_create_or_update__existing_job_id( + self, creator, workflows, workflow_spec, existing_job_id + ): + job_id = creator.create_or_update(workflow_spec, existing_job_id) + assert job_id == existing_job_id + workflows.update_job_settings.assert_called_once_with(existing_job_id, workflow_spec) + + def test_create_or_update__no_existing_job_creates_one(self, creator, workflows, workflow_spec): + workflows.search_by_name.return_value = [] + workflows.create.return_value = "job_id" + + job_id = creator.create_or_update(workflow_spec, "") + assert job_id == "job_id" + workflows.create.assert_called_once_with(workflow_spec) + + def test_create_or_update__existing_job(self, creator, workflows, workflow_spec): + workflows.search_by_name.return_value = [{"job_id": "job_id"}] + + job_id = creator.create_or_update(workflow_spec, "") + assert job_id == "job_id" + workflows.create.assert_not_called() + workflows.search_by_name.assert_called_once_with("bar") + workflows.update_job_settings.assert_called_once_with("job_id", workflow_spec)