Skip to content

nsys-jax: add basic CI, support all-to-all and repeated thunks #3

nsys-jax: add basic CI, support all-to-all and repeated thunks

nsys-jax: add basic CI, support all-to-all and repeated thunks #3

Workflow file for this run

name: nsys-jax pure-Python CI
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
on:
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths-ignore:
- '**.md'
env:
NSYS_JAX_PYTHON_FILES: JAX-Toolbox/.github/container/nsys-jax JAX-Toolbox/.github/container/jax_nsys
jobs:
mypy:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Create virtual environment"
run: |
pip install virtualenv
virtualenv venv
- name: "Run nsys-jax-ensure-protobuf"
run: |
. ./venv/bin/activate
./JAX-Toolbox/.github/container/jax_nsys/nsys-jax-ensure-protobuf
- name: "Install jax_nsys Python package"
run: ./venv/bin/pip install -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys
- name: "Install mypy"
run: ./venv/bin/pip install matplotlib mypy nbconvert types-protobuf
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
path: xla
repository: openxla/xla
sparse-checkout: |
*.proto
sparse-checkout-cone-mode: false
- name: "Compile .proto files"
shell: bash -x -e {0}
run: |
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
./venv/bin/python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
shell: bash -x -e {0}
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
./venv/bin/jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
shell: bash -x -e {0}
run: |
export MYPYPATH="${PWD}/compiled_stubs"
./venv/bin/mypy ${NSYS_JAX_PYTHON_FILES}
notebook:
runs-on: ubuntu-22.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Mock up the structure of an extracted .zip file
shell: bash -x -e {0}
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
mv .github/workflows/nsys-jax/maxtext_fsdp4_test_data/* .github/container/jax_nsys/
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
./nsys_jax_venv/bin/python -m jupyterlab --version
./nsys_jax_venv/bin/ipython Analysis.ipynb
ruff:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Install ruff"
run: pip install ruff
- name: "Run ruff checks"
shell: bash -x {0}
run: |
ruff check ${NSYS_JAX_PYTHON_FILES}
check_status=$?
ruff format --check ${NSYS_JAX_PYTHON_FILES}
format_status=$?
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi