Skip to content

Commit

Permalink
Support for %run magics and # MAGIC commands (#834)
Browse files Browse the repository at this point in the history
## Changes
<!-- Summary of your changes that are easy to understand -->

## Tests
<!-- How is this tested? -->
  • Loading branch information
kartikgupta-db authored Sep 5, 2023
1 parent f544d6b commit 65ab0c2
Showing 1 changed file with 112 additions and 23 deletions.
135 changes: 112 additions & 23 deletions packages/databricks-vscode/resources/python/00-databricks-init.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from contextlib import contextmanager
import functools
from typing import Union, List
import json
from typing import Any, Union, List
import os
import shlex
import warnings
import tempfile


def logError(function_name: str, e: Union[str, Exception]):
Expand Down Expand Up @@ -91,7 +94,7 @@ def __set__(self, instance, value):

@disposable
class LocalDatabricksNotebookConfig:
project_root = EnvLoader("DATABRICKS_PROJECT_ROOT", required=True)
project_root: str = EnvLoader("DATABRICKS_PROJECT_ROOT", required=True)
dataframe_display_limit: int = EnvLoader("DATABRICKS_DF_DISPLAY_LIMIT", 20)

def __new__(cls):
Expand All @@ -104,12 +107,6 @@ def __new__(cls):
@magics_class
@disposable
class DatabricksMagics(Magics):
@line_magic
def run(self, line):
raise NotImplementedError(
"%run is not supported for local Databricks Notebooks."
)

@needs_local_scope
@line_magic
def fs(self, line: str, local_ns):
Expand All @@ -131,10 +128,77 @@ def fs(self, line: str, local_ns):
return cmd(*args[1:])


def is_databricks_notebook(py_file: str):
if os.path.exists(py_file):
with open(py_file, "r") as f:
return "Databricks notebook source" in f.readline()

def strip_hash_magic(lines: List[str]) -> List[str]:
if len(lines) == 0:
return lines
if lines[0].startswith("# MAGIC"):
return [line.partition("# MAGIC")[2] for line in lines]
return lines

def convert_databricks_notebook_to_ipynb(py_file: str):
cells: List[dict[str, Any]] = [
{
"cell_type": "code",
"source": "import os\nos.chdir('" + os.path.dirname(py_file) + "')\n",
"metadata": {},
'outputs': [],
'execution_count': None
}
]
with open(py_file) as file:
text = file.read()
for cell in text.split("# COMMAND ----------"):
cell = ''.join(strip_hash_magic(cell.strip().splitlines(keepends=True)))
cells.append(
{
"cell_type": "code",
"source": cell,
"metadata": {},
'outputs': [],
'execution_count': None
}
)

return json.dumps({
'cells': cells,
'metadata': {},
'nbformat': 4,
'nbformat_minor': 2
})


@contextmanager
def databricks_notebook_exec_env(project_root: str, py_file: str):
import sys
old_sys_path = sys.path
old_cwd = os.getcwd()

sys.path.append(project_root)
sys.path.append(os.path.dirname(py_file))

try:
if is_databricks_notebook(py_file):
notebook = convert_databricks_notebook_to_ipynb(py_file)
with tempfile.NamedTemporaryFile(suffix=".ipynb") as f:
f.write(notebook.encode())
f.flush()
yield f.name
else:
yield py_file
finally:
sys.path = old_sys_path
os.chdir(old_cwd)


@logErrorAndContinue
@disposable
def register_magics():
def warn_for_dbr_alternative(magic):
def register_magics(cfg: LocalDatabricksNotebookConfig):
def warn_for_dbr_alternative(magic: str):
# Magics that are not supported on Databricks but work in jupyter notebooks.
# We show a warning, prompting users to use a databricks equivalent instead.
local_magic_dbr_alternative = {"%%sh": "%sh"}
Expand All @@ -147,7 +211,7 @@ def warn_for_dbr_alternative(magic):
+ " instead."
)

def throw_if_not_supported(magic):
def throw_if_not_supported(magic: str):
# These are magics that are supported on dbr but not locally.
unsupported_dbr_magics = ["%r", "%scala"]
if magic in unsupported_dbr_magics:
Expand All @@ -157,21 +221,34 @@ def throw_if_not_supported(magic):
)

def is_cell_magic(lines: List[str]):
def get_cell_magic(lines: List[str]):
if len(lines) == 0:
return
if lines[0].startswith("%%"):
return lines[0].split(" ")[0].strip()

def handle(lines: List[str]):
cell_magic = is_cell_magic(lines)
cell_magic = get_cell_magic(lines)
if cell_magic is None:
return lines
warn_for_dbr_alternative(cell_magic)
throw_if_not_supported(cell_magic)
return lines

is_cell_magic.handle = handle
if len(lines) == 0:
return
if lines[0].startswith("%%"):
return lines[0].split(" ")[0].strip()
return get_cell_magic(lines) is not None

def is_line_magic(lines: List[str]):
def get_line_magic(lines: List[str]):
if len(lines) == 0:
return
if lines[0].startswith("%"):
return lines[0].split(" ")[0].strip().strip("%")

def handle(lines: List[str]):
lmagic = is_line_magic(lines)
lmagic = get_line_magic(lines)
if lmagic is None:
return lines
warn_for_dbr_alternative(lmagic)
throw_if_not_supported(lmagic)

Expand Down Expand Up @@ -200,18 +277,30 @@ def handle(lines: List[str]):

if lmagic == "python":
return lines[1:]


if lmagic == "run":
rest = lines[0].strip().split(" ")[1:]
filename = ""
for arg in rest:
if arg.endswith((".py", ".ipy", ".ipynb")):
filename = arg
break

return [
f"with databricks_notebook_exec_env('{cfg.project_root}', '{filename}') as file:\n",
"\t%run -i {file} " + lines[0].partition('%run')[2].partition(filename)[2] + "\n"
]

return lines

is_line_magic.handle = handle
if len(lines) == 0:
return
if lines[0].startswith("%"):
return lines[0].split(" ")[0].strip().strip("%")
return get_line_magic(lines) is not None


def parse_line_for_databricks_magics(lines: List[str]):
if len(lines) == 0:
return lines
lines = strip_hash_magic(lines)
for magic_check in [is_cell_magic, is_line_magic]:
if magic_check(lines):
return magic_check.handle(lines)
Expand Down Expand Up @@ -260,7 +349,7 @@ def make_matplotlib_inline():
print(sys.modules[__name__])
cfg = LocalDatabricksNotebookConfig()
create_and_register_databricks_globals()
register_magics()
register_magics(cfg)
register_formatters(cfg)
update_sys_path(cfg)
make_matplotlib_inline()
Expand Down

0 comments on commit 65ab0c2

Please sign in to comment.