Skip to content

Commit

Permalink
Show progress bar for spark jobs from DB Connect
Browse files Browse the repository at this point in the history
  • Loading branch information
fjakobs committed Sep 13, 2024
1 parent d7cbb50 commit 24b15a9
Showing 1 changed file with 64 additions and 0 deletions.
64 changes: 64 additions & 0 deletions packages/databricks-vscode/resources/python/00-databricks-init.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from typing import Any, Union, List
import os
import time
import shlex
import warnings
import tempfile
Expand All @@ -24,6 +25,7 @@ def logError(function_name: str, e: Union[str, Exception]):

try:
from IPython import get_ipython
from IPython.display import display
from IPython.core.magic import magics_class, Magics, line_magic, needs_local_scope
except Exception as e:
logError("Ipython Imports", e)
Expand Down Expand Up @@ -357,6 +359,67 @@ def df_html(df):
html_formatter.for_type(SparkConnectDataframe, df_html)
html_formatter.for_type(DataFrame, df_html)

@logErrorAndContinue
@disposable
def register_spark_progress(spark):
try:
from pyspark.sql.connect.shell.progress import Progress
import ipywidgets as widgets
except Exception as e:
return

class JupyterProgressHandler:
def __init__(self):
self.op_id = ""
self.p = None

def reset(self):
self.p = Progress(enabled=False, handlers=[])
self.init_ui()

def init_ui(self):
self.w_progress = widgets.IntProgress(
value=0,
min=0,
max=100,
bar_style='success',
orientation='horizontal'
)
self.w_status = widgets.Label(value="")
display(widgets.HBox([self.w_progress, self.w_status]))

def __call__(self,
stages,
inflight_tasks: int,
operation_id,
done: bool
):
# print("ProgressHandler", stages, inflight_tasks, operation_id, done)
if len(stages) == 0:
return

if self.op_id != operation_id or self.p is None:
self.op_id = operation_id
self.reset()

self.p.update_ticks(stages, inflight_tasks, operation_id)
self.output()
if done:
self.wip = False

def output(self) -> None:
p = self.p
if p._tick is not None and p._ticks is not None:
percent_complete = (p._tick / p._ticks) * 100
elapsed = int(time.time() - p._started)
scanned = p._bytes_to_string(p._bytes_read)
running = p._running
self.w_progress.value = percent_complete
self.w_status.value = f"{percent_complete:.2f}% Complete ({running} Tasks running, {elapsed}s, Scanned {scanned})"

spark.clearProgressHandlers()
spark.registerProgressHandler(JupyterProgressHandler())


@logErrorAndContinue
@disposable
Expand Down Expand Up @@ -385,6 +448,7 @@ def make_matplotlib_inline():
create_and_register_databricks_globals()
register_magics(cfg)
register_formatters(cfg)
register_spark_progress(globals()["spark"])
update_sys_path(cfg)
make_matplotlib_inline()

Expand Down

0 comments on commit 24b15a9

Please sign in to comment.