Skip to content

Commit

Permalink
Inline code from pyspark
Browse files Browse the repository at this point in the history
  • Loading branch information
fjakobs committed Sep 16, 2024
1 parent 24b15a9 commit b1ceb52
Showing 1 changed file with 56 additions and 24 deletions.
80 changes: 56 additions & 24 deletions packages/databricks-vscode/resources/python/00-databricks-init.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,18 @@ def register_spark_progress(spark):
except Exception as e:
return

class JupyterProgressHandler:
def __init__(self):
self.op_id = ""
self.p = None

def reset(self):
self.p = Progress(enabled=False, handlers=[])
class Progress:
SI_BYTE_SIZES = (1 << 60, 1 << 50, 1 << 40, 1 << 30, 1 << 20, 1 << 10, 1)
SI_BYTE_SUFFIXES = ("EiB", "PiB", "TiB", "GiB", "MiB", "KiB", "B")

def __init__(
self,
) -> None:
self._ticks = None
self._tick = None
self._started = time.time()
self._bytes_read = 0
self._running = 0
self.init_ui()

def init_ui(self):
Expand All @@ -388,37 +393,64 @@ def init_ui(self):
self.w_status = widgets.Label(value="")
display(widgets.HBox([self.w_progress, self.w_status]))

def update_ticks(
self,
stages,
inflight_tasks: int
) -> None:
total_tasks = sum(map(lambda x: x.num_tasks, stages))
completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages))
if total_tasks > 0:
self._ticks = total_tasks
self._tick = completed_tasks
self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages))
if self._tick is not None and self._tick >= 0:
self.output()
self._running = inflight_tasks

def output(self) -> None:
if self._tick is not None and self._ticks is not None:
percent_complete = (self._tick / self._ticks) * 100
elapsed = int(time.time() - self._started)
scanned = self._bytes_to_string(self._bytes_read)
running = self._running
self.w_progress.value = percent_complete
self.w_status.value = f"{percent_complete:.2f}% Complete ({running} Tasks running, {elapsed}s, Scanned {scanned})"

@staticmethod
def _bytes_to_string(size: int) -> str:
"""Helper method to convert a numeric bytes value into a human-readable representation"""
i = 0
while i < len(Progress.SI_BYTE_SIZES) - 1 and size < 2 * Progress.SI_BYTE_SIZES[i]:
i += 1
result = float(size) / Progress.SI_BYTE_SIZES[i]
return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}"


class ProgressHandler:
def __init__(self):
self.op_id = ""

def reset(self):
self.p = Progress()

def __call__(self,
stages,
inflight_tasks: int,
operation_id,
done: bool
):
# print("ProgressHandler", stages, inflight_tasks, operation_id, done)
if len(stages) == 0:
return

if self.op_id != operation_id or self.p is None:
if self.op_id != operation_id:
self.op_id = operation_id
self.reset()

self.p.update_ticks(stages, inflight_tasks, operation_id)
self.output()
if done:
self.wip = False

def output(self) -> None:
p = self.p
if p._tick is not None and p._ticks is not None:
percent_complete = (p._tick / p._ticks) * 100
elapsed = int(time.time() - p._started)
scanned = p._bytes_to_string(p._bytes_read)
running = p._running
self.w_progress.value = percent_complete
self.w_status.value = f"{percent_complete:.2f}% Complete ({running} Tasks running, {elapsed}s, Scanned {scanned})"
self.p.update_ticks(stages, inflight_tasks)

spark.clearProgressHandlers()
spark.registerProgressHandler(JupyterProgressHandler())
spark.registerProgressHandler(ProgressHandler())


@logErrorAndContinue
Expand Down

0 comments on commit b1ceb52

Please sign in to comment.