Skip to content

Commit

Permalink
Remove virtualenv indirection
Browse files Browse the repository at this point in the history
Assume that google.protobuf/protoc are available in the environment that
nsys-jax and nsys-jax-combine run in.
  • Loading branch information
olupton committed Aug 21, 2024
1 parent dfab493 commit 0719d92
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 67 deletions.
2 changes: 1 addition & 1 deletion .github/container/Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/

ADD nsys-jax nsys-jax-combine /usr/local/bin/
ADD jax_nsys/ /opt/jax_nsys
ADD requirements-nsys-jax.in /opt/pip-tools.d/
RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in
RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/

###############################################################################
Expand Down
3 changes: 3 additions & 0 deletions .github/container/jax_nsys/python/jax_nsys/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ dependencies = [
"ipython",
"numpy",
"pandas",
"protobuf", # a compatible version of protoc needs to be installed out-of-band
"pyarrow",
"requests", # for install-protoc
"uncertainties", # communication analysis recipe
]
requires-python = ">= 3.10"
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def format_bandwidth(data, collective):
return "-" * width
return f"{data[collective]:>{width}S}"

for message_size, data in summary_data.items():
for message_size in sorted(summary_data.keys()):
data = summary_data[message_size]
print(
" | ".join(
[format_message_size(message_size)]
Expand Down
15 changes: 1 addition & 14 deletions .github/container/nsys-jax
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import subprocess
import sys
import tempfile
import time
import virtualenv # type: ignore
import zipfile


Expand Down Expand Up @@ -544,18 +543,6 @@ def execute_analysis_scripts(mirror_dir, analysis_scripts):
return [], 0

assert mirror_dir is not None
venv_dir = osp.join(mirror_dir, "venv")
virtualenv.cli_run([venv_dir, "--python", sys.executable, "--system-site-packages"])
subprocess.run(
[
osp.join(venv_dir, "bin", "pip"),
"--disable-pip-version-check",
"install",
"-e",
osp.join(mirror_dir, "python", "jax_nsys"),
],
check=True,
)
output = []
exit_code = 0
used_slugs = set()
Expand All @@ -574,7 +561,7 @@ def execute_analysis_scripts(mirror_dir, analysis_scripts):
candidates = list(filter(osp.exists, search))
assert len(candidates), f"Could not find analysis script, tried {search}"
args.append(mirror_dir)
analysis_command = [osp.join(venv_dir, "bin", "python"), candidates[0]] + args
analysis_command = [sys.executable, candidates[0]] + args
# Derive a unique name slug from the analysis script name
slug = osp.basename(candidates[0]).removesuffix(".py")
n, suffix = 1, ""
Expand Down
21 changes: 2 additions & 19 deletions .github/container/nsys-jax-combine
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import shutil
import subprocess
import sys
import tempfile
import virtualenv # type: ignore
import zipfile

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -159,20 +158,6 @@ with zipfile.ZipFile(args.output, "w") as ofile:
write(dst_info)
if len(args.analysis):
assert mirror_dir is not None
venv_dir = mirror_dir / "venv"
virtualenv.cli_run(
[str(venv_dir), "--python", sys.executable, "--system-site-packages"]
)
subprocess.run(
[
venv_dir / "bin" / "pip",
"--disable-pip-version-check",
"install",
"-e",
mirror_dir / "python" / "jax_nsys",
],
check=True,
)
used_slugs = set()
for analysis in args.analysis:
# Execute post-processing recipes and add any outputs to `ofile`
Expand All @@ -185,9 +170,7 @@ with zipfile.ZipFile(args.output, "w") as ofile:
candidates = list(filter(lambda p: p.exists(), search))
assert len(candidates), f"Could not find analysis script, tried {search}"
analysis_command = (
[venv_dir / "bin" / "python", candidates[0]]
+ script_args
+ [mirror_dir]
[sys.executable, candidates[0]] + script_args + [mirror_dir]
)
# Derive a unique name slug from the analysis script name
slug = os.path.basename(candidates[0]).removesuffix(".py")
Expand All @@ -210,7 +193,7 @@ with zipfile.ZipFile(args.output, "w") as ofile:
# Gather output files of the scrpt
for path in working_dir.rglob("*"):
with open(working_dir / path, "rb") as src, ofile.open(
str(pathlib.Path("analysis") / slug / path), "w"
str(path.relative_to(mirror_dir)), "w"
) as dst:
# https://github.com/python/mypy/issues/15031 ?
shutil.copyfileobj(src, dst) # type: ignore
2 changes: 1 addition & 1 deletion .github/container/pip-finalize.sh
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,6 @@ pip-sync --pip-args '--no-deps --src /opt' requirements.txt

rm -rf ~/.cache/*

# protobuf will be installed at least due to requirements-nsys-jax.in in the base
# protobuf will be installed at least as a dependency of jax_nsys in the base
# image, but the installed version is likely to be influenced by other packages.
install-protoc /usr/local
7 changes: 0 additions & 7 deletions .github/container/requirements-nsys-jax.in

This file was deleted.

53 changes: 29 additions & 24 deletions .github/workflows/nsys-jax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,20 +38,11 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Create virtual environment"
run: |
pip install virtualenv
virtualenv venv
- name: "Install google.protobuf and protoc"
run: |
./venv/bin/pip install -r ./JAX-Toolbox/.github/container/requirements-nsys-jax.in
./venv/bin/python ./JAX-Toolbox/.github/container/jax_nsys/install-protoc ./venv
- name: "Install jax_nsys Python package"
run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys
- name: "Install mypy"
run: ./venv/bin/pip install matplotlib mypy nbconvert types-protobuf
- name: "Install JAX for type-checking, this is a CPU-only build of the latest release"
run: ./venv/bin/pip install jax
# jax is just a CPU-only build of the latest release for type-checking purposes
- name: "Install jax / jax_nsys / mypy"
run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf
- name: "Install protoc"
run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
Expand All @@ -66,19 +57,19 @@ jobs:
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
./venv/bin/python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
shell: bash -x -e {0}
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
./venv/bin/jupyter nbconvert --to script ${notebook}
jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
shell: bash -x -e {0}
run: |
export MYPYPATH="${PWD}/compiled_stubs"
./venv/bin/mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES}
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES}
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered
# notebook from here too. These input files were generated with something like
Expand All @@ -96,24 +87,38 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: '3.12'
# TODO: a modern nsys-jax-combine with old .zip input should probably produce a
# .zip with a modern jax_nsys/
- name: Add modern jax_nsys/ files to static .zip inputs
run: |
cd .github/container/jax_nsys
for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do
zip -ur "${zip}" .
zipinfo "${zip}"
done
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes
shell: bash -x -e {0}
run: |
pip install virtualenv
.github/container/nsys-jax-combine \
pip install -e .github/container/jax_nsys/python/jax_nsys
python .github/container/jax_nsys/install-protoc local_protoc
PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \
--analysis summary \
--analysis communication \
-o .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip \
-o pax_fsdp4_4proc.zip \
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip
- name: Mock up the structure of an extracted .zip file
run: unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip
- name: Extract the output .zip file
run: |
mkdir combined/
unzip -d combined/ pax_fsdp4_4proc.zip
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
pushd combined/
./nsys_jax_venv/bin/python -m jupyterlab --version
# Run with ipython for the sake of getting a clear error message
./nsys_jax_venv/bin/ipython Analysis.ipynb
Expand Down

0 comments on commit 0719d92

Please sign in to comment.