diff --git a/.github/container/Dockerfile.base b/.github/container/Dockerfile.base index 6857d4267..023576cb5 100644 --- a/.github/container/Dockerfile.base +++ b/.github/container/Dockerfile.base @@ -205,7 +205,7 @@ COPY check-shm.sh /opt/nvidia/entrypoint.d/ ADD nsys-jax nsys-jax-combine /usr/local/bin/ ADD jax_nsys/ /opt/jax_nsys -ADD requirements-nsys-jax.in /opt/pip-tools.d/ +RUN echo "-e /opt/jax_nsys/python/jax_nsys" > /opt/pip-tools.d/requirements-nsys-jax.in RUN ln -s /opt/jax_nsys/install-protoc /usr/local/bin/ ############################################################################### diff --git a/.github/container/jax_nsys/python/jax_nsys/pyproject.toml b/.github/container/jax_nsys/python/jax_nsys/pyproject.toml index cc3f3981a..4c5ca9600 100644 --- a/.github/container/jax_nsys/python/jax_nsys/pyproject.toml +++ b/.github/container/jax_nsys/python/jax_nsys/pyproject.toml @@ -5,6 +5,9 @@ dependencies = [ "ipython", "numpy", "pandas", + "protobuf", # a compatible version of protoc needs to be installed out-of-band "pyarrow", + "requests", # for install-protoc + "uncertainties", # communication analysis recipe ] requires-python = ">= 3.10" diff --git a/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py b/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py index 8953b6413..ef02f5c1b 100644 --- a/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py +++ b/.github/container/jax_nsys/python/jax_nsys_analysis/communication.py @@ -83,7 +83,8 @@ def format_bandwidth(data, collective): return "-" * width return f"{data[collective]:>{width}S}" - for message_size, data in summary_data.items(): + for message_size in sorted(summary_data.keys()): + data = summary_data[message_size] print( " | ".join( [format_message_size(message_size)] diff --git a/.github/container/nsys-jax b/.github/container/nsys-jax index 16d10d82d..104732306 100755 --- a/.github/container/nsys-jax +++ b/.github/container/nsys-jax @@ -16,7 +16,6 @@ import subprocess import sys import tempfile import time -import virtualenv # type: ignore import zipfile @@ -544,18 +543,6 @@ def execute_analysis_scripts(mirror_dir, analysis_scripts): return [], 0 assert mirror_dir is not None - venv_dir = osp.join(mirror_dir, "venv") - virtualenv.cli_run([venv_dir, "--python", sys.executable, "--system-site-packages"]) - subprocess.run( - [ - osp.join(venv_dir, "bin", "pip"), - "--disable-pip-version-check", - "install", - "-e", - osp.join(mirror_dir, "python", "jax_nsys"), - ], - check=True, - ) output = [] exit_code = 0 used_slugs = set() @@ -574,7 +561,7 @@ def execute_analysis_scripts(mirror_dir, analysis_scripts): candidates = list(filter(osp.exists, search)) assert len(candidates), f"Could not find analysis script, tried {search}" args.append(mirror_dir) - analysis_command = [osp.join(venv_dir, "bin", "python"), candidates[0]] + args + analysis_command = [sys.executable, candidates[0]] + args # Derive a unique name slug from the analysis script name slug = osp.basename(candidates[0]).removesuffix(".py") n, suffix = 1, "" diff --git a/.github/container/nsys-jax-combine b/.github/container/nsys-jax-combine index c196f9aa3..00e53efb6 100755 --- a/.github/container/nsys-jax-combine +++ b/.github/container/nsys-jax-combine @@ -9,7 +9,6 @@ import shutil import subprocess import sys import tempfile -import virtualenv # type: ignore import zipfile parser = argparse.ArgumentParser( @@ -159,20 +158,6 @@ with zipfile.ZipFile(args.output, "w") as ofile: write(dst_info) if len(args.analysis): assert mirror_dir is not None - venv_dir = mirror_dir / "venv" - virtualenv.cli_run( - [str(venv_dir), "--python", sys.executable, "--system-site-packages"] - ) - subprocess.run( - [ - venv_dir / "bin" / "pip", - "--disable-pip-version-check", - "install", - "-e", - mirror_dir / "python" / "jax_nsys", - ], - check=True, - ) used_slugs = set() for analysis in args.analysis: # Execute post-processing recipes and add any outputs to `ofile` @@ -185,9 +170,7 @@ with zipfile.ZipFile(args.output, "w") as ofile: candidates = list(filter(lambda p: p.exists(), search)) assert len(candidates), f"Could not find analysis script, tried {search}" analysis_command = ( - [venv_dir / "bin" / "python", candidates[0]] - + script_args - + [mirror_dir] + [sys.executable, candidates[0]] + script_args + [mirror_dir] ) # Derive a unique name slug from the analysis script name slug = os.path.basename(candidates[0]).removesuffix(".py") @@ -210,7 +193,7 @@ with zipfile.ZipFile(args.output, "w") as ofile: # Gather output files of the scrpt for path in working_dir.rglob("*"): with open(working_dir / path, "rb") as src, ofile.open( - str(pathlib.Path("analysis") / slug / path), "w" + str(path.relative_to(mirror_dir)), "w" ) as dst: # https://github.com/python/mypy/issues/15031 ? shutil.copyfileobj(src, dst) # type: ignore diff --git a/.github/container/pip-finalize.sh b/.github/container/pip-finalize.sh index 4828af35e..1149d7638 100755 --- a/.github/container/pip-finalize.sh +++ b/.github/container/pip-finalize.sh @@ -51,6 +51,6 @@ pip-sync --pip-args '--no-deps --src /opt' requirements.txt rm -rf ~/.cache/* -# protobuf will be installed at least due to requirements-nsys-jax.in in the base +# protobuf will be installed at least as a dependency of jax_nsys in the base # image, but the installed version is likely to be influenced by other packages. install-protoc /usr/local diff --git a/.github/container/requirements-nsys-jax.in b/.github/container/requirements-nsys-jax.in deleted file mode 100644 index d07e170f2..000000000 --- a/.github/container/requirements-nsys-jax.in +++ /dev/null @@ -1,7 +0,0 @@ -# No version constraint; a compatible version of protoc will be installed later -protobuf -# Used by install-protoc -requests -virtualenv -# Used by communication analysis recipe -uncertainties diff --git a/.github/workflows/nsys-jax.yaml b/.github/workflows/nsys-jax.yaml index 4f9f48098..aa91f870c 100644 --- a/.github/workflows/nsys-jax.yaml +++ b/.github/workflows/nsys-jax.yaml @@ -38,20 +38,11 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.10' - - name: "Create virtual environment" - run: | - pip install virtualenv - virtualenv venv - - name: "Install google.protobuf and protoc" - run: | - ./venv/bin/pip install -r ./JAX-Toolbox/.github/container/requirements-nsys-jax.in - ./venv/bin/python ./JAX-Toolbox/.github/container/jax_nsys/install-protoc ./venv - - name: "Install jax_nsys Python package" - run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys - - name: "Install mypy" - run: ./venv/bin/pip install matplotlib mypy nbconvert types-protobuf - - name: "Install JAX for type-checking, this is a CPU-only build of the latest release" - run: ./venv/bin/pip install jax + # jax is just a CPU-only build of the latest release for type-checking purposes + - name: "Install jax / jax_nsys / mypy" + run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf + - name: "Install protoc" + run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc - name: "Fetch XLA .proto files" uses: actions/checkout@v4 with: @@ -66,19 +57,19 @@ jobs: mkdir compiled_protos compiled_stubs protos mv -v xla/third_party/tsl/tsl protos/ mv -v xla/xla protos/ - ./venv/bin/python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" + PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" touch compiled_stubs/py.typed - name: "Convert .ipynb to .py" shell: bash -x -e {0} run: | for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do - ./venv/bin/jupyter nbconvert --to script ${notebook} + jupyter nbconvert --to script ${notebook} done - name: "Run mypy checks" shell: bash -x -e {0} run: | export MYPYPATH="${PWD}/compiled_stubs" - ./venv/bin/mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} + mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} # Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered # notebook from here too. These input files were generated with something like @@ -96,24 +87,38 @@ jobs: uses: actions/setup-python@v5 with: python-version: '3.12' + # TODO: a modern nsys-jax-combine with old .zip input should probably produce a + # .zip with a modern jax_nsys/ + - name: Add modern jax_nsys/ files to static .zip inputs + run: | + cd .github/container/jax_nsys + for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do + zip -ur "${zip}" . + zipinfo "${zip}" + done - name: Use nsys-jax-combine to merge profiles from multiple nsys processes shell: bash -x -e {0} run: | - pip install virtualenv - .github/container/nsys-jax-combine \ + pip install -e .github/container/jax_nsys/python/jax_nsys + python .github/container/jax_nsys/install-protoc local_protoc + PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \ --analysis summary \ --analysis communication \ - -o .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip \ + -o pax_fsdp4_4proc.zip \ .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip - - name: Mock up the structure of an extracted .zip file - run: unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_4proc.zip + - name: Extract the output .zip file + run: | + mkdir combined/ + unzip -d combined/ pax_fsdp4_4proc.zip - name: Run the install script, but skip launching Jupyter Lab shell: bash -x -e {0} - run: NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh + run: | + pip install virtualenv + NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh - name: Test the Jupyter Lab installation and execute the notebook shell: bash -x -e {0} run: | - pushd .github/container/jax_nsys + pushd combined/ ./nsys_jax_venv/bin/python -m jupyterlab --version # Run with ipython for the sake of getting a clear error message ./nsys_jax_venv/bin/ipython Analysis.ipynb