diff --git a/packages/databricks-vscode/resources/python/00-databricks-init.py b/packages/databricks-vscode/resources/python/00-databricks-init.py index 57d9be4e0..be9c1ee9d 100644 --- a/packages/databricks-vscode/resources/python/00-databricks-init.py +++ b/packages/databricks-vscode/resources/python/00-databricks-init.py @@ -3,6 +3,7 @@ import json from typing import Any, Union, List import os +import time import shlex import warnings import tempfile @@ -24,6 +25,7 @@ def logError(function_name: str, e: Union[str, Exception]): try: from IPython import get_ipython + from IPython.display import display from IPython.core.magic import magics_class, Magics, line_magic, needs_local_scope except Exception as e: logError("Ipython Imports", e) @@ -357,6 +359,67 @@ def df_html(df): html_formatter.for_type(SparkConnectDataframe, df_html) html_formatter.for_type(DataFrame, df_html) +@logErrorAndContinue +@disposable +def register_spark_progress(spark): + try: + from pyspark.sql.connect.shell.progress import Progress + import ipywidgets as widgets + except Exception as e: + return + + class JupyterProgressHandler: + def __init__(self): + self.op_id = "" + self.p = None + + def reset(self): + self.p = Progress(enabled=False, handlers=[]) + self.init_ui() + + def init_ui(self): + self.w_progress = widgets.IntProgress( + value=0, + min=0, + max=100, + bar_style='success', + orientation='horizontal' + ) + self.w_status = widgets.Label(value="") + display(widgets.HBox([self.w_progress, self.w_status])) + + def __call__(self, + stages, + inflight_tasks: int, + operation_id, + done: bool + ): + # print("ProgressHandler", stages, inflight_tasks, operation_id, done) + if len(stages) == 0: + return + + if self.op_id != operation_id or self.p is None: + self.op_id = operation_id + self.reset() + + self.p.update_ticks(stages, inflight_tasks, operation_id) + self.output() + if done: + self.wip = False + + def output(self) -> None: + p = self.p + if p._tick is not None and p._ticks is not None: + percent_complete = (p._tick / p._ticks) * 100 + elapsed = int(time.time() - p._started) + scanned = p._bytes_to_string(p._bytes_read) + running = p._running + self.w_progress.value = percent_complete + self.w_status.value = f"{percent_complete:.2f}% Complete ({running} Tasks running, {elapsed}s, Scanned {scanned})" + + spark.clearProgressHandlers() + spark.registerProgressHandler(JupyterProgressHandler()) + @logErrorAndContinue @disposable @@ -385,6 +448,7 @@ def make_matplotlib_inline(): create_and_register_databricks_globals() register_magics(cfg) register_formatters(cfg) + register_spark_progress(globals()["spark"]) update_sys_path(cfg) make_matplotlib_inline()