diff --git a/CHANGELOG.md b/CHANGELOG.md index 4cbe3659e..4c4e5b2f7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ - Allow for the use of custom constraints, using the `custom` constraint type with an `expression` as the constraint (thanks @roydobbe). ([792](https://github.com/databricks/dbt-databricks/pull/792)) - Add "use_info_schema_for_columns" behavior flag to turn on use of information_schema to get column info where possible. This may have more latency but will not truncate complex data types the way that 'describe' can. ([808](https://github.com/databricks/dbt-databricks/pull/808)) - Add support for table_format: iceberg. This uses UniForm under the hood to provide iceberg compatibility for tables or incrementals. ([815](https://github.com/databricks/dbt-databricks/pull/815)) +- Add a new `workflow_job` submission method for python, which creates a long-lived Databricks Workflow instead of a one-time run (thanks @kdazzle!) ([762](https://github.com/databricks/dbt-databricks/pull/762)) +- Allow for additional options to be passed to the Databricks Job API when using other python submission methods. For example, enable email_notifications (thanks @kdazzle!) ([762](https://github.com/databricks/dbt-databricks/pull/762)) ### Under the Hood diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 7928880e7..893c09255 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -3,9 +3,11 @@ from abc import ABC from abc import abstractmethod from dataclasses import dataclass +import re from typing import Any from typing import Callable from typing import Dict +from typing import List from typing import Optional from typing import Set @@ -41,6 +43,11 @@ def post( ) -> Response: return self.session.post(f"{self.prefix}{suffix}", json=json, params=params) + def put( + self, suffix: str = "", json: Optional[Any] = None, params: Optional[Dict[str, Any]] = None + ) -> Response: + return self.session.put(f"{self.prefix}{suffix}", json=json, params=params) + class DatabricksApi(ABC): def __init__(self, session: Session, host: str, api: str): @@ -142,20 +149,38 @@ def get_folder(self, _: str, schema: str) -> str: return f"/Shared/dbt_python_models/{schema}/" -# Switch to this as part of 2.0.0 release -class UserFolderApi(DatabricksApi, FolderApi): +class CurrUserApi(DatabricksApi): + def __init__(self, session: Session, host: str): super().__init__(session, host, "/api/2.0/preview/scim/v2") self._user = "" - def get_folder(self, catalog: str, schema: str) -> str: - if not self._user: - response = self.session.get("/Me") + def get_username(self) -> str: + if self._user: + return self._user - if response.status_code != 200: - raise DbtRuntimeError(f"Error getting user folder.\n {response.content!r}") - self._user = response.json()["userName"] - folder = f"/Users/{self._user}/dbt_python_models/{catalog}/{schema}/" + response = self.session.get("/Me") + if response.status_code != 200: + raise DbtRuntimeError(f"Error getting current user.\n {response.content!r}") + + username = response.json()["userName"] + self._user = username + return username + + def is_service_principal(self, username: str) -> bool: + uuid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" + return bool(re.match(uuid_pattern, username, re.IGNORECASE)) + + +# Switch to this as part of 2.0.0 release +class UserFolderApi(DatabricksApi, FolderApi): + def __init__(self, session: Session, host: str, user_api: CurrUserApi): + super().__init__(session, host, "/api/2.0/preview/scim/v2") + self.user_api = user_api + + def get_folder(self, catalog: str, schema: str) -> str: + username = self.user_api.get_username() + folder = f"/Users/{username}/dbt_python_models/{catalog}/{schema}/" logger.debug(f"Using python model folder '{folder}'") return folder @@ -302,9 +327,11 @@ class JobRunsApi(PollableApi): def __init__(self, session: Session, host: str, polling_interval: int, timeout: int): super().__init__(session, host, "/api/2.1/jobs/runs", polling_interval, timeout) - def submit(self, run_name: str, job_spec: Dict[str, Any]) -> str: + def submit( + self, run_name: str, job_spec: Dict[str, Any], **additional_job_settings: Dict[str, Any] + ) -> str: submit_response = self.session.post( - "/submit", json={"run_name": run_name, "tasks": [job_spec]} + "/submit", json={"run_name": run_name, "tasks": [job_spec], **additional_job_settings} ) if submit_response.status_code != 200: raise DbtRuntimeError(f"Error creating python run.\n {submit_response.content!r}") @@ -357,6 +384,87 @@ def cancel(self, run_id: str) -> None: raise DbtRuntimeError(f"Cancel run {run_id} failed.\n {response.content!r}") +class JobPermissionsApi(DatabricksApi): + def __init__(self, session: Session, host: str): + super().__init__(session, host, "/api/2.0/permissions/jobs") + + def put(self, job_id: str, access_control_list: List[Dict[str, Any]]) -> None: + request_body = {"access_control_list": access_control_list} + + response = self.session.put(f"/{job_id}", json=request_body) + logger.debug(f"Workflow permissions update response={response.json()}") + + if response.status_code != 200: + raise DbtRuntimeError(f"Error updating Databricks workflow.\n {response.content!r}") + + def get(self, job_id: str) -> Dict[str, Any]: + response = self.session.get(f"/{job_id}") + + if response.status_code != 200: + raise DbtRuntimeError( + f"Error fetching Databricks workflow permissions.\n {response.content!r}" + ) + + return response.json() + + +class WorkflowJobApi(DatabricksApi): + + def __init__(self, session: Session, host: str): + super().__init__(session, host, "/api/2.1/jobs") + + def search_by_name(self, job_name: str) -> List[Dict[str, Any]]: + response = self.session.get("/list", json={"name": job_name}) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error fetching job by name.\n {response.content!r}") + + return response.json().get("jobs", []) + + def create(self, job_spec: Dict[str, Any]) -> str: + """ + :return: the job_id + """ + response = self.session.post("/create", json=job_spec) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error creating Workflow.\n {response.content!r}") + + job_id = response.json()["job_id"] + logger.info(f"New workflow created with job id {job_id}") + return job_id + + def update_job_settings(self, job_id: str, job_spec: Dict[str, Any]) -> None: + request_body = { + "job_id": job_id, + "new_settings": job_spec, + } + logger.debug(f"Job settings: {request_body}") + response = self.session.post("/reset", json=request_body) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error updating Workflow.\n {response.content!r}") + + logger.debug(f"Workflow update response={response.json()}") + + def run(self, job_id: str, enable_queueing: bool = True) -> str: + request_body = { + "job_id": job_id, + "queue": { + "enabled": enable_queueing, + }, + } + response = self.session.post("/run-now", json=request_body) + + if response.status_code != 200: + raise DbtRuntimeError(f"Error triggering run for workflow.\n {response.content!r}") + + response_json = response.json() + logger.info(f"Workflow trigger response={response_json}") + + return response_json["run_id"] + + class DatabricksApiClient: def __init__( self, @@ -368,13 +476,16 @@ def __init__( ): self.clusters = ClusterApi(session, host) self.command_contexts = CommandContextApi(session, host, self.clusters) + self.curr_user = CurrUserApi(session, host) if use_user_folder: - self.folders: FolderApi = UserFolderApi(session, host) + self.folders: FolderApi = UserFolderApi(session, host, self.curr_user) else: self.folders = SharedFolderApi() self.workspace = WorkspaceApi(session, host, self.folders) self.commands = CommandApi(session, host, polling_interval, timeout) self.job_runs = JobRunsApi(session, host, polling_interval, timeout) + self.workflows = WorkflowJobApi(session, host) + self.workflow_permissions = JobPermissionsApi(session, host) @staticmethod def create( diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 6efddb072..24117c130 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -55,6 +55,9 @@ from dbt.adapters.databricks.python_models.python_submissions import ( ServerlessClusterPythonJobHelper, ) +from dbt.adapters.databricks.python_models.python_submissions import ( + WorkflowPythonJobHelper, +) from dbt.adapters.databricks.relation import DatabricksRelation from dbt.adapters.databricks.relation import DatabricksRelationType from dbt.adapters.databricks.relation import KEY_TABLE_PROVIDER @@ -635,6 +638,7 @@ def python_submission_helpers(self) -> Dict[str, Type[PythonJobHelper]]: "job_cluster": JobClusterPythonJobHelper, "all_purpose_cluster": AllPurposeClusterPythonJobHelper, "serverless_cluster": ServerlessClusterPythonJobHelper, + "workflow_job": WorkflowPythonJobHelper, } @available diff --git a/dbt/adapters/databricks/python_models/python_submissions.py b/dbt/adapters/databricks/python_models/python_submissions.py index eb017fc23..de02f4731 100644 --- a/dbt/adapters/databricks/python_models/python_submissions.py +++ b/dbt/adapters/databricks/python_models/python_submissions.py @@ -1,13 +1,16 @@ import uuid from typing import Any from typing import Dict +from typing import List from typing import Optional +from typing import Tuple from dbt.adapters.base import PythonJobHelper from dbt.adapters.databricks.api_client import CommandExecution from dbt.adapters.databricks.api_client import DatabricksApiClient from dbt.adapters.databricks.credentials import DatabricksCredentials from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker +from dbt_common.exceptions import DbtRuntimeError DEFAULT_TIMEOUT = 60 * 60 * 24 @@ -16,6 +19,18 @@ class BaseDatabricksHelper(PythonJobHelper): tracker = PythonRunTracker() + @property + def workflow_spec(self) -> Dict[str, Any]: + """ + The workflow gets modified throughout. Settings added through dbt are popped off + before the spec is sent to the Databricks API + """ + return self.parsed_model["config"].get("workflow_job_config", {}) + + @property + def cluster_spec(self) -> Dict[str, Any]: + return self.parsed_model["config"].get("job_cluster_config", {}) + def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: self.credentials = credentials self.identifier = parsed_model["alias"] @@ -30,6 +45,8 @@ def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> No credentials, self.get_timeout(), use_user_folder ) + self.job_grants: Dict[str, List[Dict[str, Any]]] = self.workflow_spec.pop("grants", {}) + def get_timeout(self) -> int: timeout = self.parsed_model["config"].get("timeout", DEFAULT_TIMEOUT) if timeout <= 0: @@ -45,6 +62,57 @@ def _update_with_acls(self, cluster_dict: dict) -> dict: cluster_dict.update({"access_control_list": acl}) return cluster_dict + def _build_job_permissions(self) -> List[Dict[str, Any]]: + access_control_list = [] + owner, permissions_attribute = self._build_job_owner() + access_control_list.append( + { + permissions_attribute: owner, + "permission_level": "IS_OWNER", + } + ) + + for grant in self.job_grants.get("view", []): + acl_grant = grant.copy() + acl_grant.update( + { + "permission_level": "CAN_VIEW", + } + ) + access_control_list.append(acl_grant) + for grant in self.job_grants.get("run", []): + acl_grant = grant.copy() + acl_grant.update( + { + "permission_level": "CAN_MANAGE_RUN", + } + ) + access_control_list.append(acl_grant) + for grant in self.job_grants.get("manage", []): + acl_grant = grant.copy() + acl_grant.update( + { + "permission_level": "CAN_MANAGE", + } + ) + access_control_list.append(acl_grant) + + return access_control_list + + def _build_job_owner(self) -> Tuple[str, str]: + """ + :return: a tuple of the user id and the ACL attribute it came from ie: + [user_name|group_name|service_principal_name] + For example: `("mateizaharia@databricks.com", "user_name")` + """ + curr_user = self.api_client.curr_user.get_username() + is_service_principal = self.api_client.curr_user.is_service_principal(curr_user) + + if is_service_principal: + return curr_user, "service_principal_name" + else: + return curr_user, "user_name" + def _submit_job(self, path: str, cluster_spec: dict) -> str: job_spec: Dict[str, Any] = { "task_key": "inner_notebook", @@ -76,10 +144,30 @@ def _submit_job(self, path: str, cluster_spec: dict) -> str: job_spec.update({"libraries": libraries}) run_name = f"{self.database}-{self.schema}-{self.identifier}-{uuid.uuid4()}" - run_id = self.api_client.job_runs.submit(run_name, job_spec) + additional_job_config = self._build_additional_job_settings() + access_control_list = self._build_job_permissions() + additional_job_config["access_control_list"] = access_control_list + + run_id = self.api_client.job_runs.submit(run_name, job_spec, **additional_job_config) self.tracker.insert_run_id(run_id) return run_id + def _build_additional_job_settings(self) -> Dict[str, Any]: + additional_configs = {} + attrs_to_add = [ + "email_notifications", + "webhook_notifications", + "notification_settings", + "timeout_seconds", + "health", + "environments", + ] + for attr in attrs_to_add: + if attr in self.workflow_spec: + additional_configs[attr] = self.workflow_spec[attr] + + return additional_configs + def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> None: workdir = self.api_client.workspace.create_python_model_dir( self.database or "hive_metastore", self.schema @@ -162,3 +250,104 @@ def submit(self, compiled_code: str) -> None: class ServerlessClusterPythonJobHelper(BaseDatabricksHelper): def submit(self, compiled_code: str) -> None: self._submit_through_notebook(compiled_code, {}) + + +class WorkflowPythonJobHelper(BaseDatabricksHelper): + + @property + def default_job_name(self) -> str: + return f"dbt__{self.database}-{self.schema}-{self.identifier}" + + @property + def notebook_path(self) -> str: + return f"{self.notebook_dir}/{self.identifier}" + + @property + def notebook_dir(self) -> str: + return self.api_client.workspace.user_api.get_folder(self.catalog, self.schema) + + @property + def catalog(self) -> str: + return self.database or "hive_metastore" + + def __init__(self, parsed_model: Dict, credentials: DatabricksCredentials) -> None: + super().__init__(parsed_model, credentials) + + self.post_hook_tasks = self.workflow_spec.pop("post_hook_tasks", []) + self.additional_task_settings = self.workflow_spec.pop("additional_task_settings", {}) + + def check_credentials(self) -> None: + workflow_config = self.parsed_model["config"].get("workflow_job_config", None) + if not workflow_config: + raise ValueError( + "workflow_job_config is required for the `workflow_job_config` submission method." + ) + + def submit(self, compiled_code: str) -> None: + workflow_spec = self._build_job_spec() + self._submit_through_workflow(compiled_code, workflow_spec) + + def _build_job_spec(self) -> Dict[str, Any]: + workflow_spec = dict(self.workflow_spec) + workflow_spec["name"] = self.workflow_spec.get("name", self.default_job_name) + + # Undefined cluster settings defaults to serverless in the Databricks API + cluster_settings = {} + if self.cluster_spec: + cluster_settings["new_cluster"] = self.cluster_spec + elif "existing_cluster_id" in self.workflow_spec: + cluster_settings["existing_cluster_id"] = self.workflow_spec["existing_cluster_id"] + + notebook_task = { + "task_key": "inner_notebook", + "notebook_task": { + "notebook_path": self.notebook_path, + "source": "WORKSPACE", + }, + } + notebook_task.update(cluster_settings) + notebook_task.update(self.additional_task_settings) + + workflow_spec["tasks"] = [notebook_task] + self.post_hook_tasks + return workflow_spec + + def _submit_through_workflow(self, compiled_code: str, workflow_spec: Dict[str, Any]) -> None: + self.api_client.workspace.create_python_model_dir(self.catalog, self.schema) + self.api_client.workspace.upload_notebook(self.notebook_path, compiled_code) + + job_id, is_new = self._get_or_create_job(workflow_spec) + + if not is_new: + self.api_client.workflows.update_job_settings(job_id, workflow_spec) + + access_control_list = self._build_job_permissions() + self.api_client.workflow_permissions.put(job_id, access_control_list) + + run_id = self.api_client.workflows.run(job_id, enable_queueing=True) + self.tracker.insert_run_id(run_id) + + try: + self.api_client.job_runs.poll_for_completion(run_id) + finally: + self.tracker.remove_run_id(run_id) + + def _get_or_create_job(self, workflow_spec: Dict[str, Any]) -> Tuple[str, bool]: + """ + :return: tuple of job_id and whether the job is new + """ + existing_job_id = workflow_spec.pop("existing_job_id", "") + if existing_job_id: + return existing_job_id, False + + response_jobs = self.api_client.workflows.search_by_name(workflow_spec["name"]) + + if len(response_jobs) > 1: + raise DbtRuntimeError( + f"""Multiple jobs found with name {workflow_spec['name']}. Use a unique job + name or specify the `existing_job_id` in the workflow_job_config.""" + ) + + if len(response_jobs) == 1: + return response_jobs[0]["job_id"], False + else: + return self.api_client.workflows.create(workflow_spec), True diff --git a/docs/workflow-job-submission.md b/docs/workflow-job-submission.md new file mode 100644 index 000000000..b22abd3ef --- /dev/null +++ b/docs/workflow-job-submission.md @@ -0,0 +1,186 @@ +## Databricks Workflow Job Submission + +Use the `workflow_job` submission method to run your python model as a long-lived +Databricks Workflow. Models look the same as they would using the `job_cluster` submission +method, but allow for additional configuration. + +Some of that configuration can also be used for `job_cluster` models. + +```python +# my_model.py +import pyspark.sql.types as T +import pyspark.sql.functions as F + + +def model(dbt, session): + dbt.config( + materialized='incremental', + submission_method='workflow_job' + ) + + output_schema = T.StructType([ + T.StructField("id", T.StringType(), True), + T.StructField("timestamp", T.TimestampType(), True), + ]) + return spark.createDataFrame(data=spark.sparkContext.emptyRDD(), schema=output_schema) +``` + +The config for a model could look like: + +```yaml +models: + - name: my_model + config: + workflow_job_config: + # This is also applied to one-time run models + email_notifications: { + on_failure: ["reynoldxin@databricks.com"] + } + max_retries: 2 + timeout_seconds: 18000 + existing_cluster_id: 1234a-123-1234 # Use in place of job_cluster_config or null + + # Name must be unique unless existing_job_id is also defined + name: my_workflow_name + existing_job_id: 12341234 + + # Override settings for your model's dbt task. For instance, you can + # change the task key + additional_task_settings: { + "task_key": "my_dbt_task" + } + + # Define tasks to run before/after the model + post_hook_tasks: [{ + "depends_on": [{ "task_key": "my_dbt_task" }], + "task_key": 'OPTIMIZE_AND_VACUUM', + "notebook_task": { + "notebook_path": "/my_notebook_path", + "source": "WORKSPACE", + }, + }] + + # Also applied to one-time run models + grants: + view: [ + {"group_name": "marketing-team"}, + ] + run: [ + {"user_name": "alighodsi@databricks.com"} + ] + manage: [] + + # Reused for the workflow job cluster definition + job_cluster_config: + spark_version: "15.3.x-scala2.12" + node_type_id: "rd-fleet.2xlarge" + runtime_engine: "{{ var('job_cluster_defaults.runtime_engine') }}" + data_security_mode: "{{ var('job_cluster_defaults.data_security_mode') }}" + autoscale: { + "min_workers": 1, + "max_workers": 4 + } +``` + +### Configuration + +All config values are optional. See the Databricks Jobs API for the full list of attributes +that can be set. + +#### Reuse in job_cluster submission method + +If the following values are defined in `config.workflow_job_config`, they will be used even if +the model uses the job_cluster submission method. For example, you can define a job_cluster model +to send an email notification on failure. + +- grants +- email_notifications +- webhook_notifications +- notification_settings +- timeout_seconds +- health +- environments + +#### Workflow name + +The name of the workflow must be unique unless you also define an existing job id. By default, +dbt will generate a name based on the catalog, schema, and model identifier. + +#### Clusters + +- If defined, dbt will re-use the `config.job_cluster_config` to define a job cluster for the workflow tasks. +- If `config.workflow_job_config.existing_cluster_id` is defined, dbt will use that cluster +- Similarly, you can define a reusable job cluster for the workflow and tell the task to use that +- If none of those are in the configuration, the task cluster will be serverless + +```yaml +# Reusable job cluster config example + +models: + - name: my_model + + config: + workflow_job_config: + additional_task_settings: { + task_key: 'task_a', + job_cluster_key: 'cluster_a', + } + post_hook_tasks: [{ + depends_on: [{ "task_key": "task_a" }], + task_key: 'OPTIMIZE_AND_VACUUM', + job_cluster_key: 'cluster_a', + notebook_task: { + notebook_path: "/OPTIMIZE_AND_VACUUM", + source: "WORKSPACE", + base_parameters: { + database: "{{ target.database }}", + schema: "{{ target.schema }}", + table_name: "my_model" + } + }, + }] + job_clusters: [{ + job_cluster_key: 'cluster_a', + new_cluster: { + spark_version: "{{ var('dbr_versions')['lts_v14'] }}", + node_type_id: "{{ var('cluster_node_types')['large_job'] }}", + runtime_engine: "{{ var('job_cluster_defaults.runtime_engine') }}", + autoscale: { + "min_workers": 1, + "max_workers": 2 + }, + } + }] +``` + +#### Grants + +You might want to give certain users or teams access to run your workflows outside of +dbt in an ad hoc way. You can define those permissions in the `workflow_job_config.grants`. +The owner will always be the user or service principal creating the workflows. + +These grants will also be applied to one-time run models using the `job_cluster` submission +method. + +The dbt rules correspond with the following Databricks permissions: + +- view: `CAN_VIEW` +- run: `CAN_MANAGE_RUN` +- manage: `CAN_MANAGE` + +``` +grants: + view: [ + {"group_name": "marketing-team"}, + ] + run: [ + {"user_name": "alighodsi@databricks.com"} + ] + manage: [] +``` + +#### Post hooks + +It is possible to add in python hooks by using the `config.workflow_job_config.post_hook_tasks` +attribute. You will need to define the cluster for each task, or use a reusable one from +`config.workflow_job_config.job_clusters`. \ No newline at end of file diff --git a/tests/functional/adapter/python_model/fixtures.py b/tests/functional/adapter/python_model/fixtures.py index 9e048d285..fc4e451b9 100644 --- a/tests/functional/adapter/python_model/fixtures.py +++ b/tests/functional/adapter/python_model/fixtures.py @@ -33,6 +33,21 @@ def model(dbt, spark): identifier: source """ +workflow_schema = """version: 2 + +models: + - name: my_workflow_model + config: + submission_method: workflow_job + user_folder_for_python: true + workflow_job_config: + max_retries: 2 + timeout_seconds: 500 + additional_task_settings: { + "task_key": "my_dbt_task" + } +""" + simple_python_model_v2 = """ import pandas diff --git a/tests/functional/adapter/python_model/test_python_model.py b/tests/functional/adapter/python_model/test_python_model.py index e20f11346..bf1bd1f4b 100644 --- a/tests/functional/adapter/python_model/test_python_model.py +++ b/tests/functional/adapter/python_model/test_python_model.py @@ -144,3 +144,22 @@ def test_expected_handling_of_complex_config(self, project): fetch="all", ) assert results[0][0] == "This is a python table" + + +@pytest.mark.python +@pytest.mark.skip_profile("databricks_cluster", "databricks_uc_sql_endpoint") +class TestWorkflowJob: + @pytest.fixture(scope="class") + def models(self): + return { + "schema.yml": override_fixtures.workflow_schema, + "my_workflow_model.py": override_fixtures.simple_python_model, + } + + def test_workflow_run(self, project): + util.run_dbt(["run", "-s", "my_workflow_model"]) + + sql_results = project.run_sql( + "SELECT * FROM {database}.{schema}.my_workflow_model", fetch="all" + ) + assert len(sql_results) == 10 diff --git a/tests/unit/api_client/test_user_folder_api.py b/tests/unit/api_client/test_user_folder_api.py index 98e5f47ee..0006c3d11 100644 --- a/tests/unit/api_client/test_user_folder_api.py +++ b/tests/unit/api_client/test_user_folder_api.py @@ -1,15 +1,17 @@ import pytest from dbt.adapters.databricks.api_client import UserFolderApi +from dbt.adapters.databricks.api_client import CurrUserApi from tests.unit.api_client.api_test_base import ApiTestBase class TestUserFolderApi(ApiTestBase): @pytest.fixture def api(self, session, host): - return UserFolderApi(session, host) + user_api = CurrUserApi(session, host) + return UserFolderApi(session, host, user_api) def test_get_folder__already_set(self, api): - api._user = "me" + api.user_api._user = "me" assert "/Users/me/dbt_python_models/catalog/schema/" == api.get_folder("catalog", "schema") def test_get_folder__non_200(self, api, session): @@ -20,7 +22,7 @@ def test_get_folder__200(self, api, session, host): session.get.return_value.json.return_value = {"userName": "me@gmail.com"} folder = api.get_folder("catalog", "schema") assert folder == "/Users/me@gmail.com/dbt_python_models/catalog/schema/" - assert api._user == "me@gmail.com" + assert api.user_api._user == "me@gmail.com" session.get.assert_called_once_with( f"https://{host}/api/2.0/preview/scim/v2/Me", json=None, params=None ) diff --git a/tests/unit/python/test_python_submissions.py b/tests/unit/python/test_python_submissions.py index f2a94cbb2..902831427 100644 --- a/tests/unit/python/test_python_submissions.py +++ b/tests/unit/python/test_python_submissions.py @@ -1,5 +1,9 @@ +from mock import patch +from unittest.mock import Mock + from dbt.adapters.databricks.credentials import DatabricksCredentials from dbt.adapters.databricks.python_models.python_submissions import BaseDatabricksHelper +from dbt.adapters.databricks.python_models.python_submissions import WorkflowPythonJobHelper # class TestDatabricksPythonSubmissions: @@ -25,6 +29,7 @@ class DatabricksTestHelper(BaseDatabricksHelper): def __init__(self, parsed_model: dict, credentials: DatabricksCredentials): self.parsed_model = parsed_model self.credentials = credentials + self.job_grants = self.workflow_spec.get("grants", {}) class TestAclUpdate: @@ -56,3 +61,189 @@ def test_non_empty_acl_non_empty_config(self): "a": "b", "access_control_list": expected_access_control["access_control_list"], } + + +class TestJobGrants: + + @patch.object(BaseDatabricksHelper, "_build_job_owner") + def test_job_owner_user(self, mock_job_owner): + mock_job_owner.return_value = ("alighodsi@databricks.com", "user_name") + + helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) + helper.job_grants = {} + + assert helper._build_job_permissions() == [ + { + "permission_level": "IS_OWNER", + "user_name": "alighodsi@databricks.com", + } + ] + + @patch.object(BaseDatabricksHelper, "_build_job_owner") + def test_job_owner_service_principal(self, mock_job_owner): + mock_job_owner.return_value = ( + "9533b8cc-2d60-46dd-84f2-a39b3939e37a", + "service_principal_name", + ) + + helper = DatabricksTestHelper({"config": {}}, DatabricksCredentials()) + helper.job_grants = {} + + assert helper._build_job_permissions() == [ + { + "permission_level": "IS_OWNER", + "service_principal_name": "9533b8cc-2d60-46dd-84f2-a39b3939e37a", + } + ] + + @patch.object(BaseDatabricksHelper, "_build_job_owner") + def test_job_grants(self, mock_job_owner): + mock_job_owner.return_value = ( + "9533b8cc-2d60-46dd-84f2-a39b3939e37a", + "service_principal_name", + ) + helper = DatabricksTestHelper( + { + "config": { + "workflow_job_config": { + "grants": { + "view": [ + {"user_name": "reynoldxin@databricks.com"}, + {"user_name": "alighodsi@databricks.com"}, + ], + "run": [{"group_name": "dbt-developers"}], + "manage": [{"group_name": "dbt-admins"}], + } + } + } + }, + DatabricksCredentials(), + ) + + actual = helper._build_job_permissions() + + expected_owner = { + "service_principal_name": "9533b8cc-2d60-46dd-84f2-a39b3939e37a", + "permission_level": "IS_OWNER", + } + expected_viewer_1 = { + "permission_level": "CAN_VIEW", + "user_name": "reynoldxin@databricks.com", + } + expected_viewer_2 = { + "permission_level": "CAN_VIEW", + "user_name": "alighodsi@databricks.com", + } + expected_runner = {"permission_level": "CAN_MANAGE_RUN", "group_name": "dbt-developers"} + expected_manager = {"permission_level": "CAN_MANAGE", "group_name": "dbt-admins"} + + assert expected_owner in actual + assert expected_viewer_1 in actual + assert expected_viewer_2 in actual + assert expected_runner in actual + assert expected_manager in actual + + +class TestWorkflowConfig: + def default_config(self): + return { + "alias": "test_model", + "database": "test_database", + "schema": "test_schema", + "config": { + "workflow_job_config": { + "email_notifications": "test@example.com", + "max_retries": 2, + "timeout_seconds": 500, + }, + "job_cluster_config": { + "spark_version": "15.3.x-scala2.12", + "node_type_id": "rd-fleet.2xlarge", + "autoscale": {"min_workers": 1, "max_workers": 2}, + }, + }, + } + + @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") + def test_build_job_spec_default(self, mock_api_client): + job = WorkflowPythonJobHelper(self.default_config(), Mock()) + result = job._build_job_spec() + + assert result["name"] == "dbt__test_database-test_schema-test_model" + assert len(result["tasks"]) == 1 + + task = result["tasks"][0] + assert task["task_key"] == "inner_notebook" + assert task["new_cluster"]["spark_version"] == "15.3.x-scala2.12" + + @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") + def test_build_job_spec_custom_name(self, mock_api_client): + config = self.default_config() + config["config"]["workflow_job_config"]["name"] = "custom_job_name" + job = WorkflowPythonJobHelper(config, Mock()) + result = job._build_job_spec() + + assert result["name"] == "custom_job_name" + + @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") + def test_build_job_spec_existing_cluster(self, mock_api_client): + config = self.default_config() + config["config"]["workflow_job_config"]["existing_cluster_id"] = "cluster-123" + del config["config"]["job_cluster_config"] + + job = WorkflowPythonJobHelper(config, Mock()) + result = job._build_job_spec() + + task = result["tasks"][0] + assert task["existing_cluster_id"] == "cluster-123" + assert "new_cluster" not in task + + @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") + def test_build_job_spec_serverless(self, mock_api_client): + config = self.default_config() + del config["config"]["job_cluster_config"] + + job = WorkflowPythonJobHelper(config, Mock()) + result = job._build_job_spec() + + task = result["tasks"][0] + assert "existing_cluster_id" not in task + assert "new_cluster" not in task + + @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") + def test_build_job_spec_with_additional_task_settings(self, mock_api_client): + config = self.default_config() + config["config"]["workflow_job_config"]["additional_task_settings"] = { + "task_key": "my_dbt_task" + } + job = WorkflowPythonJobHelper(config, Mock()) + result = job._build_job_spec() + + task = result["tasks"][0] + assert task["task_key"] == "my_dbt_task" + + @patch("dbt.adapters.databricks.python_models.python_submissions.DatabricksApiClient") + def test_build_job_spec_with_post_hooks(self, mock_api_client): + config = self.default_config() + config["config"]["workflow_job_config"]["post_hook_tasks"] = [ + { + "depends_on": [{"task_key": "inner_notebook"}], + "task_key": "task_b", + "notebook_task": { + "notebook_path": "/Workspace/Shared/test_notebook", + "source": "WORKSPACE", + }, + "new_cluster": { + "spark_version": "14.3.x-scala2.12", + "node_type_id": "rd-fleet.2xlarge", + "autoscale": {"min_workers": 1, "max_workers": 2}, + }, + } + ] + + job = WorkflowPythonJobHelper(config, Mock()) + result = job._build_job_spec() + + assert len(result["tasks"]) == 2 + assert result["tasks"][1]["task_key"] == "task_b" + assert result["tasks"][1]["new_cluster"]["spark_version"] == "14.3.x-scala2.12"