diff --git a/.github/container/jax_nsys/Analysis.ipynb b/.github/container/jax_nsys/Analysis.ipynb index 3078edd0b..b5de7b68e 100644 --- a/.github/container/jax_nsys/Analysis.ipynb +++ b/.github/container/jax_nsys/Analysis.ipynb @@ -430,7 +430,9 @@ " # program, there may be different sub-groupings that are participating in smaller\n", " # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n", " # sub-groupings and group them, but we currently lack the relevant information.\n", - " collective_df = df.groupby([\"ProgramId\", \"Name\", \"ModuleExecution\"])\n", + " collective_df = df.groupby(\n", + " [\"ProgramId\", \"Name\", \"ModuleExecution\", \"ThunkExecution\"]\n", + " )\n", " # Take the fastest device kernel as a proxy for the actual bandwidth of the\n", " # collective.\n", " bandwidth_df = collective_df.agg(\n", @@ -468,30 +470,6 @@ "axs[2].set_xscale(\"log\")\n", "axs[2].set_yscale(\"log\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "710100ec-55e6-4b6f-b8aa-5e6bda15fda1", - "metadata": {}, - "outputs": [], - "source": [ - "compute_times = (\n", - " thunk_df[~thunk_df[\"Communication\"]]\n", - " .groupby([\"ProgramId\", \"Name\"])\n", - " .agg({\"ProjDurNs\": [\"mean\", \"std\"]})\n", - " .sort_values((\"ProjDurNs\", \"mean\"))\n", - ")\n", - "plt.plot(\n", - " compute_times[(\"ProjDurNs\", \"mean\")],\n", - " compute_times[(\"ProjDurNs\", \"std\")] / compute_times[(\"ProjDurNs\", \"mean\")],\n", - " \"o\",\n", - ")\n", - "# plt.errorbar([\"{}:{}\".format(*x) for x in compute_times.index], compute_times[(\"ProjDurNs\", \"mean\")], compute_times[(\"ProjDurNs\", \"std\")], marker=\"o\")\n", - "# plt.xlabel(\n", - "plt.xscale(\"log\")\n", - "# compute_times[(\"ProjDurNs\", \"std\")]" - ] } ], "metadata": { diff --git a/.github/container/jax_nsys/install.sh b/.github/container/jax_nsys/install.sh index ba2fd82e2..c0de00207 100755 --- a/.github/container/jax_nsys/install.sh +++ b/.github/container/jax_nsys/install.sh @@ -10,6 +10,7 @@ # The expectation is that those archives will be copied and extracted on a # laptop or workstation, and this installation script will be run there, while # the `nsys-jax` wrapper is executed on a remote GPU cluster. +set -ex SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) VIRTUALENV="${SCRIPT_DIR}/nsys_jax_venv" if [[ ! -d "${VIRTUALENV}" ]]; then @@ -18,12 +19,17 @@ if [[ ! -d "${VIRTUALENV}" ]]; then . "${VIRTUALENV}/bin/activate" python -m pip install -U pip "${SCRIPT_DIR}/nsys-jax-ensure-protobuf" - python -m pip install jupyterlab + # matplotlib is a dependency of Analysis.ipynb but not jax_nsys + python -m pip install jupyterlab matplotlib python -m pip install -e "${SCRIPT_DIR}/python/jax_nsys" curl -o "${VIRTUALENV}/bin/flamegraph.pl" https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl chmod 755 "${VIRTUALENV}/bin/flamegraph.pl" else echo "Virtual environment already exists, not installing anything..." fi -echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb" -cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb +if [ -z ${NSYS_JAX_INSTALL_SKIP_LAUNCH+x} ]; then + echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb" + cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb +else + echo "Skipping launch of jupyterlab due to NSYS_JAX_INSTALL_SKIP_LAUNCH" +fi diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py index 4b30af181..f69e70fe6 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/analysis.py @@ -43,6 +43,9 @@ def _collective_correction(kind: str, size: int) -> tuple[float, float]: return (size, (size - 1) / size) case "all-reduce": return (1, 2 * (size - 1) / size) + case "all-to-all": + # https://github.com/NVIDIA/nccl-tests/blob/a1efb427e764241bc43d2d91be875c9f55da03a5/src/alltoall.cu#L44 + return (1, (size - 1) / size) case "collective-broadcast": return (1, 1) case "collective-permute": @@ -71,6 +74,7 @@ def get_message_size(program_id: int, instruction: str) -> pd.Series: in { "all-gather-start", "all-reduce-start", + "all-to-all", "collective-broadcast", "collective-permute-start", "reduce-scatter", diff --git a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py index e30587b59..365850bb5 100644 --- a/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py +++ b/.github/container/jax_nsys/python/jax_nsys/jax_nsys/data_loaders.py @@ -247,6 +247,14 @@ def clean_data_frame(d, extra_columns=[]): value=r"\2", regex=True, ) + # Add a new column describing which (0th, 1st, ...) execution of the thunk + # within the given module execution this is. For example, while loops in the + # HLO can lead to the same thunk being executed multiple times within the same + # module execution. + thunk_df["ThunkExecution"] = thunk_df.groupby( + ["TID", "ProgramId", "Name", "ModuleExecution"] + ).cumcount() + # Classify thunks as communication/computation and save to output output["thunk"] = _classify_comms(thunk_df, prefix) diff --git a/.github/workflows/nsys-jax.yaml b/.github/workflows/nsys-jax.yaml index ae8239935..016a563fb 100644 --- a/.github/workflows/nsys-jax.yaml +++ b/.github/workflows/nsys-jax.yaml @@ -70,6 +70,33 @@ jobs: run: | export MYPYPATH="${PWD}/compiled_stubs" ./venv/bin/mypy ${NSYS_JAX_PYTHON_FILES} + + notebook: + runs-on: ubuntu-22.04 + steps: + - name: Check out the repository under ${GITHUB_WORKSPACE} + uses: actions/checkout@v4 + - name: Mock up the structure of an extracted .zip file + shell: bash -x -e {0} + run: | + # Get the actual test data from a real archive, minus the .nsys-rep file + mv .github/workflows/nsys-jax/maxtext_fsdp4_test_data/* .github/container/jax_nsys/ + - name: "Setup Python 3.10" + uses: actions/setup-python@v5 + with: + python-version: '3.10' + - name: Run the install script, but skip launching Jupyter Lab + shell: bash -x -e {0} + run: | + pip install virtualenv + NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh + - name: Test the Jupyter Lab installation and execute the notebook + shell: bash -x -e {0} + run: | + pushd .github/container/jax_nsys + ./nsys_jax_venv/bin/python -m jupyterlab --version + ./nsys_jax_venv/bin/ipython Analysis.ipynb + ruff: runs-on: ubuntu-24.04 steps: diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.autotune_results.pbtxt.xz new file mode 100644 index 000000000..b5663b07d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..6f58047d9 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..e9faca4cf Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0000.jit_convert_element_type.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.autotune_results.pbtxt.xz new file mode 100644 index 000000000..b5663b07d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..b745071a0 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..8f47e37af Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0001.jit__threefry_seed.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.autotune_results.pbtxt.xz new file mode 100644 index 000000000..b5663b07d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..f78b74618 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..f11d89d87 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0002.jit_concatenate.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.autotune_results.pbtxt.xz new file mode 100644 index 000000000..b5663b07d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..038654e9d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..498863c48 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0003.jit__unnamed_wrapped_function_.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.autotune_results.pbtxt.xz new file mode 100644 index 000000000..b5663b07d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..1651df7ff Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..fe9bd0bf9 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0012.jit_raw_generate_synthetic_data.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.autotune_results.pbtxt.xz new file mode 100644 index 000000000..b5663b07d Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..2c9e0eca3 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..3d46498f3 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0013.jit_fold_in.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..ff751bf09 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..4b814f31e Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0016.jit_train_step.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..c8eab9f56 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..1699b49e8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0033.jit_cos.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..a849d8919 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..39459188b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0034.jit_add.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..8fcae1fac Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..b11ead536 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0035.jit_multiply.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..a48bdda93 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..81f31b4b3 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0036.jit_subtract.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..b03fa798a Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..5fc4360c7 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0037.jit_add.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.autotune_results.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.autotune_results.pbtxt.xz new file mode 100644 index 000000000..bdaa857c8 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.autotune_results.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.before_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.before_optimizations.hlo.pb.xz new file mode 100644 index 000000000..029cb3c26 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.before_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.gpu_target_config.pbtxt.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.gpu_target_config.pbtxt.xz new file mode 100644 index 000000000..dcb89aa9b Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.gpu_target_config.pbtxt.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.sm_9.0_gpu_after_optimizations.hlo.pb.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.sm_9.0_gpu_after_optimizations.hlo.pb.xz new file mode 100644 index 000000000..3c18451e1 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/dump/module_0038.jit__where.sm_9.0_gpu_after_optimizations.hlo.pb.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/nvtx_gpu_proj_trace/trace.parquet b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/nvtx_gpu_proj_trace/trace.parquet new file mode 100644 index 000000000..b23d01fe5 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/nvtx_gpu_proj_trace/trace.parquet differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profile.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profile.proto new file mode 100644 index 000000000..27aa904c4 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profile.proto @@ -0,0 +1,71 @@ +// This proto intends to match format expected by pprof tool. +syntax = "proto3"; + +package tensorflow.tfprof.pprof; + +message Profile { + repeated ValueType sample_type = 1; + repeated Sample sample = 2; + repeated Mapping mapping = 3; + repeated Location location = 4; + repeated Function function = 5; + repeated string string_table = 6; + int64 drop_frames = 7; + int64 keep_frames = 8; + int64 time_nanos = 9; + int64 duration_nanos = 10; + ValueType period_type = 11; + int64 period = 12; + repeated int64 comment = 13; + int64 default_sample_type = 14; +} + +message ValueType { + int64 type = 1; + int64 unit = 2; +} + +message Sample { + repeated uint64 location_id = 1; + repeated int64 value = 2; + repeated Label label = 3; +} + +message Label { + int64 key = 1; + int64 str = 2; + int64 num = 3; +} + +message Mapping { + uint64 id = 1; + uint64 memory_start = 2; + uint64 memory_limit = 3; + uint64 file_offset = 4; + int64 filename = 5; + int64 build_id = 6; + bool has_functions = 7; + bool has_filenames = 8; + bool has_line_numbers = 9; + bool has_inline_frames = 10; +} + +message Location { + uint64 id = 1; + uint64 mapping_id = 2; + uint64 address = 3; + repeated Line line = 4; +} + +message Line { + uint64 function_id = 1; + int64 line = 2; +} + +message Function { + uint64 id = 1; + int64 name = 2; + int64 system_name = 3; + int64 filename = 4; + int64 start_line = 5; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiled_instructions.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiled_instructions.proto new file mode 100644 index 000000000..277b9f4be --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiled_instructions.proto @@ -0,0 +1,33 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tensorflow.profiler; + +// Next ID: 3 +message ProfiledInstructionsProto { + message InstructionCost { + string name = 1; + double cost_us = 2; + } + message Latency { + string source = 1; + string target = 2; + double latency_us = 3; + } + repeated InstructionCost costs = 1; + repeated Latency latencies = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_analysis.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_analysis.proto new file mode 100644 index 000000000..16d938673 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_analysis.proto @@ -0,0 +1,81 @@ +syntax = "proto3"; + +package tensorflow; + +import "tsl/profiler/protobuf/profiler_service.proto"; + +message NewProfileSessionRequest { + ProfileRequest request = 1; + // The place where we will dump profile data. We will normally use + // MODEL_DIR/plugins/profile as the repository root. + string repository_root = 2; + repeated string hosts = 3; // host or host:port, port will be ignored. + string session_id = 4; +} + +message NewProfileSessionResponse { + // Auxiliary error_message. + string error_message = 1; + + // Whether all hosts had returned a empty trace. + bool empty_trace = 2; +} + +message EnumProfileSessionsAndToolsRequest { + string repository_root = 1; +} + +message ProfileSessionInfo { + string session_id = 1; + // Which tool data is available for consumption. + repeated string available_tools = 2; +} + +message EnumProfileSessionsAndToolsResponse { + // Auxiliary error_message. + string error_message = 1; + // If success, the returned sessions information are stored here. + repeated ProfileSessionInfo sessions = 2; +} + +message ProfileSessionDataRequest { + // The place where we will read profile data. We will normally use + // MODEL_DIR/plugins/profile as the repository root. + string repository_root = 1; + string session_id = 2; + // Which host the data is associated. if empty, data from all hosts are + // aggregated. + string host_name = 5; + // Which tool + string tool_name = 3; + // Tool's specific parameters. e.g. TraceViewer's viewport etc + map parameters = 4; +} + +message ProfileSessionDataResponse { + // Auxiliary error_message. + string error_message = 1; + + // Output format. e.g. "json" or "proto" or "blob" + string output_format = 2; + + // TODO(jiesun): figure out whether to put bytes or oneof tool specific proto. + bytes output = 3; +} +//////////////////////////////////////////////////////////////////////////////// +// ProfileAnalysis service provide entry point for profiling TPU and for +// serving profiled data to TensorBoard through GRPC +//////////////////////////////////////////////////////////////////////////////// +service ProfileAnalysis { + // Starts a profiling session, blocks until it completes. + // TPUProfileAnalysis service delegate this to TPUProfiler service. + // Populate the profiled data in repository, then return status to caller. + rpc NewSession(NewProfileSessionRequest) returns (NewProfileSessionResponse) { + } + // Enumerate existing sessions and return available profile tools. + rpc EnumSessions(EnumProfileSessionsAndToolsRequest) + returns (EnumProfileSessionsAndToolsResponse) {} + // Retrieve specific tool's data for specific session. + rpc GetSessionToolData(ProfileSessionDataRequest) + returns (ProfileSessionDataResponse) {} +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_options.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_options.proto new file mode 100644 index 000000000..687a2f101 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_options.proto @@ -0,0 +1,88 @@ +syntax = "proto3"; + +package tensorflow; + +// Next ID: 11 +message ProfileOptions { + // Some default value of option are not proto3 default value. Use this version + // to determine if we should use default option value instead of proto3 + // default value. + uint32 version = 5; + + enum DeviceType { + UNSPECIFIED = 0; + CPU = 1; + GPU = 2; + TPU = 3; + PLUGGABLE_DEVICE = 4; + } + + // Device type to profile/trace: (version >= 1) + // DeviceType::UNSPECIFIED: All registered device profiler will be enabled. + // DeviceType::CPU: only CPU will be profiled. + // DeviceType::GPU: only CPU/GPU will be profiled. + // DeviceType::TPU: only CPU/TPU will be profiled. + // DeviceType::PLUGGABLE_DEVICE: only CPU/pluggable devices with profilers + // will be profiled. + DeviceType device_type = 6; + + // We don't collect the dataset ops by default for better trace-viewer + // scalability. The caller can manually set this field to include the ops. + bool include_dataset_ops = 1; + + // Levels of host tracing: (version >= 1) + // - Level 0 is used to disable host traces. + // - Level 1 enables tracing of only user instrumented (or default) TraceMe. + // - Level 2 enables tracing of all level 1 TraceMe(s) and instrumented high + // level program execution details (expensive TF ops, XLA ops, etc). + // This is the default. + // - Level 3 enables tracing of all level 2 TraceMe(s) and more verbose + // (low-level) program execution details (cheap TF ops, etc). + uint32 host_tracer_level = 2; + + // Levels of device tracing: (version >= 1) + // - Level 0 is used to disable device traces. + // - Level 1 is used to enable device traces. + // - More levels might be defined for specific device for controlling the + // verbosity of the trace. + uint32 device_tracer_level = 3; + + // Whether enable python function calls tracing. Runtime overhead ensues if + // enabled. Default off. (version >= 1) + uint32 python_tracer_level = 4; + + // Whether serialize hlo_proto when XLA is used. (version >= 1) + bool enable_hlo_proto = 7; + + // The local profiler starts profiling at this Unix timestamp in nanoseconds. + uint64 start_timestamp_ns = 8; + + // The local profiler collects `duration_ms` milliseconds of data. If the + // value is 0, profiling continues until interrupted. + uint64 duration_ms = 9; + + // Directory to save profile data to. No-op when empty. + string repository_path = 10; +} + +// Options for remote profiler session manager. +// Next ID: 6 +message RemoteProfilerSessionManagerOptions { + // Options for each local profiler. + ProfileOptions profiler_options = 1; + + // List of servers to profile. Supported formats: host:port. + repeated string service_addresses = 2; + + // Unix timestamp of when the session was started. + uint64 session_creation_timestamp_ns = 3; + + // Maximum time (in milliseconds) a profiling session manager waits for all + // profilers to finish after issuing gRPC request. If value is 0, session + // continues until interrupted. Otherwise, value must be greater than + // profiler_options.duration_ms. + uint64 max_session_duration_ms = 4; + + // Start of profiling is delayed by this much (in milliseconds). + uint64 delay_ms = 5; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_service.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_service.proto new file mode 100644 index 000000000..67b747bb8 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_service.proto @@ -0,0 +1,121 @@ +syntax = "proto3"; + +package tensorflow; + +import "tsl/profiler/protobuf/profiler_options.proto"; +import "tsl/profiler/protobuf/profiler_service_monitor_result.proto"; + +// The ProfilerService service retrieves performance information about +// the programs running on connected devices over a period of time. +service ProfilerService { + // Starts a profiling session, blocks until it completes, and returns data. + rpc Profile(ProfileRequest) returns (ProfileResponse) {} + // Signal to terminate the Profile rpc for a on-going profiling session, + // The Profile rpc will return successfully and prematurely without timeout. + // This is used by programmatic mode to end the session in workers. + rpc Terminate(TerminateRequest) returns (TerminateResponse) {} + // Collects profiling data and returns user-friendly metrics. + rpc Monitor(MonitorRequest) returns (MonitorResponse) {} +} + +message ToolRequestOptions { + // Required formats for the tool, it should be one of "json", "proto", "raw" + // etc. If not specified (backward compatible), use default format, i.e. most + // tools use json format. + string output_formats = 2; + + // Whether save the result directly to repository or pass it back to caller. + // Default to false for backward compatibilities. + bool save_to_repo = 3; +} + +// Next-ID: 9 +message ProfileRequest { + // In future, the caller will be able to customize when profiling starts and + // stops. For now, it collects `duration_ms` milliseconds worth of data. + uint64 duration_ms = 1; + + // The maximum number of events to return. By default (value 0), return all + // events. + uint64 max_events = 2; + + // Required profiling tools name such as "input_pipeline_analyzer" etc + repeated string tools = 3; + + // Specifies the requirement for each tools. + map tool_options = 8; + + // Optional profiling options that control how a TF session will be profiled. + ProfileOptions opts = 4; + + // The place where we will dump profile data. We will normally use + // MODEL_DIR/plugins/profile/ as the repository root. + string repository_root = 5; + + // The user provided profile session identifier. + string session_id = 6; + + // The hostname of system where the profile should happen. + // We use it as identifier in part of our output filename. + string host_name = 7; + + // In future, the caller will indicate which TF session is being profiled, and + // only data relating to that program will be returned. For now, we assume + // all activity during the profiling period is relevant. +} + +message ProfileToolData { + // The file name which this data is associated (e.g. "input_pipeline.json", + // "cluster_xxx.memory_viewer.json"). + string name = 1; + + // The data payload (likely json) for the specific tool. + bytes data = 2; +} + +// Next-ID: 8 +message ProfileResponse { + // Data payload for each required tools. + repeated ProfileToolData tool_data = 6; + + // When we write profiling data directly to repository directory, we need a + // way to figure out whether the captured trace is empty. + bool empty_trace = 7; + + reserved 1, 2, 3, 4, 5; +} + +message TerminateRequest { + // Which session id to terminate. + string session_id = 1; +} + +message TerminateResponse {} + +// Next-ID: 4 +message MonitorRequest { + // Duration for which to profile between each update. + uint64 duration_ms = 1; + + // Indicates the level at which we want to monitor. Currently, two levels are + // supported: + // Level 1: An ultra lightweight mode that captures only some utilization + // metrics. + // Level 2: More verbose than level 1. Collects utilization metrics, device + // information, step time information, etc. Do not use this option if the TPU + // host is being very heavily used. + int32 monitoring_level = 2; + // True to display timestamp in monitoring result. + bool timestamp = 3; +} + +// Next-ID: 11 +message MonitorResponse { + // Properly formatted string data that can be directly returned back to user. + string data = 1; + + // A collection of monitoring results for each field show in data. + ProfilerServiceMonitorResult monitor_result = 10; + + reserved 2, 3, 4, 5, 6, 7, 8, 9; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_service_monitor_result.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_service_monitor_result.proto new file mode 100644 index 000000000..48ec2113e --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/profiler_service_monitor_result.proto @@ -0,0 +1,39 @@ +syntax = "proto3"; + +package tensorflow; + +message ProfilerServiceMonitorResult { + // Represents the different types of responses from the profiling service. + enum ResponseType { + // No result is returned from the profiling service. + EMPTY_RESULT = 0; + // Only device utilization is available. + UTIL_ONLY = 1; + // Both device utilization and device idle time are available. + UTIL_IDLE = 2; + // Device utilization, device idle time, step time, and infeed percentage + // are all available. + UTIL_IDLE_STEP = 3; + } + + // Type of profiling responses. + ResponseType response_type = 1; + // Percentage of time when device is idle. + double device_idle_time_percent = 2; + // TPU matrix unit utilization percentage. + double matrix_unit_utilization_percent = 3; + // Average step time in millisecond. + double step_time_ms_avg = 4; + // Minimum step time in millisecond. + double step_time_ms_min = 5; + // Maximum step time in millisecond. + double step_time_ms_max = 6; + // Average infeed percentage. + double infeed_percent_avg = 7; + // Minimum infeed percentage. + double infeed_percent_min = 8; + // Maximum infeed percentage. + double infeed_percent_max = 9; + + // next-field: 10 +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/trace_events.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/trace_events.proto new file mode 100644 index 000000000..2f7b3075f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/trace_events.proto @@ -0,0 +1,72 @@ +syntax = "proto3"; + +package tsl.profiler; + +option cc_enable_arenas = true; +option java_outer_classname = "TraceEventsProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/core_protos_go_proto"; + +// A 'Trace' contains metadata for the individual traces of a system. +message Trace { + // The devices that this trace has information about. Maps from device_id to + // more data about the specific device. + map devices = 1; + + // All trace events capturing in the profiling period. + repeated TraceEvent trace_events = 4; +} + +// A 'device' is a physical entity in the system and is comprised of several +// resources. +message Device { + // The name of the device. + string name = 1; + + // The id of this device, unique in a single trace. + uint32 device_id = 2; + + // The resources on this device, keyed by resource_id; + map resources = 3; +} + +// A 'resource' generally is a specific computation component on a device. These +// can range from threads on CPUs to specific arithmetic units on hardware +// devices. +message Resource { + // The name of the resource. + string name = 1; + + // The id of the resource. Unique within a device. + uint32 resource_id = 2; + + // The sort index of the resource. Resources within a device are ordered by + // this value. if absent, use resource id as sort index. + uint32 sort_index = 3; +} + +message TraceEvent { + // The id of the device that this event occurred on. The full dataset should + // have this device present in the Trace object. + uint32 device_id = 1; + + // The id of the resource that this event occurred on. The full dataset should + // have this resource present in the Device object of the Trace object. A + // resource_id is unique on a specific device, but not necessarily within the + // trace. + uint32 resource_id = 2; + + // The name of this trace event. + string name = 3; + + // The timestamp that this event occurred at (in picos since tracing started). + uint64 timestamp_ps = 9; + + // The duration of the event in picoseconds if applicable. + // Events without duration are called instant events. + uint64 duration_ps = 10; + + // Extra arguments that will be displayed in trace view. + map args = 11; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/xplane.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/xplane.proto new file mode 100644 index 000000000..f410ac114 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/profiler/protobuf/xplane.proto @@ -0,0 +1,156 @@ +syntax = "proto3"; + +package tensorflow.profiler; + +option cc_enable_arenas = true; + +// A container of parallel XPlanes, generated by one or more profiling sources. +// Next ID: 5 +message XSpace { + repeated XPlane planes = 1; + // Errors (if any) in the generation of planes. + repeated string errors = 2; + // Warnings (if any) in the generation of planes; + repeated string warnings = 3; + // List of hostnames that XPlanes are generated from. + repeated string hostnames = 4; +} + +// An XPlane is a container of parallel timelines (XLines), generated by a +// profiling source or by post-processing one or more XPlanes. +// Next ID: 7 +message XPlane { + int64 id = 1; + + // Name of this XPlane. + string name = 2; + + // Parallel timelines grouped in this plane. XLines with the same id + // are effectively the same timeline. + repeated XLine lines = 3; + + // XEventMetadata map, each entry uses the XEventMetadata.id as key. This map + // should be used for events that share the same ID over the whole XPlane. + map event_metadata = 4; + + // XStatMetadata map, each entry uses the XStatMetadata.id as key. This map + // should be used for stats that share the same ID over the whole XPlane. + map stat_metadata = 5; + + // XStats associated with this plane, e.g. device capabilities. + // Each of these XStats should have a different metadata_id. + repeated XStat stats = 6; +} + +// An XLine is a timeline of trace events (XEvents). +// Next ID: 12 +message XLine { + // Id of this line, can be repeated within an XPlane. All XLines with the + // same id are effectively the same timeline. + int64 id = 1; + + // Display id of this line. Multiple lines with the same display_id are + // grouped together in the same trace viewer row. + int64 display_id = 10; + + // Name of this XLine. + string name = 2; + + // Name of this XLine to display in trace viewer. + string display_name = 11; + + // Start time of this line in nanoseconds since the UNIX epoch. + // XEvent.offset_ps is relative to this timestamp. + int64 timestamp_ns = 3; + + // Profiling duration for this line in picoseconds. + int64 duration_ps = 9; + + // XEvents within the same XLine should not overlap in time, but they can be + // nested. + repeated XEvent events = 4; + + reserved 5, 6, 7, 8; +} + +// An XEvent is a trace event, optionally annotated with XStats. +// Next ID: 6 +message XEvent { + // XEventMetadata.id of corresponding metadata. + int64 metadata_id = 1; + + oneof data { + // Start time of the event in picoseconds, as offset from + // XLine.timestamp_ns(). + int64 offset_ps = 2; + + // Number of occurrences of the event, if aggregated. + int64 num_occurrences = 5; + } + + // Duration of the event in picoseconds. Can be zero for an instant event. + int64 duration_ps = 3; + + // XStats associated with the event. + // Each of these XStats should have a different metadata_id. + repeated XStat stats = 4; +} + +// An XStat is a named value associated with an XEvent, e.g., a performance +// counter value, a metric computed by a formula applied over nested XEvents +// and XStats. +// Next ID: 8 +message XStat { + // XStatMetadata.id of corresponding metadata. + int64 metadata_id = 1; + + // Value of this stat. + oneof value { + double double_value = 2; + uint64 uint64_value = 3; + int64 int64_value = 4; + string str_value = 5; + bytes bytes_value = 6; + // A string value that stored in XStatMetadata::name. + uint64 ref_value = 7; + } +} + +// Metadata for an XEvent, corresponds to an event type and is shared by +// all XEvents with the same metadata_id. +// Next ID: 7 +message XEventMetadata { + // XPlane.event_metadata map key. + int64 id = 1; + + // Name of the event. + string name = 2; + + // Name of the event shown in trace viewer. + string display_name = 4; + + // Additional metadata in serialized format. + bytes metadata = 3; + + // XStats that are constant for all XEvents with the same metadata_id. + // Each of these XStats should have a different metadata_id. + repeated XStat stats = 5; + + // XPlane.event_metadata map key for children events. + repeated int64 child_id = 6; +} + +// Metadata for an XStat, corresponds to a stat type and is shared by all +// XStats with the same metadata_id. +// Next ID: 4 +message XStatMetadata { + // XPlane.stat_metadata map key. + int64 id = 1; + + // Name of the stat (should be short). + // Two XStatMetadata with different id should have different names. + string name = 2; + + // Description of the stat (might be long). + string description = 3; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/bfc_memory_map.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/bfc_memory_map.proto new file mode 100644 index 000000000..bca45cbf3 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/bfc_memory_map.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +package tensorflow; + +option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; + +// Some of the data from AllocatorStats +message MemAllocatorStats { + int64 num_allocs = 1; + int64 bytes_in_use = 2; + int64 peak_bytes_in_use = 3; + int64 largest_alloc_size = 4; + float fragmentation_metric = 5; +} + +message MemChunk { + uint64 address = 1; + int64 size = 2; + int64 requested_size = 3; + int32 bin = 4; + string op_name = 5; + uint64 freed_at_count = 6; + uint64 action_count = 7; + bool in_use = 8; + uint64 step_id = 9; +} + +message BinSummary { + int32 bin = 1; + int64 total_bytes_in_use = 2; + int64 total_bytes_in_bin = 3; + int64 total_chunks_in_use = 4; + int64 total_chunks_in_bin = 5; +} + +message SnapShot { + uint64 action_count = 1; + int64 size = 2; +} + +message MemoryDump { + string allocator_name = 1; + repeated BinSummary bin_summary = 2; + repeated MemChunk chunk = 3; + repeated SnapShot snap_shot = 4; + MemAllocatorStats stats = 5; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/coordination_config.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/coordination_config.proto new file mode 100644 index 000000000..035a49e6f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/coordination_config.proto @@ -0,0 +1,70 @@ +syntax = "proto3"; + +package tensorflow; + +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; + +// Represents a job type and the number of tasks under this job. +// For example, ("worker", 20) implies that there will be 20 worker tasks. +message CoordinatedJob { + string name = 1; + int32 num_tasks = 2; +} + +// Coordination service configuration parameters. +// The system picks appropriate values for fields that are not set. +message CoordinationServiceConfig { + // Type of coordination service implementation to enable. + // For example, setting the service type as "standalone" starts a service + // instance on the leader task to provide the coordination services such as + // heartbeats and consistent key-value store. + string service_type = 1; + + // Address where the coordination service instance is hosted. + string service_leader = 2; + + // Whether to enable the health check mechanism. + bool enable_health_check = 3; + + // Maximum wait time for all members in the cluster to be registered. + int64 cluster_register_timeout_in_ms = 4; + + // Heartbeat timeout, if a task does not record heartbeat in this time + // window, it will be considered disconnected. + // Note: This is also used as a grace period to accept any heartbeats after + // the agent has disconnected, to account for the lag time between the service + // recording the state change and the agent stopping heartbeats. + int64 heartbeat_timeout_in_ms = 5; + + // The list of `CoordinatedJob`s that will register in coordination service. + reserved 6; + repeated CoordinatedJob coordinated_job_list = 10; + + // Denotes how long to wait for all coordination agents to reach the barriers + // (after the first shutdown request) before disconnecting together. If + // set to 0, no barrier is imposed upon shutdown and each worker can + // disconnect individually. + int64 shutdown_barrier_timeout_in_ms = 7; + + // If set, agents do not make an explicit Shutdown() call. Service will only + // find out about the disconnecte agent via stale heartbeats. Used for + // testing. + bool agent_destruction_without_shutdown = 8; + + // The list of jobs which are recoverable. If a task in this list fails, + // it will not propagate error to other tasks. + // If empty, no jobs will be recoverable and every task failure will cause + // error propagation to other tasks. + repeated string recoverable_jobs = 9; + + // If a task restarts with a new incarnation, we may allow it to reconnect + // silently. This is useful when we know that a task can immediately resume + // work upon re-connecting to the service. + bool allow_new_incarnation_to_reconnect = 11; + + // Disables coordination service. + // Some libraries enable coordination service by default even if the user did + // not specify any config. This field allows users to explicitly disable + // coordination service under all situations. + bool force_disable = 12; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/coordination_service.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/coordination_service.proto new file mode 100644 index 000000000..2f7cc804c --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/coordination_service.proto @@ -0,0 +1,345 @@ +syntax = "proto3"; + +package tensorflow; + +import "google/protobuf/any.proto"; + +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto"; + +// Represents a remote worker task, specified by job name and task id. +message CoordinatedTask { + string job_name = 1; + int32 task_id = 2; +} + +// Represents the state of a remote worker +enum CoordinatedTaskState { + // TASKSTATE_UNSPECIFIED is an invalid state such that indicates a bug. + TASKSTATE_UNSPECIFIED = 0; + // TASKSTATE_UNINITIALIZED is an agent-only state. While the agent is + // disconnected, the service has no way of knowing if the task is + // initialized/uninitialized. + TASKSTATE_UNINITIALIZED = 1; + TASKSTATE_DISCONNECTED = 2; + TASKSTATE_CONNECTED = 3; + TASKSTATE_ERROR = 4; +} + +// Status payload for all coordination service errors. +// Note: an empty proto may be set if the error is triggered by the task's own +// agent calls (i.e. not propagated by the service from another remote task). +message CoordinationServiceError { + // Removed fields which used to specify the error origin. + reserved 1, 2; + // If true, error is reported via the agent API by the user (and not an + // internal service error). + bool is_reported_error = 3; + // Denotes which task hit the error. If unset, the error originated from the + // same task that is processing this error. + CoordinatedTask source_task = 4; +} + +message CoordinatedTaskStateInfo { + CoordinatedTask task = 1; + CoordinatedTaskState state = 2; + int32 error_code = 3; + string error_message = 4; + CoordinationServiceError error_payload = 5; +} + +// Placeholder message to be extended by other runtimes' device representations. +message DeviceInfo { + repeated google.protobuf.Any device = 1; +} + +// Request and response messages for registering a task to the cluster leader. +// A task is uniquely represented by its `job_name`, `task_id` and +// `incarnation`. Leader responds with its `incarnation` to identify a leader +// process. +message RegisterTaskRequest { + // Removed fields which used to specify the task. + reserved 1, 2; + fixed64 incarnation = 3; + // Moved the field `local_device_attributes` from this request message to + // WaitForAllTasksRequest defined below. + reserved 4; + CoordinatedTask source_task = 5; +} + +message RegisterTaskResponse { + fixed64 leader_incarnation = 1; +} + +// Request and response messages for sending heartbeats. +message HeartbeatRequest { + // Removed fields which used to specify the remote task. + reserved 1, 2; + fixed64 incarnation = 3; + CoordinatedTask source_task = 4; +} + +message HeartbeatResponse { + fixed64 leader_incarnation = 1; + // If there are failures in cluster, use additional metadata in response to + // broadcast error code and message to other tasks. +} + +// Request and response messages for waiting for all tasks. +message WaitForAllTasksRequest { + // Removed fields which used to specify the remote task. + reserved 1, 2; + // Removed field that specifically used TF device info. + reserved 3, 4; + CoordinatedTask source_task = 5; + // All local device attributes on the request sender; + DeviceInfo device_info = 6; +} + +message WaitForAllTasksResponse { + fixed64 leader_incarnation = 1; + // Removed field that specifically used TF device info. + reserved 2, 3; + // All devices in the cluster. + DeviceInfo device_info = 4; +} + +// Request and response messages for disconnecting a task from the service. +message ShutdownTaskRequest { + CoordinatedTask source_task = 1; +} + +message ShutdownTaskResponse {} + +// Request and response messages for resetting a task state in the service. +message ResetTaskRequest { + CoordinatedTask source_task = 1; +} + +message ResetTaskResponse {} + +// Request and response messages for reporting errors to task. +message ReportErrorToTaskRequest { + int32 error_code = 1; + string error_message = 2; + // Removed fields that are embedded in payload. + reserved 3, 4; + CoordinationServiceError error_payload = 5; +} + +message ReportErrorToTaskResponse {} + +// Request and response messages for reporting errors to service instance. +message ReportErrorToServiceRequest { + int32 error_code = 1; + string error_message = 2; + // Removed fields which used to specify the error origin. + reserved 3, 4; + CoordinatedTask error_origin = 5; +} + +message ReportErrorToServiceResponse {} + +// Request and response messages for getting state of a remote task. +message GetTaskStateRequest { + repeated CoordinatedTask source_task = 1; +} + +message GetTaskStateResponse { + repeated CoordinatedTaskStateInfo task_state = 1; +} + +// Message for configuration key value. +// Key is structured like Unix file system, with multiple levels of directory +// names separated by the slash ('/') characters. +message KeyValueEntry { + string key = 1; + bytes value = 2; +} + +// Request and response messages for inserting configuration key-value data. +message InsertKeyValueRequest { + KeyValueEntry kv = 1; + bool allow_overwrite = 2; +} + +message InsertKeyValueResponse {} + +// Request and response messages for getting configuration key-value data. +message GetKeyValueRequest { + string key = 1; +} + +message GetKeyValueResponse { + KeyValueEntry kv = 1; +} + +message TryGetKeyValueRequest { + string key = 1; +} + +message TryGetKeyValueResponse { + KeyValueEntry kv = 1; +} + +message GetKeyValueDirRequest { + string directory_key = 1; +} + +message GetKeyValueDirResponse { + string directory_key = 1; + repeated KeyValueEntry kv = 2; +} + +// Request and response messages for deleting configuration key-value data. +// When is_directory is true, delete key-values recursively under `key`. +message DeleteKeyValueRequest { + string key = 1; + bool is_directory = 2; +} + +message DeleteKeyValueResponse {} + +// Request and response messages for generic sync barriers. +message BarrierRequest { + string barrier_id = 1; + int64 barrier_timeout_in_ms = 2; + // Denotes list of tasks that will wait for the barrier. If unspecified, it + // implies that the entire cluster is participating in the barrier. + repeated CoordinatedTask tasks = 3; + // Task that is making the request. + CoordinatedTask source_task = 4; +} + +message BarrierResponse {} + +// Request and response messages for cancelling generic sync barriers. +message CancelBarrierRequest { + string barrier_id = 1; + // Task that is making the request. + CoordinatedTask source_task = 2; +} + +message CancelBarrierResponse {} + +// Coordination Service defines a TensorFlow service that controls and +// coordinates distributed execution in a cluster of multiple tasks. +// +// The service keeps track of the cluster configuration and the state of cluster +// members or the leader depending on the role of the current task. The +// distributed runtime leverages this service to coordinate and perform cluster +// initialization, check the healthiness of tasks, and propagate error +// messages to the cluster. +service CoordinationService { + // Register task to coordination service so that the service starts to track + // liveness of the task. RPC blocks and returns only when it registers to + // the service successfully, or error happens in the registering process. + rpc RegisterTask(RegisterTaskRequest) returns (RegisterTaskResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Heartbeat message from task to coordination service. Heartbeat is sent from + // a task to refresh its timestamp on leader to avoid it becoming stale. + // RPC responds immediately after refreshing the timestamp on leader. + rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Wait for all tasks in the cluster to be up and running. The RPC request + // only gets responded when all tasks have registered, or some error occurs. + rpc WaitForAllTasks(WaitForAllTasksRequest) returns (WaitForAllTasksResponse); + + // Disconnects task from the service. If `shutdown_barrier_timeout_in_ms` is + // specified in the config, blocks until all tasks reach the barrier before + // disconnecting together. If the barrier times out, tasks at the barrier will + // still disconnect, while an error is reported to tasks that did not reach + // the barrier on time. + rpc ShutdownTask(ShutdownTaskRequest) returns (ShutdownTaskResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Disconnects task from the service if it is in an ERROR state, thereby + // allowing it to reconnect via RegisterTask() in the future. + rpc ResetTask(ResetTaskRequest) returns (ResetTaskResponse); + + // Report error to the task. RPC sets the receiving instance of coordination + // service agent to error state permanently. + // TODO(b/195990880): Consider splitting this into a different RPC service. + rpc ReportErrorToTask(ReportErrorToTaskRequest) + returns (ReportErrorToTaskResponse); + + // Report task error to coordination service. RPC sets the service-side task + // state to error, and propagate the error to other tasks in the cluster. + rpc ReportErrorToService(ReportErrorToServiceRequest) + returns (ReportErrorToServiceResponse); + + // Get the state of a remote task. Specifically, RPC returns a + // CoordinatedTaskState, and if the task is in an error status, returns a + // non-OK error code, non-empty error message and error payload. + rpc GetTaskState(GetTaskStateRequest) returns (GetTaskStateResponse); + + // Insert configuration key-value that will be accessible to all cluster + // tasks. The key can be formatted as Unix file path with hierarchy. The + // coordination service key-value store should only be used for cluster + // configuration data. + rpc InsertKeyValue(InsertKeyValueRequest) returns (InsertKeyValueResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Get configuration key-value. The request blocks until the key-value data + // becomes available (i.e., set by a task in the cluster). + rpc GetKeyValue(GetKeyValueRequest) returns (GetKeyValueResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Get configuration key-value. The request does not block, but returns an + // error if the requested key does not exist. + rpc TryGetKeyValue(TryGetKeyValueRequest) returns (TryGetKeyValueResponse); + + // Same as GetKeyValue, but returns all values that have keys which are + // prefixed with the directory key. + rpc GetKeyValueDir(GetKeyValueDirRequest) returns (GetKeyValueDirResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Delete configuration key-value. If is_directory is set in request, + // recursively clean up all key-values under the path specified by `key`. + rpc DeleteKeyValue(DeleteKeyValueRequest) returns (DeleteKeyValueResponse); + + // Blocks until all (or a subset of) tasks are at the barrier or the barrier + // fails. + // + // `barrier_id` should be unique across barriers. Once the barrier has passed + // or failed, subsequent calls will not block, and immediately respond with + // the previous response. + // + // The first WaitAtBarrier() call received by the service for a particular + // barrier id is special in that it determines the barrier deadline based on + // timeout duration. + // However, if subsequent calls by different agents specify a different set of + // `tasks` for the same `barrier_id`, the barrier will fail instantly. + // + // If no tasks are specified (default), the barrier will block for all the + // connected tasks. + // + // Possible service errors: + // - DeadlineExceeded: Timed out waiting for specified tasks at the barrier. + // Deadline is determined by the server timestamp when it receives the + // first WaitAtBarrier() + timeout duration. + // - Cancelled: One of the tasks called CancelBarrier(). + // - Aborted: Service is shutting down. + // - Internal: Any participating task is in ERROR state. + // - InvalidArgument: (1) Conflicting tasks specified by different agents + // for the same barrier, (2) one of the participating tasks is not in + // the cluster, or (3) task making the request is not included in the + // list of participating tasks. + rpc Barrier(BarrierRequest) returns (BarrierResponse) { + // [AUTOMATION]: Internal rpc option goes here. + } + + // Aborts the barrier if it is ongoing. + // Current and future WaitAtBarrier() calls with the same id will return a + // CANCELLED error status. + // Possible service errors: + // - FailedPrecondition: Barrier has already been passed. + rpc CancelBarrier(CancelBarrierRequest) returns (CancelBarrierResponse); +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/distributed_runtime_payloads.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/distributed_runtime_payloads.proto new file mode 100644 index 000000000..3a2aecdd2 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/distributed_runtime_payloads.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package tensorflow.distributed_runtime; + +option cc_enable_arenas = true; +option go_package = "github.com/tsl/tsl/go/core/protobuf/for_core_protos_go_proto"; + +// Used to serialize and transmit tensorflow::Status payloads through +// grpc::Status `error_details` since grpc::Status lacks payload API. +// TODO(b/204231601): Use GRPC API once supported. +message GrpcPayloadContainer { + map payloads = 1; +} + +// If included as a payload, this message flags the Status to have lost payloads +// during the GRPC transmission. +// URI: "type.googleapis.com/tensorflow.distributed_runtime.GrpcPayloadsLost" +message GrpcPayloadsLost {} + +// If included as a payload, this message flags the Status to be a possible +// outcome of a worker restart. +// URI: +// "type.googleapis.com/tensorflow.distributed_runtime.WorkerPossiblyRestarted" +message WorkerPossiblyRestarted {} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/dnn.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/dnn.proto new file mode 100644 index 000000000..695db935f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/dnn.proto @@ -0,0 +1,203 @@ +// LINT: LEGACY_NAMES +syntax = "proto3"; + +package stream_executor.dnn; + +import "google/protobuf/wrappers.proto"; + +option go_package = "github.com/google/tsl/tsl/go/stream_executor"; + +// Specifies the data type used by an operation. +enum DataType { + kFloat = 0; + kDouble = 1; + kHalf = 2; + kInt8 = 3; + kInt32 = 4; + kComplexFloat = 5; + kComplexDouble = 6; + kBF16 = 7; + kF8E5M2 = 8; + kF8E4M3FN = 9; + kF8E5M2FNUZ = 10; + kF8E4M3FNUZ = 11; + kInt64 = 12; +} + +// Describes how a convolution input or output layer's data is formatted. +enum DataLayout { + // Naming convention: + // Y <-> row or height + // X <-> column or width + // Batch <-> batch, or N + // Depth <-> feature, or channel + // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. + // + // Note: In cudnn, kBatchDepthYX4 and kBatchDepthYX32 are the same layout + // (namely, NCHW_VECT_C). It differentiates between these two by using a + // different data type (int8x4 vs int8x32). In StreamExecutor we use + // different layouts for these, because we don't usually pass an explicit data + // type to StreamExecutor functions. + kYXDepthBatch = 0; + kYXBatchDepth = 1; + kBatchYXDepth = 2; // cuDNN's NHWC layout + kBatchDepthYX = 3; // cuDNN's NCHW layout + kBatchDepthYX4 = 4; // cuDNN's NCHW_VECT_C with 4-elem vectors (e.g. int8x4) + kBatchDepthYX32 = 5; // cuDNN's NCHW_VECT_C with 32-elem vects (e.g. int8x32) +} + +// Describes how a convolution filter is laid out in the memory. +enum FilterLayout { + // Naming convention: + // Y <-> row or height + // X <-> column or width + // Output <-> output feature, or N + // Input <-> input feature, or N + // TODO(timshen): turn them into cuDNN names, e.g. kNCHW. + kOutputInputYX = 0; // cuDNN's NCHW layout + kOutputYXInput = 1; // cuDNN's NHWC layout + kOutputInputYX4 = 2; // cuDNN's NCHW_VECT_C layout with 4-elem vectors + kOutputInputYX32 = 5; // cuDNN's NCHW_VECT_C layout with 32-elem vectors + // cuDNN-specific filter reordering (using `cudnnReorderFilterAndBias`) + // When the filter is reordered, so is the bias (if present). + kOutputInputYX32_CudnnReordered = 6; + kInputYXOutput = 3; + kYXInputOutput = 4; +} + +// Describes a kind of non-linearity (threshold-like mathematical function). +enum ActivationMode { + kNone = 0; + kSigmoid = 1; + // Rectified linear activation: f(x) = x < 0 ? 0 : x + kRelu = 2; + // Rectified linear activation; where upper maximum is 6.0. + kRelu6 = 3; + // Rectified linear activation; where upper maximum specified by + // BatchDescriptor::value_max(). + kReluX = 4; + kTanh = 5; + // Like ReluX; but passes all values in the range [-X,X]. + kBandPass = 6; + // Exponential linear activation: f(x) = x < 0 ? e^x - 1 : x + kElu = 7; + // Leaky Rectified linear activation: f(x) = x < 0 ? alpha * x : x + kLeakyRelu = 8; + // Gaussian Error linear unit activation: + // x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))), where P(X) ~ N(0, 1). + kGeluExact = 9; +} + +// Describe the math definition for the conv op. The popular behavior is +// actually called cross-correlation in math, despite the operation is often +// referred as convolution. See cuDNN cudnnConvolutionMode_t. +enum ConvolutionMode { + CROSS_CORRELATION = 0; + CONVOLUTION = 1; +} + +enum ConvolutionKind { + INVALID = 0; + FORWARD = 1; + BACKWARD_FILTER = 2; + BACKWARD_DATA = 3; + FORWARD_BIAS_ACTIVATION = 4; + FORWARD_GRAPH = 5; +} + +// Generic tensor representation. +message TensorDescriptorProto { + repeated int64 dimensions = 1; + DataType data_type = 2; + oneof layout_oneof { + DataLayout data_layout = 3; + FilterLayout filter_layout = 4; + } +} + +// Generic algorithm representation. +message AlgorithmProto { + enum MathType { + DEFAULT_MATH = 0; + // The GPU may operate 4x4 matrix FMA. + // See cuDNN's documentation for CUDNN_TENSOR_OP_MATH. + TENSOR_OP_MATH = 1; + } + int64 algo_id = 1; + MathType math_type = 2; + reserved 3; + + map tuning_knobs = 4; + // Legacy algorithm enums and cuDNN Frontend engine numbers need to coexist in + // the same proto medium-term, until we can be confident of no longer needing + // the legacy cuDNN convolution API. Once the migration is complete, we can + // stop producing legacy algorithm enums and remove this field. + bool is_cudnn_frontend = 5; + + // For ROCm only, it's impossible to re-query the required workspace size + // after running the algorithm search, so we must store the workspace size + // along with the choice of algorithm. For consistency and convenience, + // cuDNN uses this field in the same way, even though it would be possible to + // re-query the workspace size from cuDNN at each use. + // + // Since this message is persisted in files, we need to be able to distinguish + // 0 workspace size from unknown workspace size in an old message, so this is + // a message field. + google.protobuf.UInt64Value workspace_size = 6; +} + +// Proto definition of AlgorithmConfig in "dnn.h". +// TODO(ruochengw): After cl/380702564 is submitted, add support for algorithm +// configs with cuDNN Frontend APIs. +message AlgorithmConfigProto { + // Use oneof to emulate optional semantics in proto2 since older + // version of proto3 cannot distinguish "unset field" and "default field". + oneof optional_algorithm { + AlgorithmProto algorithm = 1; + } + oneof optional_algorithm_no_scratch { + AlgorithmProto algorithm_no_scratch = 2; + } + oneof optional_scratch_size { + int64 scratch_size = 3; + } +} + +// Convolution-specific parameters. +message ConvolutionDescriptorProto { + repeated int64 paddings = 1; + repeated int64 strides = 2; + repeated int64 dilations = 3; + // The "accumulator" type. For example, use F32 as an accumulator for F16 + // convolutions. + // See cuDNN's cudnnConvolutionMode_t. + DataType compute_mode = 4; + // See cuDNN's group count. + int32 group_count = 5; + ConvolutionMode convolution_mode = 6; + // Tensorflow node name, same as in NodeDef, for debugging purposes. + string name = 7; +} + +// NormKind kind +enum NormKind { + LAYER_FWD_INFER = 0; + LAYER_FWD_TRAIN = 1; + LAYER_BWD = 2; +} + +// FusedMHAKind kind +enum FusedMHAKind { + BMM1_OUTPUT_UNKNOWN = 0; + BMM1_OUTPUT_INPUT_TYPE = 1; + BMM1_OUTPUT_FLOAT = 2; +} + +// FusedMHAMaskKind kind +enum FMHAMaskKind { + NO_MASK = 0; + PADDING = 1; + CAUSAL = 2; + PADDING_CAUSAL = 3; + ALIBI = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/error_codes.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/error_codes.proto new file mode 100644 index 000000000..c873d5588 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/error_codes.proto @@ -0,0 +1,155 @@ +syntax = "proto3"; + +// TODO(b/247876220): Change package and java_package once we figure out how to +// migrate. + +package tensorflow.error; + +option cc_enable_arenas = true; +option java_outer_classname = "ErrorCodesProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; + +// The canonical error codes for TensorFlow APIs. +// +// Warnings: +// +// - Do not change any numeric assignments. +// - Changes to this list should only be made if there is a compelling +// need that can't be satisfied in another way. Such changes +// must be approved by at least two OWNERS. +// - These error codes must match gRPC and protobuf error codes (except for +// DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_). +// +// Sometimes multiple error codes may apply. Services should return +// the most specific error code that applies. For example, prefer +// OUT_OF_RANGE over FAILED_PRECONDITION if both codes apply. +// Similarly prefer NOT_FOUND or ALREADY_EXISTS over FAILED_PRECONDITION. +enum Code { + // Not an error; returned on success + OK = 0; + + // The operation was cancelled (typically by the caller). + CANCELLED = 1; + + // Unknown error. An example of where this error may be returned is + // if a Status value received from another address space belongs to + // an error-space that is not known in this address space. Also + // errors raised by APIs that do not return enough error information + // may be converted to this error. + UNKNOWN = 2; + + // Client specified an invalid argument. Note that this differs + // from FAILED_PRECONDITION. INVALID_ARGUMENT indicates arguments + // that are problematic regardless of the state of the system + // (e.g., a malformed file name). + INVALID_ARGUMENT = 3; + + // Deadline expired before operation could complete. For operations + // that change the state of the system, this error may be returned + // even if the operation has completed successfully. For example, a + // successful response from a server could have been delayed long + // enough for the deadline to expire. + DEADLINE_EXCEEDED = 4; + + // Some requested entity (e.g., file or directory) was not found. + // For privacy reasons, this code *may* be returned when the client + // does not have the access right to the entity. + NOT_FOUND = 5; + + // Some entity that we attempted to create (e.g., file or directory) + // already exists. + ALREADY_EXISTS = 6; + + // The caller does not have permission to execute the specified + // operation. PERMISSION_DENIED must not be used for rejections + // caused by exhausting some resource (use RESOURCE_EXHAUSTED + // instead for those errors). PERMISSION_DENIED must not be + // used if the caller can not be identified (use UNAUTHENTICATED + // instead for those errors). + PERMISSION_DENIED = 7; + + // The request does not have valid authentication credentials for the + // operation. + UNAUTHENTICATED = 16; + + // Some resource has been exhausted, perhaps a per-user quota, or + // perhaps the entire file system is out of space. + RESOURCE_EXHAUSTED = 8; + + // Operation was rejected because the system is not in a state + // required for the operation's execution. For example, directory + // to be deleted may be non-empty, an rmdir operation is applied to + // a non-directory, etc. + // + // A litmus test that may help a service implementor in deciding + // between FAILED_PRECONDITION, ABORTED, and UNAVAILABLE: + // (a) Use UNAVAILABLE if the client can retry just the failing call. + // (b) Use ABORTED if the client should retry at a higher-level + // (e.g., restarting a read-modify-write sequence). + // (c) Use FAILED_PRECONDITION if the client should not retry until + // the system state has been explicitly fixed. E.g., if an "rmdir" + // fails because the directory is non-empty, FAILED_PRECONDITION + // should be returned since the client should not retry unless + // they have first fixed up the directory by deleting files from it. + // (d) Use FAILED_PRECONDITION if the client performs conditional + // REST Get/Update/Delete on a resource and the resource on the + // server does not match the condition. E.g., conflicting + // read-modify-write on the same resource. + FAILED_PRECONDITION = 9; + + // The operation was aborted, typically due to a concurrency issue + // like sequencer check failures, transaction aborts, etc. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + ABORTED = 10; + + // Operation tried to iterate past the valid input range. E.g., seeking or + // reading past end of file. + // + // Unlike INVALID_ARGUMENT, this error indicates a problem that may + // be fixed if the system state changes. For example, a 32-bit file + // system will generate INVALID_ARGUMENT if asked to read at an + // offset that is not in the range [0,2^32-1], but it will generate + // OUT_OF_RANGE if asked to read from an offset past the current + // file size. + // + // There is a fair bit of overlap between FAILED_PRECONDITION and + // OUT_OF_RANGE. We recommend using OUT_OF_RANGE (the more specific + // error) when it applies so that callers who are iterating through + // a space can easily look for an OUT_OF_RANGE error to detect when + // they are done. + OUT_OF_RANGE = 11; + + // Operation is not implemented or not supported/enabled in this service. + UNIMPLEMENTED = 12; + + // Internal errors. Means some invariant expected by the underlying + // system has been broken. If you see one of these errors, + // something is very broken. + INTERNAL = 13; + + // The service is currently unavailable. This is a most likely a + // transient condition and may be corrected by retrying with + // a backoff. + // + // See litmus test above for deciding between FAILED_PRECONDITION, + // ABORTED, and UNAVAILABLE. + UNAVAILABLE = 14; + + // Unrecoverable data loss or corruption. + DATA_LOSS = 15; + + // An extra enum entry to prevent people from writing code that + // fails to compile when a new code is added. + // + // Nobody should ever reference this enumeration entry. In particular, + // if you write C++ code that switches on this enumeration, add a default: + // case instead of a case that mentions this enumeration entry. + // + // Nobody should rely on the value (currently 20) listed here. It + // may change in the future. + DO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_ = 20; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/histogram.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/histogram.proto new file mode 100644 index 000000000..2a5f6d936 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/histogram.proto @@ -0,0 +1,26 @@ +syntax = "proto3"; + +package tensorflow; + +option cc_enable_arenas = true; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/google/tsl/tsl/go/core/protobuf/summary_go_proto"; + +// Serialization format for histogram module in +// tsl/lib/histogram/histogram.h +message HistogramProto { + double min = 1; + double max = 2; + double num = 3; + double sum = 4; + double sum_squares = 5; + + // Parallel arrays encoding the bucket boundaries and the bucket values. + // bucket(i) is the count for the bucket i. The range for + // a bucket is: + // i == 0: -DBL_MAX .. bucket_limit(0) + // i != 0: bucket_limit(i-1) .. bucket_limit(i) + repeated double bucket_limit = 6 [packed = true]; + repeated double bucket = 7 [packed = true]; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/rpc_options.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/rpc_options.proto new file mode 100644 index 000000000..35c5dbe3b --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/rpc_options.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +package tensorflow; + +option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; + +// RPC options for distributed runtime. +message RPCOptions { + // If true, always use RPC to contact the session target. + // + // If false (the default option), TensorFlow may use an optimized + // transport for client-master communication that avoids the RPC + // stack. This option is primarily for used testing the RPC stack. + bool use_rpc_for_inprocess_master = 1; + + // The compression algorithm to be used. One of "deflate", "gzip". + string compression_algorithm = 2; + + // If compression_algorithm is set, the compression level to be used. + // From 0 (no compression), up to 3. + int32 compression_level = 3; + + // Setting cache_rpc_response to true will enable sender side caching of + // response for RecvTensorAsync and RecvBufAsync to allow receiver to retry + // requests . This is only necessary when the network fabric is experiencing a + // significant error rate. Without it we'll fail a step on an network error, + // while with it we'll be able to complete long steps (like complex + // initializations) in the face of some network errors during RecvTensor. + bool cache_rpc_response = 4; + + // Disables TCP connection sharing when opening a new RPC channel. + bool disable_session_connection_sharing = 5; + + // Setting num_channels_per_target > 0 allows uses of multiple channels to + // communicate to the same target. This can be used to improve the aggregate + // throughput on high speed links (e.g 100G) where single connection is not + // sufficient to maximize link utilization. Note that a single RPC only goes + // on a single channel, this only helps in situations where there are multiple + // transfers to the same target overlapping in time. + int32 num_channels_per_target = 6; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/status.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/status.proto new file mode 100644 index 000000000..09d722189 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/status.proto @@ -0,0 +1,20 @@ +syntax = "proto3"; + +package tensorflow; + +import "tsl/protobuf/error_codes.proto"; + +option cc_enable_arenas = true; +option java_multiple_files = true; +option java_package = "org.tensorflow.framework"; +option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; + +// Wire-format for Status. +// Next tag: 3 +message StatusProto { + // Status code as defined in tensorflow/tsl/protobuf/error_codes.proto. + error.Code code = 1; + + // Detail error message. + string message = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/test_log.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/test_log.proto new file mode 100644 index 000000000..6d3af02e6 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/tsl/protobuf/test_log.proto @@ -0,0 +1,223 @@ +// Protocol messages for describing the results of benchmarks and unit tests. +syntax = "proto3"; + +package tensorflow; + +import "google/protobuf/any.proto"; +import "google/protobuf/wrappers.proto"; + +option cc_enable_arenas = true; +option java_outer_classname = "TestLogProtos"; +option java_multiple_files = true; +option java_package = "org.tensorflow.util.testlog"; + +message EntryValue { + oneof kind { + double double_value = 1; + string string_value = 2; + } +} + +message MetricEntry { + // Metric name + string name = 1; + + // Metric value + double value = 2; + + // The minimum acceptable value for the metric if specified + google.protobuf.DoubleValue min_value = 3; + + // The maximum acceptable value for the metric if specified + google.protobuf.DoubleValue max_value = 4; +} + +// Each unit test or benchmark in a test or benchmark run provides +// some set of information. Here we provide some reasonable keys +// one would expect to see, with optional key/value pairs for things +// we haven't considered. +// +// This BenchmarkEntry should be emitted by each unit test or benchmark +// reporter. +message BenchmarkEntry { + // The name of the specific benchmark or test + // (e.g. BM_AdjustContrast_gpu_B_W_H) + string name = 1; + + // If a benchmark, how many iterations it was run for + int64 iters = 2; + + // Total cpu time used for all iterations (in seconds) + double cpu_time = 3; + + // Total wall time used for all iterations (in seconds) + double wall_time = 4; + + // Throughput (in MB/s) + double throughput = 5; + + // Generic map from result key to value. + map extras = 6; + + // Metric name, value and expected range. This can include accuracy metrics + // typically used to determine whether the accuracy test has passed + repeated MetricEntry metrics = 7; +} + +message BenchmarkEntries { + repeated BenchmarkEntry entry = 1; +} + +message BuildConfiguration { + string mode = 1; // opt, dbg, etc + repeated string cc_flags = 2; // CC compiler flags, if known + repeated string opts = 3; // Bazel compilation options, if known +} + +message CommitId { + oneof kind { + // Submitted changelist. + int64 changelist = 1; + string hash = 2; + } + // Hash of intermediate change between hash/changelist and what was tested. + // Not used if the build is from a commit without modifications. + string snapshot = 3; + // Changelist tested if the change list is not already submitted. + int64 pending_changelist = 4; +} + +message CPUInfo { + int64 num_cores = 1; + + int64 num_cores_allowed = 2; + + // How fast are these cpus? + double mhz_per_cpu = 3; + + // Additional cpu information. For example, + // Intel Ivybridge with HyperThreading (24 cores) dL1:32KB dL2:256KB dL3:30MB + string cpu_info = 4; + + // What kind of cpu scaling is enabled on the host. + // Examples include "performance", "ondemand", "conservative", "mixed". + string cpu_governor = 5; + + // Cache sizes (in bytes), e.g. "L2": 262144 (for 256KB) + map cache_size = 6; +} + +message MemoryInfo { + int64 total = 1; // Total virtual memory in bytes + int64 available = 2; // Immediately available memory in bytes +} + +message GPUInfo { + string model = 1; // e.g. "Tesla K40c" + string uuid = 2; // Final entry in output of "nvidia-smi -L" + string bus_id = 3; // e.g. "0000:04:00.0" +} + +message PlatformInfo { + string bits = 1; // e.g. '64bit' + string linkage = 2; // e.g. 'ELF' + string machine = 3; // e.g. 'i386' + string release = 4; // e.g. '3.13.0-76-generic' + string system = 5; // e.g. 'Linux' + string version = 6; // e.g. '#120-Ubuntu SMP Mon Jan 18 15:59:10 UTC 2016' +} + +message AvailableDeviceInfo { // Matches DeviceAttributes + string name = 1; // Device name. + string type = 2; // Device type, e.g. 'CPU' or 'GPU'. + int64 memory_limit = 3; // Memory capacity in bytes. + string physical_description = 4; // The physical description of this device. +} + +message MachineConfiguration { + // Host name of machine that ran the benchmark. + string hostname = 1; + + // Unique serial number of the machine. + string serial_identifier = 7; + + // Additional platform information. + PlatformInfo platform_info = 2; + + // CPU Information. + CPUInfo cpu_info = 3; + + // Other devices that are attached and relevant (e.g. GPUInfo). + repeated google.protobuf.Any device_info = 4; + + // Devices accessible to the test (e.g. as given by list_local_devices). + repeated AvailableDeviceInfo available_device_info = 5; + + MemoryInfo memory_info = 6; +} + +// Run-specific items such as arguments to the test / benchmark. +message RunConfiguration { + repeated string argument = 1; + // Environment variables used to run the test/benchmark. + map env_vars = 2; +} + +// The output of one benchmark / test run. Each run contains a list of +// tests or benchmarks, stored as BenchmarkEntry messages. +// +// This message should be emitted by the reporter (which runs the +// test / BM in a subprocess and then reads the emitted BenchmarkEntry messages; +// usually from a serialized json file, finally collecting them along +// with additional information about the test run. +message TestResults { + // The target of the run, e.g.: + // //tensorflow/core:kernels_adjust_contrast_op_benchmark_test + string target = 1; + + // The list of tests or benchmarks in this run. + BenchmarkEntries entries = 2; + + // The configuration of the build (compiled opt? with cuda? any copts?) + BuildConfiguration build_configuration = 3; + + // The commit id (git hash or changelist) + CommitId commit_id = 4; + + // The time the run started (in seconds of UTC time since Unix epoch) + int64 start_time = 5; + + // The amount of time the total run took (wall time in seconds) + double run_time = 6; + + // Machine-specific parameters (Platform and CPU info) + MachineConfiguration machine_configuration = 7; + + // Run-specific parameters (arguments, etc) + RunConfiguration run_configuration = 8; + + // Benchmark target identifier. + string name = 9; + + // The type of benchmark. + enum BenchmarkType { + UNKNOWN = 0; // Fallback for protos written before Type was introduced. + CPP_MICROBENCHMARK = 1; + PYTHON_BENCHMARK = 2; + ANDROID_BENCHMARK = 3; + EDGE_BENCHMARK = 4; + IOS_BENCHMARK = 5; + } + BenchmarkType benchmark_type = 10; + + // Used for differentiating between continuous and debug builds. + // Must be one of: + // * cbuild: results from continuous build. + // * presubmit: results from oneshot requests. + // * culprit: results from culprit finder rerun. + string run_mode = 11; + + // TensorFlow version this benchmark runs against. + // This can be either set to full version or just the major version. + string tf_version = 12; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/autotune_results.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/autotune_results.proto new file mode 100644 index 000000000..cf3ddcc1a --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/autotune_results.proto @@ -0,0 +1,52 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +import "xla/autotuning.proto"; + +// A collection of algorithms for particular dot/convs. Usually this is "the +// best" algorithm for the particular dot/conv, although that's not strictly +// required. +// +// Users don't interact with this proto directly. It's used internally to +// facilitate ahead-of-time autotuning -- The string used by +// xla::{Serialize,Load}AutotuneResults is, internally, a serialization of this +// proto. +// +// LINT.IfChange +message AutotuneResults { + message Entry { + string device = 1; + string hlo = 2; + AutotuneResult result = 3; + } + + int32 version = 1; + reserved 2; // dots + reserved 3; // convs + repeated Entry results = 4; +} +// LINT.ThenChange( +// "service/gpu/autotuner_util.cc:version" +// ) + +message AutotuningLogs { + repeated AutotuningLog logs = 1; + + // Next ID: 2 +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/autotuning.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/autotuning.proto new file mode 100644 index 000000000..9d6de133f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/autotuning.proto @@ -0,0 +1,123 @@ +// This file defines protos that store the results of autotuning various +// operations. +// +// They are in proto format because we want to log them structured. They offer +// tremendous statistical, testing, and debugging value. +syntax = "proto3"; + +package xla; + +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; +import "tsl/protobuf/dnn.proto"; + +option go_package = "github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto"; + +message CudnnVersion { + int32 major = 1; + int32 minor = 2; + int32 patch = 3; +} + +message ComputeCapability { + int32 major = 1; + int32 minor = 2; +} + +message AutotuneResult { + enum FailureKind { + UNKNOWN = 0; + + // Algorithm wrote memory outside its output buffers. + REDZONE_MODIFIED = 1; + + // Algorithm gave a different result from a reference algorithm. + WRONG_RESULT = 2; + + // Algorithm was rejected for failing to run or for known bugs. + DISQUALIFIED = 3; + } + + message FailureResult { + FailureKind kind = 1; + string msg = 2; + + // For failure_kind == WRONG_RESULT, this field indicates the reference + // configuration that we compared against. + // + // Note that the reference algorithm isn't always correct. However, + // empirically it's more correct, as it's "algo 0", less fancy than the + // compared one. + oneof key { + ConvKey reference_conv = 11; + GemmKey reference_gemm = 12; + CudaConvPlanKey reference_cuda_conv_plan = 14; + stream_executor.dnn.AlgorithmProto reference_algorithm = 15; + } + + int64 buffer_address = 13; + } + + // Legacy and unused in new data; superseded by AlgorithmProto. + message ConvKey { + int64 algorithm = 1; + bool tensor_ops_enabled = 2; + } + + message GemmKey { + int64 algorithm = 1; + } + + // Legacy and unused in new data; superseded by AlgorithmProto. + message CudaConvPlanKey { + string exec_plan_id = 1; + } + + // If you don't need a proto in your code, please use TritonGemmConfig instead + // of using this proto directly. + message TritonGemmKey { + int64 block_m = 1; + int64 block_n = 2; + int64 block_k = 3; + int64 split_k = 4; + int64 num_stages = 5; + int64 num_warps = 6; + int64 num_ctas = 7; + } + + int64 scratch_bytes = 8; + google.protobuf.Duration run_time = 9; + + FailureResult failure = 7; + + oneof key { + ConvKey conv = 5; + GemmKey gemm = 6; + TritonGemmKey triton = 17; + CudaConvPlanKey cuda_conv_plan = 15; + stream_executor.dnn.AlgorithmProto algorithm = 16; + } + + // Next ID: 17 +} + +message AutotuningLog { + google.protobuf.Any instr = 1; + + // Records all auto-tuning results per algorithm. + repeated AutotuneResult results = 2; + + CudnnVersion cudnn_version = 3; + ComputeCapability compute_capability = 4; + + // stream_executor::DeviceDescription::pci_bus_id. + string device_pci_bus_id = 5; + + string blas_version = 6; + + string fusion_name = 7; + + int64 fusion_count = 8; + + // Next ID: 9 +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/hlo/experimental/auto_sharding/auto_sharding.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/hlo/experimental/auto_sharding/auto_sharding.proto new file mode 100644 index 000000000..ad370117c --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/hlo/experimental/auto_sharding/auto_sharding.proto @@ -0,0 +1,86 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +message AutoShardingSolverRequest { + message Pair { + int64 first = 1; + int64 second = 2; + } + message Costs { + repeated double costs = 1; + } + message Nodes { + repeated int64 nodes = 1; + } + message Edges { + repeated int64 edges = 1; + } + message Names { + repeated string names = 1; + } + message SolverTimeout { + int64 solver_timeout_in_seconds = 1; + } + message Coeff { + double coeff = 1; + } + message Group { + repeated int64 prims = 1; // Node or edge primitive indices. + } + + int64 num_nodes = 1; + int64 memory_budget = 2; + repeated int64 s_len = 3; + repeated int64 s_follow = 4; + repeated int64 s_hint = 5; + repeated int64 peak_times = 35; + repeated Pair edges = 6; + repeated Nodes live = 7; + repeated Edges live_edges = 28; + repeated Pair node_intervals = 36; + repeated Pair edge_intervals = 37; + repeated Group node_groups = 38; + repeated Group edge_groups = 39; + repeated Costs computation_costs = 8; + repeated Costs communication_costs = 9; + repeated Costs memory_costs = 10; + repeated Costs memory_edge_costs = 29; + repeated Costs departure_costs = 11; + repeated Costs resharding_costs = 12; + repeated Costs duration_costs = 13; + repeated Pair aliases = 14; + repeated Costs value_costs = 15; + repeated string instruction_names = 16; + repeated string opcodes = 33; + repeated Names strategy_names = 32; + optional SolverTimeout solver_timeout = 17; + optional Coeff overbudget_coeff = 18; + optional Coeff makespan_coeff = 19; + optional Coeff max_departures = 20; + optional Coeff max_cost = 25; + optional Coeff coeff_limit = 26; + bool crash_at_infinity_costs_check = 21; + bool compute_iis = 22; + double saltiplier = 23; + bool deterministic_mode = 27; + string module_name = 24; + string request_name = 30; + bool enable_output = 31; + bool enable_memory_edge_costs = 34; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/mlir/tools/mlir_replay/public/compiler_trace.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/mlir/tools/mlir_replay/public/compiler_trace.proto new file mode 100644 index 000000000..b8f75dd1a --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/mlir/tools/mlir_replay/public/compiler_trace.proto @@ -0,0 +1,31 @@ +/* Copyright 2022 The OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mlir.interpreter; + +message MlirCompilationTraceEntry { + // The name of the pass that was previously executed. + optional string after_pass = 1; + + // MLIR module IR of the state after the pass. + optional string mlir_module = 2; +} + +message MlirCompilationTrace { + // MLIR modules corresponding to each stage of the compilation pipeline. + repeated MlirCompilationTraceEntry passes = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/mlir/tools/mlir_replay/public/execution_trace.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/mlir/tools/mlir_replay/public/execution_trace.proto new file mode 100644 index 000000000..c4ff5ecd4 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/mlir/tools/mlir_replay/public/execution_trace.proto @@ -0,0 +1,72 @@ +/* Copyright 2022 The OpenXLA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package mlir.interpreter; + +message TracedValue { + // The shape - includes vector dimensions. + // TODO(jreiffers): Model vector dimensions separately. + repeated int64 shape = 1 [packed = true]; + optional bool is_scalar = 2; + + enum ElementType { + UNKNOWN = 0; + INTEGRAL = 1; + UNSIGNED = 2; + FLOAT = 3; + COMPLEX = 4; + TUPLE = 5; + } + + optional int32 bit_width = 3; + optional ElementType element_type = 4; + + repeated float floats = 5 [packed = true]; + repeated double doubles = 6 [packed = true]; + repeated int64 ints = 7 [packed = true]; + repeated uint64 uints = 8 [packed = true]; + repeated TracedValue tuple_elements = 9; +} + +message InstructionTrace { + optional string name = 1; + repeated TracedValue args = 2; + repeated TracedValue results = 3; + // TODO(jreiffers): Model side effects (e.g. memref.store). + + repeated RegionTrace regions = 4; +} + +message RegionTrace { + // The number of the region that is being executed (within the parent op). + // For example: '1' for an scf.while's `after` region. + optional int32 region_number = 1; + // The arguments that were passed to the region. + repeated TracedValue bbargs = 2; + // One instruction per instruction in the region. + repeated InstructionTrace instructions = 3; + repeated TracedValue results = 4; +} + +message ExecutionTrace { + // The IR that was executed. Note: this should always be filled in the generic + // format. + optional string ir = 1; + + // The trace of the entry function execution. + optional RegionTrace trace = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/compile_options.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/compile_options.proto new file mode 100644 index 000000000..4ea4af933 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/compile_options.proto @@ -0,0 +1,133 @@ +syntax = "proto3"; + +package xla; + +import "xla/stream_executor/device_description.proto"; +import "xla/xla.proto"; +import "xla/xla_data.proto"; + +// A serialization of xla::ExecutableBuildOptions. +// Next id: 19. +message ExecutableBuildOptionsProto { + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + int64 device_ordinal = 1; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + xla.ShapeProto result_layout = 2; + + // Expose access to the XLA compilation environments, which will be passed to + // the compilation process. + xla.CompilationEnvironmentsProto comp_envs = 13; + + // Expose access to the XLA debug options which will be passed to the + // compilation process. + xla.DebugOptions debug_options = 3; + + // The number of replicas of this computation that are to be executed. + // Defaults to 1. + int64 num_replicas = 4; + + // The number of partitions in this computation. Defaults to 1. + int64 num_partitions = 5; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 6; + + // Whether to automatically generate XLA shardings for SPMD partitioner. + bool use_auto_spmd_partitioning = 7; + + // Whether HLOs should be deduplicated. + bool deduplicate_hlo = 8; + + // If set, this specifies a static device assignment for the computation. + // Otherwise, the computation will be compiled generically and can be run with + // any device assignment compatible with the computation's replica and + // partition counts. + xla.DeviceAssignmentProto device_assignment = 9; + + // Whether input and output buffers are aliased if the associated parameter is + // passed-through XLA modules without being changed. + bool alias_passthrough_params = 10; + + // By default, XLA builds an executable by invoking standard compilation, i.e. + // running Compiler::Compile, or both Compiler::RunHloPasses and + // Compiler::RunBackend. When run_backend_only is set to true, XLA builds an + // executable by invoking only RunBackend and skip invoking RunHloPasses, + // which can be used to compile post-optimizations HLO modules. + bool run_backend_only = 11; + + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 18; + + // Allows sharding propagation to propagate to the outputs. This changes the + // output shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the output + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + // This is a vector of bool, because the user can control (if the output of + // the computation is a tuple) which elements of the tuple can have the + // sharding substituted and which don't. If only one boolean value is passed + // in the vector that's interpreted as the value to be applied for every + // single element of the output tuple. One value per element of the tuple + // means that each value is attached to one of the output elements. + repeated bool allow_spmd_sharding_propagation_to_output = 12; + + // Opaque profile data for any feedback directed optimizations. + bytes fdo_profile = 14; + + int64 device_memory_size = 15; + + // Mesh shape in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_shape = 16; + + // Mesh ids in auto sharding options. + repeated int64 auto_spmd_partitioning_mesh_ids = 17; +} + +message OptionOverrideProto { + oneof value { + string string_field = 1; + bool bool_field = 2; + int64 int_field = 3; + double double_field = 4; + } +} + +message CompileOptionsProto { + // Refer CompileOptions for documentation of fields. + repeated ShapeProto argument_layouts = 1; + bool parameter_is_tupled_arguments = 2; + ExecutableBuildOptionsProto executable_build_options = 3; + bool compile_portable_executable = 4; + int64 profile_version = 5; + bytes serialized_multi_slice_config = 6; + map env_option_overrides = 7; + + stream_executor.GpuTargetConfigProto target_config = 8; +} + +// Helper for serializing opaque executables alongside CompileOptions. +message ExecutableAndOptionsProto { + bytes serialized_executable = 1; + CompileOptionsProto compile_options = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/cpu/cpu_topology.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/cpu/cpu_topology.proto new file mode 100644 index 000000000..85167fc5f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/cpu/cpu_topology.proto @@ -0,0 +1,14 @@ +syntax = "proto3"; + +package xla; + +// A proto used to serialize CpuTopology instances. +message CpuTopologyProto { + message CpuDevice { + int32 id = 1; + int32 process_index = 2; + int32 local_hardware_id = 3; + } + repeated CpuDevice cpu_devices = 1; + repeated string machine_attributes = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/distributed/protocol.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/distributed/protocol.proto new file mode 100644 index 000000000..955a537d1 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/distributed/protocol.proto @@ -0,0 +1,66 @@ +// Copyright 2020 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================== +// +// Distributed XLA service protocol. +// +// This is a minimal distributed protocol intended for a small set of purposes +// * barriers to wait for all clients to start up or shut down +// * health checking to detect when clients vanish +// * for sharing GPU topology and NCCL communicator state between distributed +// hosts. +// +// The intention is that a service is started during cluster initialization and +// persists for the lifetime of the cluster. + +syntax = "proto3"; + +package xla; + +option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/compiler/" + "xla/pjrt/distributed/protocol_go_proto"; + +// Describes a device local to a host. +message DeviceProto { + int32 local_device_ordinal = 1; + string name = 2; + string vendor = 3; + + // The following fields are present in the GlobalTopologyProto message + // returned by EnumerateDevices() but not in the LocalTopologyProto messages + // passed to EnumerateDevices(). In other words, the coordinator node + // determines the global device IDs during EnumerateDevices(). + int32 global_device_id = 4; // Globally unique ID number. + // Devices with the same slice_index are connected by fast network, e.g. + // NVLink on GPUs. + int32 slice_index = 5; + + // Store vendor-specific compute capability. + string compute_capability = 6; + + // The number of cores (e.g. SMs on GPUs) on the device. + int32 core_count = 7; +}; + +message LocalTopologyProto { + int32 node_id = 1; + // Unique per OS kernel restart to uniquely identify a host. + // See /proc/sys/kernel/random/boot_id. + string boot_id = 2; + repeated DeviceProto devices = 3; +} + +message GlobalTopologyProto { + repeated LocalTopologyProto nodes = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/executable_metadata.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/executable_metadata.proto new file mode 100644 index 000000000..db308d57a --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/executable_metadata.proto @@ -0,0 +1,23 @@ +syntax = "proto3"; + +package xla; + +import "xla/service/hlo.proto"; + +// Mirror of xla::CompiledMemoryStats. +message CompiledMemoryStatsProto { + // Device default memory (e.g., HBM for GPU/TPU) usage stats. + int64 generated_code_size_in_bytes = 1; + int64 argument_size_in_bytes = 2; + int64 output_size_in_bytes = 3; + int64 alias_size_in_bytes = 4; + int64 temp_size_in_bytes = 5; + xla.HloProto hlo_proto = 6; + + // Host memory usage stats. + int64 host_generated_code_size_in_bytes = 7; + int64 host_argument_size_in_bytes = 8; + int64 host_output_size_in_bytes = 9; + int64 host_alias_size_in_bytes = 10; + int64 host_temp_size_in_bytes = 11; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/execute_options.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/execute_options.proto new file mode 100644 index 000000000..dd9cb661f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/execute_options.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package xla; + +enum ExecutionModeProto { + EXECUTION_MODE_UNSPECIFIED = 0; + EXECUTION_MODE_DEFAULT = 1; + EXECUTION_MODE_SYNCHRONOUS = 2; + EXECUTION_MODE_ASYNCHRONOUS = 3; +} + +// Mirrors `xla::ExecuteOptions`. +message ExecuteOptionsProto { + bool arguments_are_tupled = 1; + bool untuple_result = 2; + int32 launch_id = 3; + bool strict_shape_checking = 4; + bool use_major_to_minor_data_layout_for_callbacks = 8; + ExecutionModeProto execution_mode = 6; + repeated int32 non_donatable_input_indices = 7; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/gpu/gpu_topology.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/gpu/gpu_topology.proto new file mode 100644 index 000000000..0bb3c5b34 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/gpu/gpu_topology.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package xla; + +enum GpuVersionProto { + GPU_VERSION_UNSPECIFIED = 0; + GPU_VERSION_A100 = 1; + GPU_VERSION_H100 = 2; +} + +// A proto used to serialize GpuTopology instances. +message GpuTopologyProto { + // TODO(b/331224674): Remove this field once all uses are removed. + repeated int32 device_ids = 1; + + GpuVersionProto gpu_version = 2; + + // Name for the GPU version, e.g., "NVIDIA A100-SXM4-40GB". Returned as + // "device_kind" of a GPU device in the PJRT client API. + string platform_version = 3; + + // The number of slices. + // Devices on the same slice are connected by the fast network via NVLinks, + // which could be within a host or span across multiple hosts. + int32 num_slices = 4; + + // The number of hosts for each slice. + int32 num_hosts_per_slice = 5; + + // The number of devices for each host. + int32 num_devices_per_host = 6; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/stream_executor_executable.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/stream_executor_executable.proto new file mode 100644 index 000000000..c9572d9aa --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/pjrt/stream_executor_executable.proto @@ -0,0 +1,29 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +import "xla/pjrt/compile_options.proto"; + +message StreamExecutorExecutableProto { + CompileOptionsProto compile_options = 1; + repeated bytes executables = 2; + int32 num_replicas = 3; + int32 num_partitions = 4; + string name = 5; + string fingerprint = 6; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/array_spec.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/array_spec.proto new file mode 100644 index 000000000..6d61b71a0 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/array_spec.proto @@ -0,0 +1,29 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/dtype.proto"; +import "xla/python/ifrt/shape.proto"; +import "xla/python/ifrt/sharding.proto"; + +// Proto equivalent of C++ `ArraySpec`. +message ArraySpecProto { + DTypeProto dtype = 1; + ShapeProto shape = 2; + ShardingProto sharding = 3; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/custom_call_program.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/custom_call_program.proto new file mode 100644 index 000000000..e631b2935 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/custom_call_program.proto @@ -0,0 +1,31 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/array_spec.proto"; +import "xla/python/ifrt/device.proto"; + +// Proto equivalent of C++ `CustomCallProgram`. +message CustomCallProgramProto { + string type = 1; + string name = 2; + bytes serialized_program_text = 3; + DeviceListProto devices = 4; + repeated ArraySpecProto input_specs = 5; + repeated ArraySpecProto output_specs = 6; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/device.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/device.proto new file mode 100644 index 000000000..53e384522 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/device.proto @@ -0,0 +1,25 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Proto equivalent of C++ `DeviceList`. +message DeviceListProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + repeated int32 device_ids = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/dtype.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/dtype.proto new file mode 100644 index 000000000..eadfd42a3 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/dtype.proto @@ -0,0 +1,77 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Proto equivalent of C++ `DType`. +message DTypeProto { + // LINT.IfChange(DTypeProtoKind) + enum Kind { + KIND_UNSPECIFIED = 0; + + // Predicates are two-state booleans. + KIND_PRED = 1; + + // Signed integral values of fixed width. + KIND_S2 = 26; + KIND_S4 = 21; + KIND_S8 = 2; + KIND_S16 = 3; + KIND_S32 = 4; + KIND_S64 = 5; + + // Unsigned integral values of fixed width. + KIND_U2 = 27; + KIND_U4 = 22; + KIND_U8 = 6; + KIND_U16 = 7; + KIND_U32 = 8; + KIND_U64 = 9; + + // Floating-point values of fixed width. + KIND_F16 = 10; + KIND_F32 = 11; + KIND_F64 = 12; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the + // exponent and 7 bits for the mantissa. + KIND_BF16 = 16; + + // Complex values of fixed width. + KIND_C64 = 15; // Paired F32 (real, imag), as in std::complex. + KIND_C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A token type threaded between side-effecting operations. Shapes of this + // dtype will have empty dimensions. + KIND_TOKEN = 17; + + KIND_F8E4M3FN = 20; + KIND_F8E4M3B11FNUZ = 23; + KIND_F8E4M3FNUZ = 25; + KIND_F8E5M2 = 19; + KIND_F8E5M2FNUZ = 24; + + // Variable-length string represented as raw bytes, as in `bytes` in Python, + // i.e., no encoding enforcement. String is not support in XLA. DType.Kind + // needs to match xla.PrimitiveType enum, so choose a large enum to avoid + // collision. + KIND_STRING = 99; + } + // LINT.ThenChange() + Kind kind = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/remap_plan.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/remap_plan.proto new file mode 100644 index 000000000..3de2690ca --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/remap_plan.proto @@ -0,0 +1,40 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/array_spec.proto"; + +// Wire format for `RemapPlan`. See `RemapPlan` for the semantics of the proto +// fields. +message RemapPlanProto { + message MappingProto { + int32 in_array = 1; + int32 out_array = 2; + // Transposed lists for `from` and `to`. This transposition makes a + // serialized proto smaller when `from` or `to` has many elements. + repeated int64 from_start = 3; + repeated int64 from_end = 4; + repeated int64 from_step = 5; + repeated int64 to_start = 6; + repeated int64 to_end = 7; + repeated int64 to_step = 8; + } + repeated ArraySpecProto input_specs = 1; + repeated ArraySpecProto output_specs = 2; + repeated MappingProto mappings = 3; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/serdes.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/serdes.proto new file mode 100644 index 000000000..4693b645f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/serdes.proto @@ -0,0 +1,24 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Wire format for objects serialized from `Serializable`. +message Serialized { + string type_name = 1; + bytes data = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/shape.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/shape.proto new file mode 100644 index 000000000..cd5d26e13 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/shape.proto @@ -0,0 +1,38 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +// Proto equivalent of C++ `Shape`. Currently support static shapes with all +// dimension sizes greater than or equal to 0. +message ShapeProto { + repeated int64 dims = 1; +} + +// Proto equivalent of C++ `BoundedDynamicShapeTag`. +message BoundedDynamicShapeTagProto { + repeated bool is_dynamic_dims = 1; +} + +// Proto equivalent of C++ `DynamicShape`. Currently only support bounded +// dynamic shape. +message DynamicShapeProto { + ShapeProto shape = 1; + oneof tag { + BoundedDynamicShapeTagProto bounded_dynamic_shape_tag = 2; + } +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/sharding.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/sharding.proto new file mode 100644 index 000000000..a6033bb34 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/sharding.proto @@ -0,0 +1,26 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/serdes.proto"; + +// Proto equivalent of C++ `Sharding`. A suitable serializer and deserializer +// implementation must be registered. +message ShardingProto { + xla.ifrt.Serialized serialized_sharding = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/sharding_serdes.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/sharding_serdes.proto new file mode 100644 index 000000000..b470b04fe --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt/sharding_serdes.proto @@ -0,0 +1,56 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/device.proto"; +import "xla/python/ifrt/shape.proto"; + +// Proto equivalent of C++ `SingleDeviceSharding`. +message SingleDeviceShardingProto { + // Serialization and deserialization are expected to ensure that device ids + // are stable across proto construction and consumption. + int32 device_id = 1; + optional string memory_kind = 2; +} + +// Proto equivalent of C++ `OpaqueSharding`. +message OpaqueShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 2; +} + +// Proto equivalent of C++ `ConcreteSharding`. +message ConcreteShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 4; + oneof shape_or_dynamic_shape { + ShapeProto shape = 2; + DynamicShapeProto dynamic_shape = 5; + } + repeated ShapeProto shard_shapes = 3; + repeated DynamicShapeProto shard_dynamic_shapes = 6; +} + +// Proto equivalent of C++ `ConcreteEvenSharding`. +message ConcreteEvenShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 4; + ShapeProto shape = 2; + ShapeProto shard_shape = 3; + bool is_fully_replicated = 5; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto new file mode 100644 index 000000000..6741e5d98 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/grpc_ifrt_service.proto @@ -0,0 +1,107 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proxy; + +import "xla/python/ifrt_proxy/common/ifrt_service.proto"; + +service GrpcIfrtService { + // Returns the IFRT Proxy version that both the client and the server + // supports. Returns an error if there's no such version. + rpc GetVersion(GrpcGetVersionRequest) returns (GrpcGetVersionResponse) {} + + // IfrtSession is a stream of IFRT requests (from the client) and responses + // from the server. + // + // Clients can optionally start the stream with an InitRequest to configure + // startup options and to retrieve basic run-time system details such as the + // number and handles of the available devices (see InitResponse). But clients + // that are fine with the default options and do not immediately need the info + // from the InitResponse can start with any other request. + // + // TODO(b/282757875): Investigate if there are useful details that client + // should supply to the server even before the first InitRequest message - may + // be via gRPC metadata. + rpc IfrtSession(stream IfrtRequest) returns (stream IfrtResponse) {} + + // Sends a host buffer from the client to the server. Uses client-side + // streaming to allow sending buffers that exceed the 2GiB protobuf + // serialization limit. + rpc HostBufferStore(stream GrpcHostBufferStoreRequest) + returns (GrpcHostBufferStoreResponse); + + // Reads a host buffer from the server to the client. Uses server-side + // streaming to allow >2GiB host buffer transfer. + rpc HostBufferLookup(GrpcHostBufferLookupRequest) + returns (stream GrpcHostBufferLookupResponse); + + // Deletes a host buffer from the server. + rpc HostBufferDelete(GrpcHostBufferDeleteRequest) + returns (GrpcHostBufferDeleteResponse); +} + +message GrpcGetVersionRequest { + IfrtProxyVersion min_version = 1; + IfrtProxyVersion max_version = 2; +} + +message GrpcGetVersionResponse { + IfrtProxyVersion version = 1; +} + +// Metadata for `IfrtSession` requests, sent as client metadata associated with +// key "ifrt-proxy-grpc-ifrt-session-metadata-bin". +message GrpcIfrtSessionMetadata { + IfrtProxyVersion version = 1; +} + +// Metadata for `Store` requests, sent as client metadata associated with key +// "ifrt-proxy-grpc-host-buffer-store-metadata-bin". +message GrpcHostBufferStoreMetadata { + fixed64 session_id = 1; + fixed64 handle = 2; + int64 buffer_size = 3; +} + +// `Store` request that contains actual data, potentially chunked. All requests +// in a transfer must be sent in order and the server simply concatenate `bytes` +// in the response under this assumption. +message GrpcHostBufferStoreRequest { + bytes data = 1; // copybara_removed [ctype = STRING_PIECE] +} + +message GrpcHostBufferStoreResponse {} + +// `Lookup` request that specifies which host buffer in the server to read. +message GrpcHostBufferLookupRequest { + fixed64 session_id = 1; + fixed64 handle = 2; +} + +// `Lookup` response that returns the (potentially chunked) host buffer +// contents. As in `GrpcHostBufferStoreRequest`, all responses must be sent in +// order and the client simply concatenates `data`. +message GrpcHostBufferLookupResponse { + bytes data = 1; // copybara_removed [ctype = STRING_PIECE] +} + +// `Delete` request that specifies the host buffer to delete. +message GrpcHostBufferDeleteRequest { + fixed64 session_id = 1; + fixed64 handle = 2; +} + +message GrpcHostBufferDeleteResponse {} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/ifrt_service.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/ifrt_service.proto new file mode 100644 index 000000000..6c3745197 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/ifrt_service.proto @@ -0,0 +1,506 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proxy; + +import "google/protobuf/any.proto"; +import "xla/pjrt/execute_options.proto"; +import "xla/python/ifrt/dtype.proto"; +import "xla/python/ifrt/remap_plan.proto"; +import "xla/python/ifrt/serdes.proto"; +import "xla/python/ifrt/shape.proto"; +import "xla/python/ifrt/sharding.proto"; +import "xla/python/ifrt_proxy/common/types.proto"; +import "xla/xla_data.proto"; +import "tsl/protobuf/status.proto"; + +option cc_enable_arenas = true; + +message IfrtProxyVersion { + int32 protocol_version = 1; +} + +message IfrtRequest { + RequestMetadata request_metadata = 1; + + oneof request { + InitRequest init_request = 2; + + // ===== Future ===== + CheckFutureRequest check_future_request = 3; + + // ===== Array ===== + MakeArrayFromHostBufferRequest make_array_from_host_buffer_request = 4; + AssembleArrayFromSingleDeviceArraysRequest + assemble_array_from_single_device_arrays_request = 5; + RemapArraysRequest remap_arrays_request = 23; + CopyToHostBufferRequest copy_to_host_buffer_request = 6; + DisassembleIntoSingleDeviceArraysRequest + disassemble_into_single_device_arrays_request = 7; + CheckArrayReadyRequest check_array_ready_request = 8; + DeleteArrayRequest delete_array_request = 9; + ReshardRequest reshard_request = 10; + FullyReplicatedShardRequest fully_replicated_shard_request = 20; + IsArrayDeletedRequest is_array_deleted_request = 11; + DestructArrayRequest destruct_array_request = 12; + + // ==== Compiler ==== + CompileRequest compile_request = 13; + + // ===== LoadedExecutable ===== + LoadedExecutableMetadataRequest loaded_executable_metadata_request = 14; + LoadedExecutableExecuteRequest loaded_executable_execute_request = 15; + LoadedExecutableDeleteRequest loaded_executable_delete_request = 16; + LoadedExecutableIsDeletedRequest loaded_executable_is_deleted_request = 17; + LoadedExecutableDestructRequest loaded_executable_destruct_request = 18; + + // ===== LoadedHostCallback ===== + LoadedHostCallbackPollRequest loaded_host_callback_poll_request = 21; + LoadedHostCallbackReturnRequest loaded_host_callback_return_request = 22; + + // ===== Client ===== + GetDefaultDeviceAssignmentRequest get_default_device_assignment_request = + 19; + } +} + +message IfrtResponse { + ResponseMetadata response_metadata = 1; + + oneof response { + InitResponse init_response = 2; + + // ===== Future ===== + CheckFutureResponse check_future_response = 3; + + // ===== Array ===== + MakeArrayFromHostBufferResponse make_array_from_host_buffer_response = 4; + AssembleArrayFromSingleDeviceArraysResponse + assemble_array_from_single_device_arrays_response = 5; + RemapArraysResponse remap_arrays_response = 23; + CopyToHostBufferResponse copy_to_host_buffer_response = 6; + DisassembleIntoSingleDeviceArraysResponse + disassemble_into_single_device_arrays_response = 7; + CheckArrayReadyResponse check_array_ready_response = 8; + DeleteArrayResponse delete_array_response = 9; + ReshardResponse reshard_response = 10; + FullyReplicatedShardResponse fully_replicated_shard_response = 20; + IsArrayDeletedResponse is_array_deleted_response = 11; + DestructArrayResponse destruct_array_response = 12; + + // ===== Compiler ===== + CompileResponse compile_response = 13; + + // ===== LoadedExecutable ===== + LoadedExecutableMetadataResponse loaded_executable_metadata_response = 14; + LoadedExecutableExecuteResponse loaded_executable_execute_response = 15; + LoadedExecutableDeleteResponse loaded_executable_delete_response = 16; + LoadedExecutableIsDeletedResponse loaded_executable_is_deleted_response = + 17; + LoadedExecutableDestructResponse loaded_executable_destruct_response = 18; + + // ===== LoadedHostCallback ===== + LoadedHostCallbackPollResponse loaded_host_callback_poll_response = 21; + LoadedHostCallbackReturnResponse loaded_host_callback_return_response = 22; + + // ===== Client ===== + GetDefaultDeviceAssignmentResponse get_default_device_assignment_response = + 19; + } +} + +// Metadata of an IFRT Request. +message RequestMetadata { + // Identifies a logical IFRT Operation (equivalent to an IFRT API call). + // + // For the operations that require chunking (e.g.: MakeArrayFromHostBuffer) + // all the request proto messages share the same op_id. + // + // Must be unique and monotonically increasing across the life of a client - + // may stretch across multiple successive IfrtSessions used to reconnect and + // resync after transient connectivity failures. + fixed64 op_id = 1; + + // List of one or more prior ops this current op is "dependent" + // upon. Currently this allows the client to define the order in which the + // server starts the execution of requests. Future versions may add other + // types of dependencies. For instance, a separate list of dependencies that + // must *complete* executing before the current one can start to execute. + // + // An op_id that has not yet been seen by the server is treated as an error + // that fails the op. + repeated fixed64 dependencies = 2; + + // UserContext is a basic provenance mechanism that allows the server-side + // actions and artifacts (say, allocating a buffer) to be associated with the + // corresponding client-side context that triggered those actions. + // + // The optional UserContextId is generated by the client and are used as an + // opaque label by the server and the run-time systems behind it. + // TODO(b/282757875): Add a pointer to Usercontext bugs/design doc. + fixed64 user_context_id = 3; + + // Additional implementation-specific payloads. + repeated google.protobuf.Any payloads = 4; +} + +// Metadata of an IFRT Response. + +message ResponseMetadata { + // ID of the operation this response belongs to. + fixed64 op_id = 1; + + // absl::Status of the operation. + // + // In case of "chunked" responses (i.e., the full logical response is + // spread across a sequence of IfrtResponse protos), the actual sequence of + // IfrtResponse messages will follow only if this absl::Status is OK in the + // very first message. That is, in case of errors, server sends a single + // IfrtResponse with the appropriate error included. + // + // In case of "batched" operations (i.e., where the response is carrying + // the outcomes of multiple requests that were "batched" in the same + // IfrtRequest proto - such as deleting a bunch of Arrays) this + // absl::Status field provides a way to quickly check if none of the + // individual operations encountered errors. Clients should not rely on + // specific error type or string when this is not OK, they should check the + // response message for individual absl::Statuses. + tensorflow.StatusProto status = 2; +} + +// InitRequest allows the client to specify the optional startup configuration +// parameters such as an idle timeout for this `IfrtSession`, backend servers +// addresses, and whether to turn on tracing, etc. +// +// Initialization of a a session is optional, but if a client chooses to do it, +// it must be the very first op i.e., the InitRequest must be the very first +// request of the session. +message InitRequest {} + +// InitResponse contains basic runtime system info (such as the available +// devices, and name and type of the platform) that most clients can immediately +// make use of. It may also carry the status for whether the optional +// configuration requested by the InitRequest has been successfully applied. +message InitResponse { + uint64 session_id = 8; + + string platform_name = 1; // == ifrt::Client::platform_name() + string platform_version = 2; // == ifrt::Client::platform_version() + uint64 platform_id = 3; // == ifrt::Client::platform_id() + uint64 process_index = 4; // == ifrt::Client::process_index() + string runtime_type = 5; // == ifrt::Client::runtime_type() + + message Device { + int32 id = 1; + int32 local_device_id = 9; + int32 local_hardware_id = 2; + string device_kind = 3; + optional int32 default_memory_id = 7; + repeated int32 memory_ids = 8; + string debug_string = 4; + string to_string = 5; + map attributes = 6; + } + + repeated Device devices = 6; // == ifrt::Client::devices() + repeated int32 addressable_device_ids = + 7; // == ifrt::Client::addressable_devices() + + message Memory { + int32 id = 1; + string memory_space_kind = 2; + int32 kind_id = 6; + repeated int32 device_ids = 3; + string debug_string = 4; + string to_string = 5; + } + + repeated Memory memories = 9; +} + +// ================ Future-related operations ================ + +// Checks if the given Futures are ready on the server. This is a destructive +// read, i.e., the given future will no longer be able to be referenced. +message CheckFutureRequest { + fixed64 future_handle = 1; +} +message CheckFutureResponse {} + +// ================ Array-related operations ================ + +// In the current context of the IFRT proxy service, the term `Host` in the +// proto names below refers to the host where the proxy client and the user code +// (e.g.: a Jax application) are running. + +// Makes an IFRT Array from the contents of a HostBuffer. +// Equivalent to `ifrt::Client::MakeArrayFromHostBuffer`. +message MakeArrayFromHostBufferRequest { + DTypeProto dtype = 1; + ShapeProto shape = 2; + ShardingProto sharding = 3; + fixed64 host_buffer_handle = 4; + optional proto.ByteStrides byte_strides = 5; +} +message MakeArrayFromHostBufferResponse { + fixed64 array_handle = 1; +} + +// Makes an IFRT Array from a set of single-device Arrays. +// Equivalent to ifrt::Client::AssembleArrayFromSingleDeviceArrays. +message AssembleArrayFromSingleDeviceArraysRequest { + ShapeProto shape = 1; + ShardingProto sharding = 2; + repeated fixed64 single_device_array_handles = 3; + proto.ArrayCopySemantics copy_semantics = 4; +} +message AssembleArrayFromSingleDeviceArraysResponse { + fixed64 array_handle = 1; +} + +// Remaps the shards of given IFRT arrays to new IFRT arrays. +// Equivalent to ifrt::Client::RemapArrays. +message RemapArraysRequest { + RemapPlanProto plan = 1; + repeated fixed64 array_handles = 2; + proto.ArrayCopySemantics copy_semantics = 3; +} +message RemapArraysResponse { + repeated fixed64 array_handles = 1; +} + +// Reads the contents of a given IFRT Array. +// Equivalent to ifrt::Array::CopyToHostBuffer. +message CopyToHostBufferRequest { + fixed64 array_handle = 2; + optional proto.ByteStrides byte_strides = 3; + fixed64 host_buffer_handle = 1; +} +message CopyToHostBufferResponse {} + +// Breaks the given Array into its constituent per-device Arrays. +// Equivalent to ifrt::Array::DisassmebleIntoSingleDeviceArrays. +message DisassembleIntoSingleDeviceArraysRequest { + fixed64 array_handle = 1; + proto.ArrayCopySemantics copy_semantics = 2; +} +message DisassembleIntoSingleDeviceArraysResponse { + repeated fixed64 single_device_array_handles = 1; +} + +message ReshardRequest { + fixed64 array_handle = 1; + ShardingProto sharding = 2; + proto.ArrayCopySemantics copy_semantics = 3; +} +message ReshardResponse { + fixed64 array_handle = 1; +} + +message FullyReplicatedShardRequest { + fixed64 array_handle = 1; + proto.ArrayCopySemantics copy_semantics = 2; +} +message FullyReplicatedShardResponse { + fixed64 array_handle = 1; +} + +// Checks if the given Arrays are ready on the server. +message CheckArrayReadyRequest { + fixed64 array_handle = 1; +} +message CheckArrayReadyResponse {} + +// Deletes the given Array. Response contains the handle for a Future that +// becomes ready when the deletion completes. +message DeleteArrayRequest { + fixed64 array_handle = 1; +} +message DeleteArrayResponse { + fixed64 deletion_future_handle = 1; +} + +message IsArrayDeletedRequest { + fixed64 array_handle = 1; +} +message IsArrayDeletedResponse { + bool deleted = 1; +} + +message DestructArrayRequest { + fixed64 array_handle = 1; +} +message DestructArrayResponse {} + +// ================ Compiler-related operations ================ + +// Modeled after `xla::PjRtLoadedExecutable::LogicalDeviceIds`. +// +// TODO(hyeontaek): this XLA-specific type is temporary and will be removed when +// `addressable_device_logical_ids()` is removed from `LoadedExecutable` or +// moved to a type-erased proto field. +message LogicalDeviceIds { + int32 replica = 1; + int32 partition = 2; +} + +// Compiles `mlir_module` and returns a `LoadedExecutable`. +message CompileRequest { + xla.ifrt.Serialized program = 1; + xla.ifrt.Serialized compile_options = 2; + repeated bytes host_callbacks = 3; +} +message CompileResponse { + fixed64 loaded_executable_handle = 1; + repeated fixed64 loaded_host_callback_handles = 8; + + // A subset of LoadedExecutable's fields that are cheap to calculate. See + // `LoadedExecutableMetadataResponse` for the rest of metadata. + string name = 2; + int32 num_devices = 3; + repeated LogicalDeviceIds addressable_device_logical_ids = 4; + repeated int32 addressable_device_ids = 5; + oneof fingerprint { + bytes fingerprint_value = 6; + tensorflow.StatusProto fingerprint_error = 7; + } + fixed64 ready_future_handle = 9; +} + +// ================ LoadedExecutable-related operations ================ + +// Reads `LoadedExecutable`'s metadata that's typically available only after +// compilation. Metadata fields that are cheaper to calculate are available +// immediately as part of `CompileResponse`. +message LoadedExecutableMetadataRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableMetadataResponse { + message ShardingList { + repeated xla.OpSharding shardings = 1; + } + + optional ShardingList parameter_shardings = 1; + optional ShardingList output_shardings = 2; + + message LayoutList { + repeated xla.LayoutProto layouts = 1; + } + + oneof parameter_layouts { + LayoutList parameter_layouts_list = 4; + tensorflow.StatusProto parameter_layouts_error = 5; + } + oneof output_layouts { + LayoutList output_layouts_list = 6; + tensorflow.StatusProto output_layouts_error = 7; + } + + message MemoryKindList { + repeated string memory_kinds = 1; + } + + message OutputMemoryKind { + tensorflow.StatusProto status = 1; + repeated MemoryKindList memory_kind_lists = 2; + } + + OutputMemoryKind output_memory_kinds = 3; +} + +// Mirrors `LoadedExecutable::Execute`. Returns output array handles and a +// future handle that becomes ready when the execution completes. The latter can +// be checked by issuing `CheckFutureRequest`. +message LoadedExecutableExecuteRequest { + fixed64 loaded_executable_handle = 1; + repeated fixed64 args_handles = 2; + xla.ExecuteOptionsProto execute_options = 3; + repeated int32 device_ids = 4; +} +message LoadedExecutableExecuteResponse { + fixed64 status_handle = 1; + + message Output { + DTypeProto dtype = 1; + ShapeProto shape = 2; + ShardingProto sharding = 3; + fixed64 array_handle = 4; + } + + repeated Output outputs = 2; +} + +// Mirrors `LoadedExecutable::Delete`. Returns a handle of a future that becomes +// ready when the deletion completes. +message LoadedExecutableDeleteRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableDeleteResponse { + fixed64 future_handle = 1; +} + +// Mirrors `LoadedExecutable::IsDeleted`. +message LoadedExecutableIsDeletedRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableIsDeletedResponse { + bool is_deleted = 1; +} + +// Mirrors `LoadedExecutable::~LoadedExecutable`. The LoadedExecutable handle +// becomes unusable after this request. +message LoadedExecutableDestructRequest { + fixed64 loaded_executable_handle = 1; +} +message LoadedExecutableDestructResponse {} + +// ================ LoadedHostCallback-related operations ================ + +// Waits for the given host callback on the server to have any pending execution +// and retrieves its execution identifier and operands. The server serializes +// all operands, concatenates them in the argument order, stores it as a single +// host buffer assocatiated with the given handle. +message LoadedHostCallbackPollRequest { + fixed64 loaded_host_callback_handle = 1; + fixed64 operand_host_buffer_handle = 2; +} +message LoadedHostCallbackPollResponse { + optional fixed64 host_callback_execution_handle = 1; +} + +// Returns the results of a client-side host callback execution, requested by +// `LoadedHostCallbackPollResponse`. The client concatenates all serialized +// results and stores them as a single host buffer associated with the given +// handle. +message LoadedHostCallbackReturnRequest { + fixed64 host_callback_execution_handle = 1; + oneof result { + fixed64 result_host_buffer_handle = 3; + tensorflow.StatusProto error = 2; + } +} +message LoadedHostCallbackReturnResponse {} + +// ============= Operations supported by the IFRT `Client` class ============= + +// Mirrors Client::GetDefaultDeviceAssignment. +message GetDefaultDeviceAssignmentRequest { + fixed64 num_replicas = 1; + fixed64 num_partitions = 2; +} +message GetDefaultDeviceAssignmentResponse { + xla.DeviceAssignmentProto device_assignment = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/types.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/types.proto new file mode 100644 index 000000000..ca3829891 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/ifrt_proxy/common/types.proto @@ -0,0 +1,43 @@ +// Copyright 2023 The OpenXLA Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +syntax = "proto3"; + +package xla.ifrt.proto; + +// Mirrors `xla::PjRtValueType`, which is used in IFRT to model +// polymorphic-typed values, e.g., `xla::ifrt::Executable::CostAnalysisValue`. +message Variant { + message Int64List { + repeated sfixed64 values = 1; + } + + oneof value { + bytes string_value = 1; + sfixed64 int64_value = 2; + Int64List int64_list = 3; + float float_value = 4; + } +} + +enum ArrayCopySemantics { + ARRAY_COPY_SEMANTICS_UNSPECIFIED = 0; + ARRAY_COPY_SEMANTICS_ALWAYS_COPY = 1; + ARRAY_COPY_SEMANTICS_REUSE_INPUT = 2; + ARRAY_COPY_SEMANTICS_DONATE_INPUT = 3; +} + +message ByteStrides { + repeated int64 strides = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_compiler.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_compiler.proto new file mode 100644 index 000000000..f9f97b100 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_compiler.proto @@ -0,0 +1,24 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/pjrt/compile_options.proto"; + +message XlaCompileOptionsProto { + xla.CompileOptionsProto compile_options = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_host_callback.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_host_callback.proto new file mode 100644 index 000000000..e583c5f94 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_host_callback.proto @@ -0,0 +1,53 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "google/protobuf/any.proto"; +import "xla/xla_data.proto"; + +// Represents a host callback in an XLA computation. +// +// XLA computation may use XLA send/recv to represent communication between the +// host and the device. This can be used to implement "host callbacks", where +// host-side computation is invoked in the middle of XLA computation. This +// message contains information that is necessary to instantiate host callbacks +// that are marshalled from the client. +// +// Modeled after `xla::HostCallback`. +message XlaHostCallbackProto { + message ArgInfo { + // The channel id associated with this value in HLO. Declared as `uint32` + // even though `xla::HostCallbackArgInfo::channel_id` is `uint16_t` because + // protobuf doesn't have a 16-bit integer type. + uint32 channel_id = 1; + + // The host shape for the value. + xla.ShapeProto shape = 2; + } + + // The metadata (e.g. channel_id, shape) for the operands and results. + repeated ArgInfo operands = 1; + repeated ArgInfo results = 2; + + // Serialized host callback. + google.protobuf.Any serialized_callback = 3; + + // See comment for PJRT + // ExecuteOptions::use_major_to_minor_data_layout_for_callbacks. + bool use_major_to_minor_data_layout_for_callbacks = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_sharding.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_sharding.proto new file mode 100644 index 000000000..6867d9713 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pjrt_ifrt/xla_sharding.proto @@ -0,0 +1,28 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.ifrt; + +import "xla/python/ifrt/device.proto"; +import "xla/xla_data.proto"; + +// Wire format for `HloSharding`. +message HloShardingProto { + DeviceListProto devices = 1; + optional string memory_kind = 3; + xla.OpSharding xla_op_sharding = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/py_host_callback.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/py_host_callback.proto new file mode 100644 index 000000000..f91e122af --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/py_host_callback.proto @@ -0,0 +1,25 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +// Represents a JAX host callback that is serialized using the 'cloudpickle' +// Python library. Typically used for +// `xla.ifrt.XlaHostCallbackProto.serialized_callback`. +message PyHostCallbackProto { + bytes callable = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pytree.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pytree.proto new file mode 100644 index 000000000..73c087ef5 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/python/pytree.proto @@ -0,0 +1,32 @@ +syntax = "proto3"; + +package jax; + +enum PyTreeNodeType { + PY_TREE_KIND_INVALID = 0; + PY_TREE_KIND_LEAF = 1; + PY_TREE_KIND_LIST = 2; + PY_TREE_KIND_NONE = 3; + PY_TREE_KIND_TUPLE = 4; + PY_TREE_KIND_DICT = 5; +} + +message DictKeysProto { + repeated uint32 str_id = 1; +} + +message PyTreeNodeDefProto { + // Recovers the tree structure. + uint32 arity = 1; + // Node type. + PyTreeNodeType type = 2; + // Only set when type == DICT. + DictKeysProto dict_keys = 3; +} + +// A Pytree. +message PyTreeDefProto { + repeated PyTreeNodeDefProto nodes = 1; + // Extra strings. + repeated string interned_strings = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/buffer_assignment.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/buffer_assignment.proto new file mode 100644 index 000000000..98d9287bd --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/buffer_assignment.proto @@ -0,0 +1,105 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.buffer_assignment; + +// This defines the buffer isolation configuration, which is a debugging tool to +// move buffers around to isolate them. The idea is to isolate buffers from the +// heap-simulator-packed assignments one at a time, with an optional padding, to +// help debug buffer corruption issues using bisection and compiler fuel. +// Consider the following heap-simulator-packed assignments: +// +// space +// ^ +// 7 | +--------+ +// 6 | +----+ | C | +// 5 | | | +--------+ +// 4 | | A | +// 3 | | | +------+ +// 2 | +--+----+--+ | B | +// 1 | | D | | | +// 0 +-------+----------+-+------+----> time +// +// A offset: 2, size: 4 +// B offset: 0, size: 3 +// C offset: 5, size: 2 +// D offset: 0, size: 2 +// Total size of heap-simulator-packed buffers: 7 +// +// *base_offset_bytes* sets the base offset of all buffers. For example, +// BufferIsolationConfig(base_offset_bytes=2, isolation_fuel=0) will produce the +// following: +// +// A offset: 4, size: 4 +// B offset: 2, size: 3 +// C offset: 7, size: 2 +// D offset: 2, size: 2 +// +// *isolation_fuel* controls how many buffers to isolate, on top of the +// heap-simulator-allocated buffers. We will use a deterministic pseudo-random +// order, using the isolation_order_salt value to ensure determinism. Assuming +// the salt value we picked happen to respect the alphabetical order of buffers, +// BufferIsolationConfig(base_offset_bytes=2, isolation_fuel=2) will produce the +// following: +// +// A offset: 9, size: 4 (isolated) +// B offset: 13, size: 3 (isolated) +// C offset: 7, size: 2 (not isolated) +// D offset: 2, size: 2 (not isolated) +// +// In contrast, BufferIsolationConfig(base_offset_bytes=2, isolation_fuel=4) +// will produce the following: +// +// A offset: 9, size: 4 (isolated) +// B offset: 13, size: 3 (isolated) +// C offset: 16, size: 2 (isolated) +// D offset: 18, size: 2 (isolated) +// +// *isolation_padding_bytes* controls extra padding between the isolated +// buffers. For example, BufferIsolationConfig(base_offset_bytes=2, +// isolation_fuel=2, isolation_padding_bytes=1) will produce the following: +// +// A offset: 10, size: 4 (isolated) +// B offset: 15, size: 3 (isolated) +// C offset: 7, size: 2 (not isolated) +// D offset: 2, size: 2 (not isolated) +// +// *isolation_order_salt* allows us to pick a deterministic pseudo-random order +// in isolating buffers. This is helpful when isolating every buffer may not be +// feasible due to out-of-memory issues, but we still want to find a scenario +// where isolating some buffers makes an error go away. For this scenario, we +// would try different salt and fuel values to find allocations where we are not +// out-of-memory and the error goes away. Then we can bisect using the fuel to +// find the first buffer to isolate that makes the error go away. For example, +// assuming the salt value of 10 happens to order buffers by reverse +// alphabetical order, BufferIsolationConfig(base_offset_bytes=2, +// isolation_fuel=2, isolation_order_salt=10) will produce the following: +// +// A offset: 4, size: 4 (not isolated) +// B offset: 2, size: 3 (not isolated) +// C offset: 11, size: 2 (isolated) +// D offset: 9, size: 2 (isolated) +// +// *isolation_colors* picks which buffer colors would be isolated. +// +message BufferIsolationConfig { + int64 base_offset_bytes = 1; + int64 isolation_fuel = 2; + int64 isolation_padding_bytes = 3; + uint64 isolation_order_salt = 4; + repeated int32 isolation_colors = 5; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/backend_config.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/backend_config.proto new file mode 100644 index 000000000..08cfbf222 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/backend_config.proto @@ -0,0 +1,54 @@ +syntax = "proto3"; + +package xla.cpu; + +// Backend config for XLA:CPU. +message BackendConfig { + // Number of partitions per outer dimension (in order, starting with + // outer-most dimension first). Used by the parallel cpu backend to partition + // HLOs into parallel tasks. + repeated int64 outer_dimension_partitions = 1; + // Configuration to be used by oneDNN matmul + OneDnnMatMulConfig onednn_matmul_config = 2; + OneDnnLayerNormConfig onednn_layer_norm_config = 3; +} + +message OneDnnMatMulConfig { + bool transpose_a = 1; + bool transpose_b = 2; + // These enum needs to be mapped to oneDNN enum for post_op algorithm. + // TODO(intel-tf): Add kinds supported by oneDNN. + enum FusionKind { + UNDEFINED = 0; + BIAS = 1; + RELU = 2; + TANH = 3; + GELU_ERF = 4; + GELU_TANH = 5; + BINARY_ADD = 6; + LINEAR = 7; + ELU = 8; + RELU6 = 9; + SIGMOID = 10; + } + repeated FusionKind fused_ops = 3; + bool bias_broadcast = 4; + // To avoid protobuf failures for specific decimal values, + // the original float value alpha is type-casted to int32. + int32 alpha_typecast = 5; + bool weights_prepacked = 6; + bool user_scratchpad = 7; +} + +message OneDnnLayerNormConfig { + // These enum needs to be mapped to oneDNN enum for post_op algorithm. + // TODO(intel-tf): Add kinds supported by oneDNN. + enum FusionKind { + UNDEFINED = 0; + SCALE = 1; + SHIFT = 2; + SCALE_AND_SHIFT = 3; + } + FusionKind fused_ops = 1; + int32 epsilon_typecast = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/executable.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/executable.proto new file mode 100644 index 000000000..2c48a51f4 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/executable.proto @@ -0,0 +1,34 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.cpu; + +import "xla/service/cpu/xla_framework.proto"; +import "xla/service/hlo.proto"; +import "xla/xla.proto"; + +message XlaRuntimeCpuExecutableProto { + optional XlaRuntimeExecutableProto xla_runtime_executable = 1; + optional XlaFrameworkMappingProto xla_framework_mapping = 2; +} + +message CompilationResultProto { + HloModuleProtoWithConfig hlo_module = 1; + BufferAssignmentProto buffer_assignment = 2; + string entry_function_name = 3; + bytes obj_file = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/xla_framework.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/xla_framework.proto new file mode 100644 index 000000000..ce4b3874d --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/cpu/xla_framework.proto @@ -0,0 +1,25 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package xla.cpu; + +message XlaFrameworkMappingProto { + repeated int64 inputs = 1 [packed = true]; + repeated int64 flattened_outputs = 2 [packed = true]; + optional int64 result = 3 [default = -1]; + optional bool output_is_tuple = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/backend_configs.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/backend_configs.proto new file mode 100644 index 000000000..18264a767 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/backend_configs.proto @@ -0,0 +1,285 @@ +syntax = "proto3"; + +package xla.gpu; + +import "xla/autotuning.proto"; +import "xla/xla_data.proto"; +import "tsl/protobuf/dnn.proto"; + +// Backend configs for XLA:GPU. +// +// These are metadata that the GPU backend attaches to HloInstructions and later +// uses during e.g. codegen. +// +// GpuBackendConfig serves as a parent config for all backend configs so +// configs won't overwrite each other. Any new backend config proto +// should be added to and used in GpuBackendConfig. +// +// Remember that proto3 doesn't give clients a way to tell the difference +// between a field not being present and a field having the default value. +// Choose your defaults carefully. +// +// No guarantee is made about the stability of these protos. +// +// See HloInstruction::backend_config() for more info. + +// Backend config for a convolution that runs through cudnn. +message CudnnConvBackendConfig { + reserved 1, 2; + + // Opaque algorithm number and tuning knobs chosen for this conv. + stream_executor.dnn.AlgorithmProto algorithm = 6; + + // The scaling factor multiplied with the convolution result. + double conv_result_scale = 4; + + // Below are the fields related to cuDNN's fused convolution. Refer to + // GpuConvParams for their meanings. + + // The requested activation (e.g. relu) after the convolution. + stream_executor.dnn.ActivationMode activation_mode = 3; + + // The scaling factor multiplied with the side input. If no side input buffer + // is provided, this field must be 0. + double side_input_scale = 5; + + // The negative slope coefficient alpha for leaky_relu activation, used only + // when activation_mode is kLeakyRelu. + // + // leakyrelu(x) is defined as x > 0 ? x : alpha * x. + // + // Since this is a proto3 proto, leakyrelu_alpha is 0 if not specified (in + // which case the leakyrelu activation is equivalent to relu). + double leakyrelu_alpha = 8; + + // If the filter (and bias, if present) have been reordered, set this flag. + // It's placed into a `oneof` block to skip the serialization when not set. + oneof filter_and_bias_reordering_oneof { + // cuDNN int8x32 vectorized convolutions (NCHW_VECT_C data layout) can be + // optimized by reordering the filter and bias (if present). The logical + // layout stays the same, but the data is shuffled in a way that is + // compatible with NVidia's IMMA instruction (sm75+). + bool reordered_int8_nchw_vect = 7; + } + + // Serialization of the graph described by the convolution and adjacent + // pointwise ops. + optional string serialized_graph = 9; +} + +// Backend config for the GEMM operation running through cuBLAS. +message GemmBackendConfig { + // Opaque optional algorithm number. No chosen number indicates that a + // different cuBLAS API will be used, which does not allow for choosing an + // algorithm. + oneof algorithm { + int64 selected_algorithm = 1; + } + + double alpha_real = 2; + double alpha_imag = 9; + + double beta = 3; + + xla.DotDimensionNumbers dot_dimension_numbers = 7; + + xla.PrecisionConfig precision_config = 12; + + // cublasLt matmul epilogue. + enum Epilogue { + DEFAULT = 0; + BIAS = 1; + RELU = 2; + BIAS_RELU = 3; + GELU = 4; + GELU_AUX = 5; + BIAS_GELU = 6; + BIAS_GELU_AUX = 7; + } + + Epilogue epilogue = 13; + + optional int64 lhs_stride = 14; + optional int64 rhs_stride = 15; + + optional bool grad_x = 16; + optional bool grad_y = 17; + bool damax_output = 18; +} + +// Backend config for bitcast operation generated from MLIR MHLO dialect. +message BitcastBackendConfig { + LayoutProto source_layout = 1; + LayoutProto result_layout = 2; +} + +// Backend config for async collective operations. Note that for is_sync will +// be false by default, so even if a backend config is not explicitly attached +// to the HLOInstruction, getting the backend_config will yield a default valued +// proto which will have is_sync = false. Attribute no_parallel_custom_call +// asserts that an asynchronous collective operation does not execute in +// parallel with custom-calls, which can trigger device synchronization . This +// attribute will also be false by default and should lead to conversative +// runtime behavior. +message CollectiveBackendConfig { + bool is_sync = 1; + bool no_parallel_custom_call = 2; +} + +// Backend config for cost model estimates. +message ReificationCost { + // Total execution time of the reified op. + double end_to_end_cycles = 1; + + // Estimated overall kernel execution in microseconds. + // + // GPU Cost Model estimates compute and memory access time separately. Exec + // time is a combined metric of the two. + double exec_time_us = 2; + + // Estimate for compute time in microseconds. + double compute_time_us = 3; + + // Estimate for memory access (read+write) time in microseconds. + double memory_access_time_us = 4; +} + +// Backend config for a custom fusion (pre-compiled device kernel implementing a +// fusion computation). +message CustomFusionConfig { + string name = 1; +} + +message CuDnnFusionConfig { + int64 plan_id = 1; +} + +message FusionBackendConfig { + // kLoop, kInput, or kOutput (from HloInstruction::FusionKind), or your own + // custom string. + // + // Don't put "kCustom" in here -- just put a string describing the custom + // fusion, like "__triton_gemm". + // + // This is somewhat redundant with HloInstruction::fusion_kind(). We need it + // here because LMHLO does not have the concept of a fusion kind, and we use + // this same backend-config proto for both HLO and LMHLO. + string kind = 1; + + // Only valid when kind == "__triton_gemm". Even then it's optional: If not + // present, we use the default Triton config. + AutotuneResult.TritonGemmKey triton_gemm_config = 2; + + // Only valid when kind == "__custom_fusion". + CustomFusionConfig custom_fusion_config = 4; + + // Cost model prediction. + ReificationCost reification_cost = 3; + + CuDnnFusionConfig cudnn_fusion_config = 5; +} + +// Backed config for norm executed by cuDNN. +message CudnnNormBackendConfig { + // Epsilon parameter. + double epsilon = 1; + + // Opaque algorithm number. + stream_executor.dnn.AlgorithmProto algorithm = 2; + + // Norm type. + enum Kind { + LAYER_FWD_INFER = 0; + LAYER_FWD_TRAIN = 1; + LAYER_BWD = 2; + } + Kind kind = 3; +} + +// Backend config for a fused Multi-Headed Attention (fMHA) that runs through +// cudnn. +message CudnnfMHABackendConfig { + // Opaque algorithm number and tuning knobs chosen for this fMHA. + stream_executor.dnn.AlgorithmProto algorithm = 8; + + // The scaling factor multiplied with the BMM1 result. fmha_scale is 1 if the + // MHA pattern has no scaling. + double fmha_scale = 10; + + // Dropout factor in MHA + double dropout_rate = 13; + + // Configs for mha bmms in the forward graph + xla.DotDimensionNumbers bmm1_dot_dimension_numbers = 11; + xla.DotDimensionNumbers bmm2_dot_dimension_numbers = 12; + + xla.ShapeProto intermediate_tensor_shape = 14; + + // Configs for mha bmms in the backward graph + xla.DotDimensionNumbers bmm1_grad_gemm1_dot_dimension_numbers = 16; + xla.DotDimensionNumbers bmm1_grad_gemm2_dot_dimension_numbers = 17; + xla.DotDimensionNumbers bmm2_grad_gemm1_dot_dimension_numbers = 18; + xla.DotDimensionNumbers bmm2_grad_gemm2_dot_dimension_numbers = 19; + + // Random seed used by dropout + int64 seed = 15; + + // Is flash attention + bool is_flash_attention = 20; + + // Is causal mask + bool is_causal_mask = 21; + + // mask type + enum MaskType { + NO_MASK = 0; + PADDING = 1; + CAUSAL = 2; + PADDING_CAUSAL = 3; + ALIBI = 4; + } + MaskType mask_type = 22; +} + +// Generic backend config for XLA:GPU +message GpuBackendConfig { + // Specifies which operation queue the current instruction will run on. + // A backend may have multiple operation queues to run instructions + // concurrently, use this to signal the backend which queue to dispatch to. + // The backend should keep a mapping of + // operation_queue_id->actual_hardware_queue_id if runtime will create + // different IDs. + int64 operation_queue_id = 1; + + // Specifies which operation queues to await for data when running with + // multiple operation queues. + repeated int64 wait_on_operation_queues = 2; + + oneof backend_config { + CudnnConvBackendConfig cudnn_conv_backend_config = 3; + + GemmBackendConfig gemm_backend_config = 4; + + BitcastBackendConfig bitcast_backend_config = 5; + + CollectiveBackendConfig collective_backend_config = 6; + + FusionBackendConfig fusion_backend_config = 7; + + CudnnNormBackendConfig cudnn_norm_backend_config = 8; + + CudnnfMHABackendConfig cudnn_fmha_backend_config = 9; + } + + // This attribute instructs the latency-hiding scheduler to + // schedule this particular instruction to the earliest position. + // Note that setting this to true will make this instruction scheduled + // at the very beginning of the parent computation before + // every other nodes. + // An example use case would be deciding to schedule between collective + // or an async compute. LHS might put either one at the first place + // depending on the cost, but it'd be more beneficial if the collective + // is always scheduled first as it's not SM-heavy. + // In this case we can use this flag to enforce the ordering. + bool force_earliest_schedule = 10; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/executable.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/executable.proto new file mode 100644 index 000000000..f4c2a479b --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/executable.proto @@ -0,0 +1,29 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.gpu; + +import "xla/service/hlo.proto"; +import "xla/xla.proto"; + +message CompilationResultProto { + HloModuleProtoWithConfig hlo_module_with_config = 1; + BufferAssignmentProto buffer_assignment = 2; + string asm_text = 3; + bytes binary = 4; + map dnn_compiled_graphs = 5; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/fusion_process_dump.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/fusion_process_dump.proto new file mode 100644 index 000000000..0c52edb46 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/fusion_process_dump.proto @@ -0,0 +1,60 @@ +syntax = "proto3"; + +package xla.gpu; + +import "xla/stream_executor/device_description.proto"; + +message FusionStep { + message Fusion { + // Name of the resulting fusion. Can be the same as producer or consumer. + string fusion_name = 1; + + // Name of the producer instruction before fusion. + string producer_name = 2; + + // Name of the consumer instruction before fusion. + string consumer_name = 3; + } + + message UpdatePriority { + // The name of the producer whose priority was updated. + string producer_name = 1; + // The names of all of the producers' consumers. + repeated string consumer_names = 2; + + // The time to execute the epilogue of each consumer (consisting of the + // producer's HLO) and read the producer's inputs from each consumer. + float us_fused = 3; + // The time to execute the producer and read the producer's outputs from + // the consumers when unfused. + float us_unfused = 4; + } + + message ProducerIneligible { + // The name of the producer. + string producer_name = 1; + // The reason why this producer cannot be fused. + string reason = 2; + } + + oneof step { + Fusion fusion = 4; + ProducerIneligible producer_ineligible = 5; + UpdatePriority update_priority = 6; + } + + reserved 1 to 3; +} + +message FusionProcessDumpProto { + repeated FusionStep fusion_steps = 1; + + stream_executor.GpuDeviceInfoProto gpu_device_info = 2; + + // HLO module before fusion in short parsable string format. The string + // represantation is compacter than HloModuleProto in this case, especially + // when the fusion process dump is stored as text proto. + // + // TODO: Consider using base64 or gzip to decrease the size of the string. + string hlo_module_before_fusion = 3; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/gpu_autotuning.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/gpu_autotuning.proto new file mode 100644 index 000000000..bcba33bf1 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/gpu_autotuning.proto @@ -0,0 +1,32 @@ +// This is used for convolution logging. +syntax = "proto3"; + +package xla.gpu; + +import "xla/autotuning.proto"; +import "xla/service/hlo.proto"; +import "xla/xla_data.proto"; + +message ConvInstructionLog { + xla.HloInstructionProto instruction = 1; + repeated xla.ShapeProto operand_shapes = 2; + repeated uint64 result_addresses = 3; + repeated uint64 operand_addresses = 4; +} + +message DenylistedAlgorithm { + int64 id = 1; + bool tensor_ops = 2; +} + +message AlgorithmDenylistEntry { + string hlo = 1; + ComputeCapability cc = 2; + CudnnVersion cudnn_version = 3; + string blas_version = 5; + repeated DenylistedAlgorithm algos = 4; +} + +message AlgorithmDenylist { + repeated AlgorithmDenylistEntry entries = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/model/hlo_op_profile.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/model/hlo_op_profile.proto new file mode 100644 index 000000000..5a0843b5a --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/gpu/model/hlo_op_profile.proto @@ -0,0 +1,18 @@ +syntax = "proto3"; + +package xla.gpu; + +import "xla/service/hlo.proto"; + +message HloInstructionProfile { + xla.HloInstructionProto instruction = 1; + int64 clock_cycles = 2; +} + +message HloInstructionProfileList { + repeated HloInstructionProfile entries = 1; +} + +message DeviceHloInstructionProfiles { + map entries = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo.proto new file mode 100644 index 000000000..b79805ec3 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo.proto @@ -0,0 +1,819 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This proto file defines messages which represent the HLO module. This is a +// full fidelity serialization of the c++ HLO constructs. +// +// Many of the protos below are simple 1-to-1 serializations of the +// corresponding C++ classes, e.g., HloModule, HloComputation, and +// HloInstruction. +// +// FIELD NAMES ARE IMPORTANT +// +// Unlike most protos, you can't safely change the names of fields, even if you +// keep the numeric ids the same. This is because we sometimes serialize these +// protos as JSON, which includes the field names in the serialization. + +syntax = "proto3"; + +package xla; + +import "google/protobuf/any.proto"; +import "xla/xla_data.proto"; + +option cc_enable_arenas = true; + +enum CustomCallSchedule { + SCHEDULE_NONE = 0; + SCHEDULE_LATEST = 1; + SCHEDULE_EARLIEST = 2; +} + +// The version of the API used by the custom call function. The signatures for +// each version are given below. +// TODO(b/189822916): Remove this enum when all clients are migrated to the +// status-returning API. +enum CustomCallApiVersion { + API_VERSION_UNSPECIFIED = 0; + + // The first version of the API, with the following signatures: + // + // CPU: + // void do_custom_call(void* out, const void** in); + // + // GPU: + // void do_custom_call(CUstream stream, void** buffers, + // const char* opaque, size_t opaque_len); + API_VERSION_ORIGINAL = 1; + + // When the ability to return success/failure status was added: + // + // CPU: + // void do_custom_call(void* out, const void** in, + // XlaCustomCallStatus* status); + // + // GPU: + // void do_custom_call(CUstream stream, void** buffers, + // const char* opaque, size_t opaque_len, + // XlaCustomCallStatus* status); + // + API_VERSION_STATUS_RETURNING = 2; + + // Fixes the API signatures on the CPU side of the version STATUS_RETURNING by + // adding the opaque string so that the custom call API is consistent across + // CPUs and GPUs. For GPUs, the behaviors invoked by + // API_VERSION_STATUS_RETURNING and API_VERSION_STATUS_RETURNING_UNIFIED are + // the same. + // + // CPU: + // void do_custom_call(void* out, const void** in, + // const char* opaque, size_t opaque_len, + // XlaCustomCallStatus* status); + // + // GPU: + // void do_custom_call(CUstream stream, void** buffers, + // const char* opaque, size_t opaque_len, + // XlaCustomCallStatus* status); + // + API_VERSION_STATUS_RETURNING_UNIFIED = 3; + + // Api version implementing XLA runtime custom call calling convention. These + // custom calls can be registered as an XLA runtime custom call (1) or as XLA + // runtime FFI binding (2). + // + // This type of custom call uses custom ABI to pass type information along + // with custom call arguments. Also it passes buffer arguments together with + // data type, sizes and strides. + // + // Example: (XLA runtime custom call) + // + // absl::Status DoCustomCall(StridedMemrefView arg, float attr); + // + // CustomCall::Bind("custom_call") + // .Arg() + // .Attr("attr") + // .To(DoCustomCall); + // + // (1) xla/runtime/custom_call.h + // (2) xla/runtime/ffi/ffi.h + API_VERSION_TYPED_FFI = 4; +} + +// Serialization of HloInstruction. +// Next ID: 87 +message HloInstructionProto { + reserved 10; + reserved "parameter_name"; + reserved 12; + reserved "fused_instructions_computation"; + reserved 4; + reserved "operand_names"; + reserved 5; + reserved "control_predecessor_names"; + reserved 6; + reserved "called_computation_names"; + reserved 44; + reserved "replica_group_ids"; + // Use backend_config instead for custom_call_opaque. + reserved 53; + reserved "custom_call_opaque"; + // Use backend_config instead for all_reduce_barrier. + reserved 46; + reserved "all_reduce_barrier"; + + string name = 1; + string opcode = 2; + xla.ShapeProto shape = 3; + + xla.OpMetadata metadata = 7; + + // Literal, only present for kConstant. + xla.LiteralProto literal = 8; + + // Parameter number is only present for kParameter. + int64 parameter_number = 9; + + // Fusion state, only present for kFusion. + string fusion_kind = 11; + + // Index for kGetTupleElement. + int64 tuple_index = 13; + + // Dimensions present for some operations that require reshaping or + // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + repeated int64 dimensions = 14; + + // Describes the window in a windowed operation such as convolution. + xla.Window window = 15; + + // Describes the dimension numbers used for a convolution. + xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + + // The number of feature groups. Used for a convolution. Must be a divisor of + // the input feature dimension and output feature dimension. If not specified, + // it will use a default value of 1. + int64 feature_group_count = 50; + + int64 batch_group_count = 58; + + // Describes the [begin, end) index range and stride for slices. + message SliceDimensions { + int64 start = 1; + int64 limit = 2; + int64 stride = 3; + } + repeated SliceDimensions slice_dimensions = 17; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits = 18; + int32 mantissa_bits = 19; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + repeated int64 dynamic_slice_sizes = 20; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. Only set for pad instructions. + xla.PaddingConfig padding_config = 21; + + // Outfeed configuration information, only present for kOutfeed. + bytes outfeed_config = 22; + + // The distribution requested for random number generation. + // Only present for kRng. + xla.RandomDistribution distribution = 23; + + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining, kBatchNormInference, and + // kBatchNormGrad. + float epsilon = 24; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining, kBatchNormInference, and + // kBatchNormGrad. + int64 feature_index = 25; + + // Represents a unique identifier for each Send/Recv instruction pair or + // optionally for collective instructions (AllReduce, CollectivePermute, + // AllToAll). Non-positive channel_id is equivalent to no channel id. + int64 channel_id = 26; + + // The string representation of the infeed configuration. + bytes infeed_config = 27; + + // Name of a external target (eg, global symbol) to call, only present for + // kCustomCall. + string custom_call_target = 28; + + // Shape of outfeed request. + xla.ShapeProto outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; + + // FFT type (FFT, IFFT, etc). + xla.FftType fft_type = 31; + + // FFT length. + repeated int64 fft_length = 32; + + // Comparison direction only used for kCompare. + string comparison_direction = 63; + + // Gather dimension numbers. + xla.GatherDimensionNumbers gather_dimension_numbers = 33; + repeated int64 gather_slice_sizes = 34; + + // Used to be compute host-related fields. + reserved 41; + reserved 42; + + // The id of this instruction. + int64 id = 35; + + repeated int64 operand_ids = 36; + repeated int64 control_predecessor_ids = 37; + repeated int64 called_computation_ids = 38; + + xla.OpSharding sharding = 40; + + // Backend configuration for the instruction. Has backend-specific meaning. + bytes backend_config = 43; + + // Cross replica op fields. + repeated ReplicaGroup replica_groups = 49; + // Deprecated, but keeping it for backward compatibility. Use channel_id. + // Non-positive all_reduce_id is equivalent to no all_reduce_id. + int64 all_reduce_id = 45 [deprecated = true]; + + // If true, interprets ids in ReplicaGroup as global device ids, which is + // a linearized id of `replica_id * partition_count + partition_id`. + bool use_global_device_ids = 71; + + // Whether this Send/Recv instruction transfers data to/from the host. Only + // present for Send and Recv instructions and their SendDone and RecvDone + // partners. + bool is_host_transfer = 47; + + // Whether this Sort instruction should be stable. + bool is_stable = 60; + + xla.ScatterDimensionNumbers scatter_dimension_numbers = 48; + + // Precision configuration for the instruction. Has backend-specific meaning. + xla.PrecisionConfig precision_config = 51; + + // Collective permute field. + repeated SourceTarget source_target_pairs = 52; + + // Sharding for kDomain instructions. + xla.OpSharding domain_entry_sharding = 54; + xla.OpSharding domain_exit_sharding = 55; + + // For custom call this indicates that the layouts are constrained. If + // constrain_layout is true then the 'shape' field must contain a layout, and + // 'operand_shapes_with_layout' must contain a shape with layout for each + // operand. + bool constrain_layout = 56; + repeated xla.ShapeProto operand_shapes_with_layout = 57; + + // Options for TriangularSolve + xla.TriangularSolveOptions triangular_solve_options = 59; + + // Options for Cholesky + xla.CholeskyOptions cholesky_options = 62; + + // Describes how parameters behave with regards to replicas. + xla.ParameterReplication parameter_replication = 61; + + reserved 64; + + // Whether the kCustomCall instruction has side-effects, only present for + // kCustomCall. + bool custom_call_has_side_effect = 65; + + // A list of OutputOperandAliasing pairs that specifies aliasing buffers + // between output and operands for kCustomCall and kFusion. + repeated xla.OutputOperandAliasing output_operand_aliasing = 74; + + // Specifies the desired schedule for the custom-call. The field is only + // present for custom-call. + CustomCallSchedule custom_call_schedule = 76; + + // The delta value for kRngGetAndUpdateState. + int64 delta = 66; + + // Specifies if the gather/scatter indices are guaranteed to be sorted by the + // caller. + bool indices_are_sorted = 67; + + // Frontend attributes to pass to the XLA backend. + xla.FrontendAttributes frontend_attributes = 68; + + // Specifies if all elements updated are guaranteed to be unique by + // the caller. + bool unique_indices = 69; + + // RNG algorithm used by kRngBitGenerator. + xla.RandomAlgorithm rng_algorithm = 70; + + // The comparison type used for kCompare. + string comparison_type = 72; + + // Specifies if this is a cross-program-prefetch, used by kCopyStart. + // Deprecated and replaced by optional_cross_program_prefetch_index. + bool is_cross_program_prefetch = 73 [deprecated = true]; + + // Specifies the cross-program-prefetch index used by kCopyStart. Uses oneof + // to emulate the 'optional' keyword for proto3 versions before v3.15.0 + // released 2021/2/18. + oneof optional_cross_program_prefetch_index { + int32 cross_program_prefetch_index = 80; + } + + // If a convolution is dynamic, a dynamic padding type will be specified. + xla.PaddingType padding_type = 75; + + // The API version used by the custom call function. This field is only + // present for custom-call. + // TODO(b/189822916): Remove this field when all clients are migrated to the + // status-returning API. + CustomCallApiVersion custom_call_api_version = 77; + + // Used to be async_group_id. + reserved 78; + + // Represents a unique execution thread name for one or more async groups. + // Each HLO module may contain a main thread and one or more parallel threads. + // Empty async_execution_thread is equivalent to main thread. + string async_execution_thread = 79; + + // Represents the K value for top-k. + int64 k = 81; + + // Represents the largest flag for top-k. + bool largest = 85; + + // Represents the information for tracking propagation of values within HLO + // graph. + xla.StatisticsViz statistics_viz = 82; + + // Used to be operation_queue_id. + reserved 83; + // Used to be wait_on_operation_queues. + reserved 84; + + // Sparsity descriptor for dot operation. + repeated xla.SparsityDescriptor dot_sparsity = 86; +} + +// Serialization of HloComputation. +message HloComputationProto { + reserved 3; + reserved "root_name"; + + string name = 1; + + // The array of instructions is always in a valid dependency order, where + // operands appear before their users. + repeated HloInstructionProto instructions = 2; + + // The program shape (with layout) of this computation. + + xla.ProgramShapeProto program_shape = 4; + + // The id of this computation. + int64 id = 5; + + // The id of the root of the computation. + int64 root_id = 6; + + // Whether this is a fusion computation. Fusion computations should use this + // to determine whether they are a fusion in CreateFromProto since the + // parent fusion_instruction_ may get removed and be nullptr. + bool is_fusion_computation = 7; + + // The name of execution thread this computation belongs to. + string execution_thread = 8; +} + +// Serialization of an HLO schedule. An HLO schedule contains a total order of +// instructions for each non-fusion computation in the module. +message HloScheduleProto { + message InstructionSequence { + repeated int64 instruction_ids = 1; + } + + // Map from computation id to sequence. + map sequences = 1; +} + +enum Kind { + // Define a UNDEFINED_ALIAS equal to zero to get around the default-0 proto3 + // behavior and missing has_*() APIs. + UNDEFINED_ALIAS = 0; + // The buffers may or may not alias at runtime. + MAY_ALIAS = 1; + // The buffers must alias at runtime. + MUST_ALIAS = 2; +} + +message HloInputOutputAliasProto { + // The following proto describes a pair of aliased an input + // (described by parameter number and a ShapeIndex of the parameter) + // and an output (described by a ShapeIndex of the root + // instruction). For example: + // + // entry = { + // output_shape_index={1}, + // parameter_number=0, + // parameter_shape_index={1, 2}, + // } + // + // This entry indicates that the first parameter's {1, 2} element is + // aliased with the {1} element of the root instruction. + message AliasEntryProto { + // ShapeIndex of the root hlo. + repeated int64 output_shape_index = 1; + // Number of the parameter in entry computation. + int64 parameter_number = 2; + // ShapeIndex of the parameter instruction. + repeated int64 parameter_shape_index = 3; + // The kind of alias to be setup. + Kind kind = 4; + } + + repeated AliasEntryProto entries = 1; +} + +message HloBufferDonorProto { + // The following proto describes an input (described by parameter number and a + // ShapeIndex of the parameter) that can donate its butter to any output + // tensor. It is similar to HloInputOutputAliasProto, but without a paired + // output. For example: + // + // entry = { + // parameter_number=0, + // parameter_shape_index={1, 2}, + // } + // + // This entry indicates that the first parameter's {1, 2} element can donate + // its buffer. + message BufferDonorEntryProto { + // Number of the parameter in entry computation. + int64 parameter_number = 1; + // ShapeIndex of the parameter instruction. + repeated int64 parameter_shape_index = 2; + } + + repeated BufferDonorEntryProto entries = 1; +} + +message CrossProgramPrefetch { + int64 parameter = 1; + repeated int64 index = 2; + int64 offset = 3; +} + +// Serialization of stack frames index representations. +// Stack frames index presented in four flat arrays: +// 1. File names array. +// 2. Function names array. +// 3. File location array. +// 4. Frame array. +// All reference ids in sub-protos are 1-based positions of the +// entity in the flat array. +// Ids are 1-based to keep 0 value as representation of non-set property. +message StackFrameIndexProto { + // Serialization of file position. + message FileLocation { + // 1-based position of file name. + int32 file_name_id = 1; + // 1-based position of function name. + int32 function_name_id = 2; + // Line number. + int32 line = 3; + // Column number. + int32 column = 4; + } + + // Serialization of frame. + message StackFrame { + // 1-based position of file location. + int32 file_location_id = 1; + // 1-based position of the parent frame. + int32 parent_frame_id = 2; + } + + // Flat index array of file names. + repeated string file_names = 1; + // Flat index array of function names. + repeated string function_names = 2; + // Flat index array of file locations. + repeated FileLocation file_locations = 3; + // Flat index array of frames. + repeated StackFrame stack_frames = 4; +} + +// Serialization of HloModule. +message HloModuleProto { + string name = 1; + string entry_computation_name = 2; + int64 entry_computation_id = 6; + + // The array of computations is always in a valid dependency order, where + // callees appear before their callers. + repeated HloComputationProto computations = 3; + + // The host program shape (with layout) of the entry computation. + xla.ProgramShapeProto host_program_shape = 4; + + // The id of this module. + int64 id = 5; + + // The schedule for this module. + HloScheduleProto schedule = 7; + + // Describes alias information between inputs and outputs. + HloInputOutputAliasProto input_output_alias = 8; + + // Describes the information of input buffer donors. + HloBufferDonorProto buffer_donor = 18; + + repeated CrossProgramPrefetch cross_program_prefetches = 10; + + // True if the module contains dynamic computation. + bool is_dynamic = 11; + + xla.OpSharding spmd_output_sharding = 12; + + repeated xla.OpSharding spmd_parameters_shardings = 14; + + // Uses AutoSharding pass or not. + bool use_auto_spmd_partitioning = 16; + + // The type of optimization profile in use for module-level optimizations. + enum ProfileType { + INVALID = 0; + FLAG = 1; + FUSION = 2; + LAYOUT = 3; + DOT = 4; + } + + // Information about the optimization profile that this module contains. + message ProfileInfo { + // The optimization profiles that this module contains. + ProfileType profile_type = 1; + // Speedup of tuned config compared to default config. + double relative_speedup = 2; + // The source of the optimization profile that this module contains. + xla.ProfileSource profile_source = 3; + // The compilation event that triggered the use of the profile. + xla.CompilationEvent compilation_event = 4; + // The fingerprint of the unoptimized module this profile was applied to. + string fingerprint = 5; + } + + // Profile information for the HLO module. + repeated ProfileInfo profile_info = 13; + + // DeviceAssignment object information. + DeviceAssignmentProto device_assignment = 15; + + // Stack frames index. + StackFrameIndexProto stack_frame_index = 17; + + // Frontend attributes to pass to the XLA backend. + xla.FrontendAttributes frontend_attributes = 19; + + reserved 9; + reserved "dynamic_parameter_binding"; +} + +// Serialization of LogicalBuffer. +message LogicalBufferProto { + // Location represents an instruction and its shape index, which uniquely + // identifies a point where a buffer is needed. + message Location { + // TODO(b/239098765): Remove instruction_name and computation_name. + string instruction_name = 2 [deprecated = true]; + int64 instruction_id = 4; + repeated int64 shape_index = 3; + + reserved 1; + } + + int64 id = 1; + int64 size = 2; + + // The location where the buffer is defined. + Location defined_at = 3; + + int64 color = 4; +} + +// Serialization of BufferAllocation. +message BufferAllocationProto { + // Assigned represents a single LogicalBuffer that is assigned to this + // BufferAllocation. + message Assigned { + int64 logical_buffer_id = 1; + int64 offset = 2; + int64 size = 3; + } + + int64 index = 1; + int64 size = 2; + bool is_thread_local = 3; + bool is_tuple = 11; + bool is_entry_computation_parameter = 5; + bool is_constant = 12; + int64 parameter_number = 6; + repeated int64 parameter_shape_index = 10; + bool maybe_live_out = 7; + int64 color = 8; + repeated Assigned assigned = 9; +} + +// A trace of a HeapSimulator run. +message HeapSimulatorTrace { + // The trace includes a list of events, where each event describes one action + // performed by the heap simulator. + message Event { + enum Kind { + ALLOC = 0; // A memory region was allocated for the buffer. + FREE = 1; // A memory region was freed for the buffer. + + // A buffer was shared with another (canonical) buffer. This is similar to + // ALLOC, except that instead of allocating a new region of memory, the + // memory region of the canonical buffer is directly re-used. Multiple + // buffers may share with the same canonical buffer. The lifetime of the + // canonical buffer is extended to the union of all lifetimes. + SHARE_WITH = 2; + } + Kind kind = 1; + + // The id of the LogicalBuffer that the event applies to. + int64 buffer_id = 2; + + // The HloInstruction that the simulation was processing that caused this + // event to occur, identified by its computation and instruction name. E.g. + // buffers defined by instruction A are allocated when processing A. + string computation_name = 3; + string instruction_name = 4; + + // The id of the canonical LogicalBuffer that the buffer shares with. Only + // set for SHARE_WITH events. + int64 share_with_canonical_id = 5; + } + repeated Event events = 1; + bool whole_module_simulation = 2; + int64 buffer_allocation_index = 3; +} + +// An abstraction representing a set of HLO module built to run concurrently +// across different devices. +message HloModuleGroupProto { + string name = 1; + repeated HloModuleProto hlo_modules = 2; +} + +// Serialization of BufferAssignment. +message BufferAssignmentProto { + // Alias represents a source LogicalBuffer, and the buffer location that + // aliases it. + message BufferAlias { + int64 source_buffer_id = 1; + LogicalBufferProto.Location location = 2; + } + + repeated LogicalBufferProto logical_buffers = 1; + repeated BufferAlias buffer_aliases = 2; + repeated BufferAllocationProto buffer_allocations = 3; + repeated HeapSimulatorTrace heap_simulator_traces = 4; +} + +// Grouping message that contains all of the information above. +message HloProto { + reserved 2; + reserved "hlo_ordering"; + + HloModuleProto hlo_module = 1; + BufferAssignmentProto buffer_assignment = 3; +} + +// Encapsulates HloProto together with the arguments, result, and +// execution_platform. This message is used for purposes such as +// analysis/replay/file-storage. +message HloSnapshot { + // The hlo graph. + HloProto hlo = 1; + + // The arguments passed to the graph. + repeated LiteralProto arguments = 2; + + // The result of the graph. + LiteralProto result = 3; + + // The name of the platform used to run the graph. + string execution_platform = 4; +} + +// Metadata for an HLO module. Dumped after HLO passes and before LLO lowering +// with filename module_####.metadata.textproto, where #### is +// canonical_module_id. +message HloModuleMetadataProto { + // Uniquely identifies an HloModuleMetadata. Equal to the first unique_id + // of the module (a module may go through multiple unique_ids). If a module + // is partitioned into multiple modules, those modules will each have a new + // HloModuleMetadata with a different canonical_module_id. + int64 canonical_module_id = 1; + + // Name of the module group that the module is part of. + string module_group_name = 2; + + // The canonical module id of the module that this one is partitioned from, + // if applicable. + int64 original_module_id = 3; + + // The canonical module ids of the modules that this one is partitioned into, + // if applicable. + repeated int64 partitioned_module_ids = 4; + + // Metadata for the HLO passes that are run on the module. + repeated HloPassMetadata pass_metadata = 5; +} + +// Metadata for one run of an HLO pass on a module. Provides more information +// when processing debug dumps of HloProtos about the order of HLO passes and +// various other stats like duration. `pass_id` may also be used to identify a +// particular run of a pass in debug info that propagates through stages of +// compilation. +message HloPassMetadata { + // For a given module, pass_id uniquely identifies a run of an HLO pass on + // that module. Note that a pass_id may not always refer to the same pass + // because the order of passes during compilation may change. For finding + // metadata for a particular pass, pass_name and pipeline_name would be more + // reliable, although note that they may not be unique. + int64 pass_id = 1; + string pass_name = 2; + string pipeline_name = 3; + + // Filenames of the dumps of the module after this pass ran. Module may be + // dumped in multiple formats, and the order of formats in this field will + // stay consistent across passes. + repeated string dump_filenames = 4; + + // Return value of pass.Run(). True if this pass changed the module, or, in + // the case where the module was run through this pass as part of a module + // group, true if this pass changed any module in the same module group. + bool module_changed = 5; + + // The unique_id of the module that this pass is run on. May be different from + // the canonical_module_id of the HloModuleMetadata that this HloPassMetadata + // is inside. + int64 module_id = 6; + + // If the module went through this pass as part of a module group, this is + // set as the ids of all the modules in the module group. Empty otherwise. + repeated int64 module_group_module_ids = 7; + + // Timestamp before and after the pass is run. Note they may be equal. + int64 start_timestamp_usec = 8; + int64 end_timestamp_usec = 9; + + // Custom metadata for the pass. + google.protobuf.Any custom_metadata = 10; +} + +// Encodes the underlying Xla runtime executable compiled from the XLA module. +message XlaRuntimeExecutableProto { + HloModuleProto hlo_module_proto = 1; + + // TODO(b/232263665)): We need to know the TargetMachine this executable was + // compiled for, otherwise we can accidentally use illegal instrauctions (e.g. + // use AVX512 when it's not available). + + // TODO(b/232263665)): Serialized executable has to know what APIs it has to + // be linked with, including the version. For example Gpu executable must be + // linked with a runtime layer that abstracts over CUDA. + + // Serialized object file compiled from the XLA module. + bytes obj_file = 3; + + // Serialized MLIR module corresponding to compiled object file. + string mlir_module = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo_execution_profile_data.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo_execution_profile_data.proto new file mode 100644 index 000000000..b0897bd06 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo_execution_profile_data.proto @@ -0,0 +1,27 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +import "xla/service/hlo_profile_printer_data.proto"; + +option cc_enable_arenas = true; + +message HloExecutionProfileData { + HloProfilePrinterData printer_data = 1; + repeated int64 profile_counters = 2; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo_profile_printer_data.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo_profile_printer_data.proto new file mode 100644 index 000000000..5231d13d6 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/hlo_profile_printer_data.proto @@ -0,0 +1,67 @@ +/* Copyright 2018 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Describes how to pretty-print a profile counter array gathered for a specific +// HloModule. +message HloProfilePrinterData { + // Pretty-printer information about an HloInstruction. + message HloInstructionInfo { + string long_name = 1; + string short_name = 2; + string category = 3; + + // Metrics computed by HloCostAnalysis. + float flop_count = 4; + float transcendental_count = 5; + reserved 6; // bytes_accessed used to erroneously be a float + int64 bytes_accessed = 9; + float optimal_seconds = 7; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloInstructionInfo. + int64 profile_index = 8; + } + + // Pretty-printer information about an HloComputation. + message HloComputationInfo { + string name = 1; + + // The index into the profile counters array for the HloComputation + // corresponding to this HloComputationInfo. + int64 profile_index = 2; + + // HloInstructionInfos for every HloInstruction in the HloComputation for + // corresponding to this HloComputattionInfo. + repeated HloInstructionInfo instruction_infos = 3; + } + + // HloComputationInfos for every HloComputation in the HloModule. + repeated HloComputationInfo computation_infos = 1; + + // The size of the profile counters array we will pretty-print. + int64 profile_counters_size = 2; + + // Maps extra metric name to the index into the profile counters array. + map extra_metrics = 3; + + // Name of the entry computation. + string entry_computation = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/memory_space_assignment/memory_space_assignment.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/memory_space_assignment/memory_space_assignment.proto new file mode 100644 index 000000000..77faa69e7 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/memory_space_assignment/memory_space_assignment.proto @@ -0,0 +1,170 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.memory_space_assignment; + +// Memory space assignment options for slicing prefetches into smaller +// asynchronous copies, reducing prefetch memory allocation pressure. +// +// No prefetch slicing is performed if max_slices == 0. +// +// TODO(b/275905276): Consider adding another option that indicates that we want +// slices of a certain size, rather than just always creating max_slices. +message SlicedPrefetchOptions { + // The maximum number of slices into which to slice a prefetch. + uint32 max_slices = 1; + + // The minimum tensor size in bytes that we will attempt to slice. + uint64 min_bytes = 2; + + // This option should never be set to true in production. When this is true, + // we will crash if we propose a slice (other than the final slice) with a + // size that is not a multiple of the required hardware alignment. Otherwise, + // we will choose not to slice such situations, which is always safe. + bool fail_on_non_alignment_boundary_slice_proposal = 3; + + // The threshold for max_slices after which we limit the permutations of slice + // times that we try when placing a sliced allocation. + uint32 all_slice_time_permutations_threshold = 4; + + // The preferred slize size for MSA sliced prefetches. 0 means there is no + // preferred slice size, in which case, we'll try to slice into max_slices. + uint64 preferred_slice_size = 5; +} + +// Options for memory-bound loop optimizations in memory space assignment. If +// enabled, this pass can optimize memory-bound unrolled loops to maximize the +// bandwidth utilized and minimize the execution time. +message MemoryBoundLoopOptimizerOptions { + // Enable the memory-bound loop optimizations. + optional bool enabled = 1; + + // The desired ratio of overlapped operations that is sufficient to overlap + // prefetches with. If this value is 1, the algorithm will try to fully + // overlap the prefetches with other compute, if less than 1, the algorithm + // may schedule prefetches such that some of the prefetch is not overlapped, + // so may become critical. For example, if this value is 0.5, we are willing + // for the prefetch time to take up to 2X of the overlapped computation time. + optional float desired_copy_ratio = 2; + + // If true, the algorithm allows a fully pipelined prefetch to be scheduled + // even if the copy resources haven't reached the desired copy ratio. A fully + // pipelined prefetch starts the same time as its counterpart in the previous + // iteration finishes. + optional bool allow_unsatisfied_fully_pipelined_prefetch = 3; + + // The minimum number of iterations that the loop needs to be unrolled for the + // memory-bound loop optimizer to kick in. + optional float min_num_iterations = 4; +} + +message TupleShapeIndex { + repeated int64 index = 1; +} + +// A message to filter operands in an HLO schedule, that can be used to override +// compiler behaviour like altering schedule etc. +message HloOperandFilter { + // Regex to match instruction name. + optional string instruction_name_regex = 1; + // Set if filtering operands of an instruction. + optional int64 operand_number = 2; + // If filtering operands based on size in bytes. + optional int64 size_gte = 3; + // If filtering operands based on size in bytes. + optional int64 size_lte = 4; + // If operand of an instruction is a tuple and indexing into the tuple is + // required. + optional TupleShapeIndex tuple_index = 5; +} + +// Options to override preferred prefetch time for an operand. +message PreferredPrefetchOverrideOptions { + oneof options { + // A value X in [0, 1] that tells us the preferred prefetch time is the + // fraction X through the live range. For example, .5 will set the + // preferred prefetch time to the middle of live range. + float prefetch_eagerness = 1; + // Preferred prefetch time is set to after the instruction with instruction + // name. + string after_instruction_name = 2; + // Preferred prefetch time is set to before the instruction with instruction + // name. + string before_instruction_name = 3; + } +} + +// Filters operands in an HLO schedule and overrides preferred prefetch times +// for those operands according to an override strategy specified in +// override_options. +message PreferredPrefetchOverride { + optional HloOperandFilter hlo_operand_filter = 1; + optional xla.memory_space_assignment.PreferredPrefetchOverrideOptions + override_options = 2; +} + +// Encloses chained override configs. The first config has highest precedence +// and so on. +message PreferredPrefetchOverrides { + repeated PreferredPrefetchOverride overrides = 1; +} + +// A message that identifies one or more HloPositions. +message HloPositionMatcher { + // Regex to match the entire instruction HLO. The HLO string is constructed + // using default HloPrintOptions. Refer to the HloPrintOptions class in + // hlo_instruction.h to know more about the format of the HLO string used for + // matching. + optional string instruction_regex = 1; + // Regex to match instruction name. + optional string instruction_name_regex = 2; + // If output of an instruction is a tuple and indexing into the + // tuple is required. + optional TupleShapeIndex tuple_index = 3; + // Filters instructions with output size in bytes greater or equal to a value. + optional int64 size_gte = 4; + // Filters instructions with output size in bytes less or equal to a value. + optional int64 size_lte = 5; +} + +// Options to override preferred prefetch time for an operand. +message MsaSortOrderOverrideOptions { + oneof options { + // Assign alternate memory to the filtered buffer before other buffers. If + // multiple buffers are to be assigned first (within the same override + // config) other tie breakers and stable sort order will take effect. + bool assign_first = 1; + // Assign alternate memory to the filtered buffer after other buffers. If + // multiple buffers are to be assigned last (within the same override + // config) other tie breakers and stable sort order will take effect. + bool assign_last = 2; + } +} + +// Specifies details on how to override the sort order for matching +// HloPositions. +message MsaSortOrderOverride { + optional HloPositionMatcher hlo_position_matcher = 1; + optional xla.memory_space_assignment.MsaSortOrderOverrideOptions + override_options = 2; +} + +// Encloses chained override configs. The first config has highest precedence +// and so on. +message MsaSortOrderOverrides { + repeated MsaSortOrderOverride overrides = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/metrics.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/metrics.proto new file mode 100644 index 000000000..d41c4edae --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/metrics.proto @@ -0,0 +1,46 @@ +syntax = "proto3"; + +package xla; + +import "google/protobuf/any.proto"; +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + +// Defines pass specific metrics. +message PassMetrics { + // Unique ID of the module on which the pass was run. + uint64 module_id = 1; + // The name of the pass. + string pass_name = 2; + // Duration of the pass. + google.protobuf.Duration pass_duration = 3; + // Custom pass metrics. This is kept opaque, via `google.protobuf.Any`, in + // order to decouple pass agnostic compilation logs from possibly proprietary + // compiler passes. + google.protobuf.Any custom_metrics = 4; +} + +// Defines XLA compilation metrics. +message CompilationLogEntry { + // Time when the event captured by this log entry occurred. + google.protobuf.Timestamp timestamp = 1; + // Defines compilation stages for which metrics are collected. + enum CompilationStage { + UNSPECIFIED = 0; + END_TO_END = 1; + HLO_PASSES = 2; + CODE_GENERATION = 3; + BACKEND_PASSES = 4; + } + // Compilation stage recorded by this log entry. + CompilationStage stage = 2; + // Duration of the given compilation stage. + google.protobuf.Duration duration = 3; + // Task index from which this log entry was recorded or + // -1 if the task index could not be fetched. + int32 task_index = 4; + // Pass specific metrics. + repeated PassMetrics pass_metrics = 5; + // IDs of modules on which the compilation stage was run. + repeated uint64 module_ids = 6; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/test_compilation_environment.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/test_compilation_environment.proto new file mode 100644 index 000000000..8aaaa61c9 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/test_compilation_environment.proto @@ -0,0 +1,30 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla.test; + +message TestCompilationEnvironment1 { + uint32 some_flag = 1; +} + +message TestCompilationEnvironment2 { + uint32 some_other_flag = 1; +} + +message TestCompilationEnvironment3 { + uint32 a_third_flag = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/xla_compile_result.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/xla_compile_result.proto new file mode 100644 index 000000000..ed5982e27 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/service/xla_compile_result.proto @@ -0,0 +1,51 @@ +/* Copyright 2023 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto2"; + +package xla; + +import "google/protobuf/duration.proto"; +import "xla/service/hlo.proto"; +import "tsl/protobuf/status.proto"; + +// Statistics on how long various parts of compilation took. +// Not all durations may be relevant for all producers of this message, in +// which irrelevant fields should simply be skipped. +message CompilerPerfStats { + // How long did it take to initialize the compiler? + optional google.protobuf.Duration init_duration = 1; + // How long did it take to verify the HLO? + optional google.protobuf.Duration hlo_verification_duration = 2; + // How long did it take to prepare for compilation after verification? + optional google.protobuf.Duration compilation_prologue_duration = 3; + // How long did it take to compile? + optional google.protobuf.Duration compilation_duration = 4; + // How long did everything take? + optional google.protobuf.Duration total_duration = 5; +} + +message CompilationResult { + // The compiled HLO. Only set when compilation succeeds. + optional xla.HloModuleProto hlo_module = 1; + // Always set when compilation succeeds. May or may not be set when + // compilation fails. + optional CompilerPerfStats perf_stats = 2; + // Always set when compilation fails; never set when compilation succeeds. + optional tensorflow.StatusProto status = 3; + // Collects counters collected during compilation. Not every producer may + // include counter support at all or any particular counter. + map counters = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/stream_executor/device_description.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/stream_executor/device_description.proto new file mode 100644 index 000000000..c01d365b7 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/stream_executor/device_description.proto @@ -0,0 +1,72 @@ +/* Copyright 2022 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package stream_executor; + +import "xla/autotune_results.proto"; + +message CudaComputeCapabilityProto { + int32 major = 1; + int32 minor = 2; +} + +message RocmComputeCapabilityProto { + string gcn_arch_name = 1; +} + +message GpuDeviceInfoProto { + int32 threads_per_block_limit = 1; + int32 threads_per_warp = 2; + int32 shared_memory_per_block = 3; + int32 shared_memory_per_core = 4; + int32 threads_per_core_limit = 5; + int32 core_count = 6; + int64 fpus_per_core = 7; + int32 block_dim_limit_x = 8; + int32 block_dim_limit_y = 9; + int32 block_dim_limit_z = 10; + int64 memory_bandwidth = 11; + int64 l2_cache_size = 12; + float clock_rate_ghz = 13; + int64 device_memory_size = 14; + int32 shared_memory_per_block_optin = 15; + oneof compute_capability { + CudaComputeCapabilityProto cuda_compute_capability = 16; + RocmComputeCapabilityProto rocm_compute_capability = 17; + } +} + +message DnnVersionInfoProto { + int32 major = 1; + int32 minor = 2; + int32 patch = 3; +} + +message GpuTargetConfigProto { + GpuDeviceInfoProto gpu_device_info = 1; + reserved 2, 3; + reserved "cuda_compute_capability", "rocm_compute_capability"; + string platform_name = 4; + DnnVersionInfoProto dnn_version_info = 5; + + // TODO(b/248362914): Autotuning results should be separate from + // GpuTargetConfig because autotuning can be updated regularly separate from + // the target. + xla.AutotuneResults autotune_results = 6; + + string device_description_str = 7; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tools/run_hlo_module.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tools/run_hlo_module.proto new file mode 100644 index 000000000..21df780d4 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tools/run_hlo_module.proto @@ -0,0 +1,21 @@ +syntax = "proto3"; + +package xla; + +import "xla/xla_data.proto"; + +message RunHloModuleIterationLiterals { + // Arguments used by the iteration. + repeated LiteralProto arguments = 2; + + // Ressult of the iteration on the target platform. + LiteralProto result = 3; + + // Result of the iteration on the reference platform. + LiteralProto reference_result = 4; +} + +message RunHloModuleLiterals { + // Iterations of run hlo module. + repeated RunHloModuleIterationLiterals iterations = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tsl/distributed_runtime/coordination/test_device.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tsl/distributed_runtime/coordination/test_device.proto new file mode 100644 index 000000000..c2b308aaf --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tsl/distributed_runtime/coordination/test_device.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +package tensorflow; + +message TestDevice { + string name = 1; + int64 local_id = 2; + int64 global_id = 3; +} + +message TestDeviceList { + repeated TestDevice device = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tsl/distributed_runtime/rpc/test_request.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tsl/distributed_runtime/rpc/test_request.proto new file mode 100644 index 000000000..f6378cd0e --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/tsl/distributed_runtime/rpc/test_request.proto @@ -0,0 +1,23 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package tsl.test; + +// Dummy proto for testing. +message TestRequest { + repeated string data = 1; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/xla.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/xla.proto new file mode 100644 index 000000000..c64a751e4 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/xla.proto @@ -0,0 +1,1204 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +import "google/protobuf/any.proto"; +import "xla/service/hlo.proto"; +import "xla/xla_data.proto"; + +// Proto version of `xla::CompilationEnvironments`. +message CompilationEnvironmentsProto { + repeated google.protobuf.Any environments = 1; +} + +// Debugging options for XLA. These options may change at any time - there are +// no guarantees about backward or forward compatibility for these fields. +message DebugOptions { + // Show addresses of HLO ops in graph dump. + bool xla_hlo_graph_addresses = 2; + + // Instrument the computation to collect per-HLO cycle counts. + bool xla_hlo_profile = 9; + + // List of HLO passes to disable/enable. These names must exactly match the + // pass names as specified by the HloPassInterface::name() method. + // + // At least one of xla_disable_hlo_passes and xla_enable_hlo_passes_only must + // be empty. + repeated string xla_disable_hlo_passes = 30; + repeated string xla_enable_hlo_passes_only = 124; + + // Disables all HLO passes. Notes that some passes are necessary for + // correctness and the invariants that must be satisfied by "fully optimized" + // HLO are different for different devices and may change over time. The only + // "guarantee", such as it is, is that if you compile XLA and dump the + // optimized HLO for some graph, you should be able to run it again on the + // same device with the same build of XLA. + bool xla_disable_all_hlo_passes = 104; + + // Numerical optimization level for the XLA compiler backend; the specific + // interpretation of this value is left to the backends. + int32 xla_backend_optimization_level = 31; + + // Embed the compiler IR as a string in the executable. + bool xla_embed_ir_in_executable = 33; + + // Eliminate implicit broadcasts when lowering user computations to HLO + // instructions; use explicit broadcast instead. + bool xla_eliminate_hlo_implicit_broadcast = 35; + + // When generating calls to Eigen in the CPU backend, use multi-threaded Eigen + // mode. + bool xla_cpu_multi_thread_eigen = 60; + + // Path to directory with cuda/ptx tools and libraries. + string xla_gpu_cuda_data_dir = 61; + + // Enable flush-to-zero semantics in the GPU backend. + bool xla_gpu_ftz = 62; + + reserved 63; // Was xla_gpu_disable_multi_streaming + reserved 134; // Was xla_gpu_use_random_streams + + // If true, in LLVM-based backends, emit !alias.scope metadata in + // generated IR. + bool xla_llvm_enable_alias_scope_metadata = 70; + + // If true, in LLVM-based backends, emit !noalias metadata in the + // generated IR. + bool xla_llvm_enable_noalias_metadata = 71; + + // If true, in LLVM-based backends, emit !invariant.load metadata in + // the generated IR. + bool xla_llvm_enable_invariant_load_metadata = 72; + + // If true, a set of expensive LLVM optimization passes will not be run. + bool xla_llvm_disable_expensive_passes = 73; + + reserved 80; // Was hlo_reduce_precision_options + + // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the + // computation will run n! times with all permunations of layouts for the + // output shape in rank n. For example, with a 3D shape, all permutations of + // the set {0, 1, 2} are tried. + bool xla_test_all_output_layouts = 90; + + // This is used by ClientLibraryTestBase::ComputeAndCompare*. If true, the + // computation will run for all permunations of layouts of all input + // arguments. For example, with 2 input arguments in 2D and 4D shapes, the + // computation will run 2! * 4! times. + bool xla_test_all_input_layouts = 91; + + // Assign colors based on sharding information when generating the Graphviz + // HLO graph. + bool xla_hlo_graph_sharding_color = 92; + + reserved 93; // Was xla_hlo_tfgraph_device_scopes + reserved 94; // Was xla_gpu_use_cudnn_batchnorm + + // Generate calls to MKL-DNN in the CPU backend. + bool xla_cpu_use_mkl_dnn = 97; + + reserved 177; // Was xla_cpu_use_xla_runtime + bool xla_cpu_use_thunk_runtime = 298; + + reserved 98; // Was xla_gpu_max_kernel_unroll_factor + + // When true, "unsafe" mathematical optimizations are enabled. These + // transformations include but are not limited to: + // + // - Reducing the precision of operations (e.g. using an approximate sin + // function, or transforming x/y into x * (1/y)). + // - Assuming that operations never produce or consume NaN or +/- Inf (this + // behavior can be adjusted using xla_cpu_fast_math_allow_{nans|infs}). + // - Assuming that +0 and -0 are indistinguishable. + bool xla_cpu_enable_fast_math = 99; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce NaNs. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_nans = 120; + + // When xla_cpu_enable_fast_math is true then this controls whether we allow + // operations to produce infinites. Ignored when xla_cpu_enable_fast_math is + // false. + bool xla_cpu_fast_math_honor_infs = 121; + + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to use the reciprocal of an argument instead of division. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_division = 126; + + // When xla_cpu_enable_fast_math is true then this controls whether we forbid + // to approximate calculations for functions. Ignored when + // xla_cpu_enable_fast_math is false. + bool xla_cpu_fast_math_honor_functions = 129; + + // When false we lower the Minimum and Maximum hlos in the CPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NaN. In other words, if flag + // this is false we always propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the gpu flag + // below! + bool xla_cpu_enable_fast_min_max = 140; + + // When true we lower the Minimum and Maximum hlos in the GPU backend such + // that Min(NotNaN, NaN) = Min(NaN, NotNaN) = NotNaN. In other words, if flag + // this is true we don't propagate NaNs through Min and Max. + // + // Note, this does not correspond to the exact same behavior as the cpu flag + // above! + bool xla_gpu_enable_fast_min_max = 100; + + reserved 207; // Was xla_cpu_sparse_cuda_threads + + // Allows xla to increase the output precision of floating point operations + // and all floating-point conversions to be simplified, including those + // that affect the numerics. The `FloatNormalization` pass inserts many + // `f32 -> bf16 -> f32` conversion pairs. These are not removed by the + // `AlgebraicSimplifier`, as that will only simplify conversions that are + // no-ops, e.g. `bf16 -> f32 -> bf16`. Removing these improves accuracy. + bool xla_allow_excess_precision = 122; + + // Crashes the program when any kind of verification fails, instead of just + // logging the failures. One example is cross checking of convolution results + // among different algorithms. + bool xla_gpu_crash_on_verification_failures = 101; + + // 0: Disable gemm and convolution autotuning. + // 1: Enable autotuning, but disable correctness checking. + // 2: Also set output buffers to random numbers during autotuning. + // 3: Also reset output buffers to random numbers after autotuning each + // algorithm. + // 4+: Also check for correct outputs and for out-of-bounds reads/writes. + // + // Default: 4. + int32 xla_gpu_autotune_level = 123; + + // Force the host platform to pretend that there are these many host + // "devices". All these devices are backed by the same threadpool. Defaults + // to 1. + // + // Setting this to anything other than 1 can increase overhead from context + // switching but we let the user override this behavior to help run tests on + // the host that run models in parallel across multiple devices. + int32 xla_force_host_platform_device_count = 102; + + // If set to true XLA:GPU invokes `ptxas` with -O0 (default is -O3). + bool xla_gpu_disable_gpuasm_optimizations = 103; + + enum ShapeChecks { + // Do not insert any shape checks for dynamically shaped operations; output + // buffers might contain garbage data if shapes don't match. + IGNORE = 0; + + // Check shapes at runtime, will insert an extra synchronization if shapes + // cannot be proven correct at compile time. + RUNTIME = 1; + + // Will refuse to compile any program where shape correctness can not be + // established at compile time. + COMPILE_TIME = 2; + } + + ShapeChecks xla_gpu_shape_checks = 170; + + reserved 171; // Was xla_cpu_enable_mlir_lowering + + reserved 173; // Was xla_gpu_enable_mlir_lowering + + reserved 179; // Was xla_gpu_enable_softmax_fusion + + // Enable fast math with eigen in the HLO evaluator. + bool xla_hlo_evaluator_use_fast_path = 106; + + // Temporary option to allow support for both the R1 and the scalar index + // versions of DynamicSlice and DynamicUpdateSlice. Only used for testing. + bool xla_allow_scalar_index_dynamic_ops = 107; + + enum StepMarkerLocation { + // Generate a step marker at the program entry. This handles the case where + // each step is done by one or multiple program execution(s). Only the first + // program will be tagged for generating a step marker at the program entry. + // This is the default. + STEP_MARK_AT_ENTRY = 0; + // Generate a step marker at each iteration of the top level while loop, + // which is assumed to be a training loop. + STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP = 1; + // Generate a step marker at each iteration of the second level while loops, + // which is assumed to be a training or eval loop. + STEP_MARK_AT_SECOND_LEVEL_WHILE_LOOP = 3; + // No step marker generated. + STEP_MARK_NONE = 2; + } + // Option to emit a target-specific marker to indicate the start of a training + // step. The location of the marker (if any) is determined by the option + // value. + StepMarkerLocation xla_step_marker_location = 108; + + // + // BEGIN flags controlling dumping HLO modules for debugging. + // + // When dumping is enabled, HLO modules dumped at the very beginning and end + // of compilation, and optionally also during the pass pipeline. + // + // In general, if you set one of these flags, we will try to infer reasonable + // defaults for the others. For example: + // + // * Setting --xla_dump_to=/tmp/foo without specifying a format + // with --xla_dump_hlo_as_* will turn on --xla_dump_hlo_as_text. + // + // * Setting --xla_dump_hlo_as_text without specifying --xla_dump_to will + // dump to stdout. + // + + // Directory to dump into. + string xla_dump_to = 109; + + // If specified, will only dump modules which match this regexp. + string xla_dump_hlo_module_re = 110; + + // If this flag is specified, will also dump HLO before and after passes that + // match this regular expression. Set to .* to dump before/after all passes. + string xla_dump_hlo_pass_re = 111; + + // Specifies the format that HLO is dumped in. Multiple of these may be + // specified. + bool xla_dump_hlo_as_text = 112; + bool xla_dump_hlo_as_proto = 113; + bool xla_dump_hlo_as_dot = 114; + bool xla_dump_hlo_as_url = 115; + + // Dump HLO graphs as an HTML (DOT -> SVG inlined in HTML) + bool xla_dump_hlo_as_html = 116; + + // Dump the visualization of the fusion progress. + bool xla_dump_fusion_visualization = 149; + + // If true, every time an HLO module is run, we will dump an HloSnapshot + // (essentially, a serialized module plus its inputs) to the --xla_dump_to + // directory. + bool xla_dump_hlo_snapshots = 118; + + // Include a timestamp in the dumped filenames. + bool xla_dump_include_timestamp = 131; + + // Max number of hlo module dumps in a directory. Set to < 0 for unbounded. + int32 xla_dump_max_hlo_modules = 132; + + // Dump HloModuleMetadata as a text proto for each HLO module. + bool xla_dump_module_metadata = 144; + + // GZip-compress protos dumped via --xla_dump_hlo_as_proto. + bool xla_dump_compress_protos = 151; + + // Dump HLO in long text format. Ignored unless xla_dump_hlo_as_text is true. + bool xla_dump_hlo_as_long_text = 164; + + // + // END flags controlling dumping HLO modules. + // + + // Overrides for XLA GPU's convolution layout heuristic. + bool xla_gpu_force_conv_nchw = 125; + bool xla_gpu_force_conv_nhwc = 146; + + // Paths to files with ptx code. + repeated string xla_gpu_ptx_file = 127; + + // Whether to dump llvm ir when compiling to ptx. + bool xla_gpu_dump_llvmir = 155; + + // Whether to dump mlir using pretty print form. + bool xla_dump_enable_mlir_pretty_form = 185; + + // Denylist for cuDNN convolutions. + string xla_gpu_algorithm_denylist_path = 128; + + reserved 130; // Was xla_gpu_deterministic_reductions + + // Debug options that trigger execution errors when NaN or Inf are detected. + bool xla_tpu_detect_nan = 135; + bool xla_tpu_detect_inf = 136; + + // True if TraceMe annotations are enabled for XLA:CPU. + bool xla_cpu_enable_xprof_traceme = 137; + + // It is usually preferable to not fallback to the driver; it can consume more + // memory, or have bugs. + bool xla_gpu_unsafe_fallback_to_driver_on_ptxas_not_found = 138; + + // Extra parameters to pass the GPU assembler. + string xla_gpu_asm_extra_flags = 141; + + // Per-heap size constraint. New heaps will be created if per-heap max size is + // reached. + int32 xla_multiheap_size_constraint_per_heap = 142; + + reserved 143; // Was xla_detailed_logging_and_dumping + + // Enable detailed logging into vlog. If this is disabled, no + // compilation summary will be printed in the end of computation. + bool xla_detailed_logging = 252; + + // Enable HLO dumping. If this is disabled, no HLO modules will be dumped. + bool xla_enable_dumping = 253; + + // Overrides normal multi-threaded compilation setting to use this many + // threads. Setting to 0 (the default value) means no enforcement. + int32 xla_gpu_force_compilation_parallelism = 147; + bool xla_gpu_enable_llvm_module_compilation_parallelism = 268; + + // Guarantees run-to-run determinism. + // This flag implies --xla_gpu_exclude_nondeterministic_ops and in addition + // disables autotuning. + bool xla_gpu_deterministic_ops = 148; + + // Paths to files with LLVM code. + repeated string xla_gpu_llvm_ir_file = 150; + + // Enum to define all collective ops + // that xla supports. + enum CollectiveOpType { + NOOP = 0; + ALLREDUCE = 1; + ALLGATHER = 2; + REDUCESCATTER = 3; + COLLECTIVEBROADCAST = 4; + ALLTOALL = 5; + COLLECTIVEPERMUTE = 6; + } + + repeated CollectiveOpType xla_gpu_disable_async_collectives = 289; + + // Used to be xla_gpu_enable_async_all_reduce + // xla_gpu_enable_async_collective_broadcast + // xla_gpu_enable_async_collective_permute + // xla_gpu_enable_async_all_gather + // xla_gpu_enable_async_reduce_scatter + // xla_gpu_enable_async_all_to_all + // xla_gpu_enable_async_collectives + reserved 152, 278, 183, 199, 200, 201, 238; + + // Size threshold (in bytes) for the GPU collective combiners. + int64 xla_gpu_all_reduce_combine_threshold_bytes = 157; + int64 xla_gpu_all_gather_combine_threshold_bytes = 212; + int64 xla_gpu_reduce_scatter_combine_threshold_bytes = 213; + + // Combine all-gather/scatter-reduce ops with the same dimension or + // irrespective of their dimension. + bool xla_gpu_enable_all_gather_combine_by_dim = 254; + bool xla_gpu_enable_reduce_scatter_combine_by_dim = 257; + + // Combine GPU all-reduces into a single operation over a contiguous buffer. + bool xla_gpu_all_reduce_contiguous = 158; + + // Enable allreduce reassociation on allreduces that are converted to a wider + // type. The resulting allreduce will be promoted to a wider-typed allreduce. + bool xla_gpu_enable_reassociation_for_converted_ar = 209; + + // Number of devices per host for first stage of BlueConnect decomposition + // pass. The pass will attempt to decompose all-reduces ops into a + // ReduceScatter-AllReduce-AllGather sequence, with the initial ReduceScatter + // being performed over all of the devices in the same host. Set to < 1 to + // disable all-reduce decomposition. + int32 xla_gpu_all_reduce_blueconnect_num_devices_per_host = 159; + + // Enable hoisting of reduce-scatter out of while loops. + bool xla_gpu_enable_while_loop_reduce_scatter_code_motion = 203; + + // Inflate collective cost by running each collective multiple times. + int32 xla_gpu_collective_inflation_factor = 205; + + // Whether to force inline before llvm module split to get a more balanced + // splits for parallel compilation. + bool xla_llvm_force_inline_before_split = 300; + + // Whether to use the cuDNN frontend API for convolutions when possible. + bool xla_gpu_enable_cudnn_frontend = 160; + + bool xla_gpu_enable_cudnn_fmha = 218; + bool xla_gpu_fused_attention_use_cudnn_rng = 235; + + // Rewrite layer norm patterns into cuDNN library calls. + bool xla_gpu_enable_cudnn_layer_norm = 262; + + // Disable dumping metadata in HLO dumps. + bool xla_dump_disable_metadata = 153; + + // If this flag is specified, will only dump HLO before and after passes in + // the pass pipeline that matches this regular expression. Default empty value + // enables dumping in all pipelines. + string xla_dump_hlo_pipeline_re = 154; + + // If true, abort immediately when conv algorithm picker fails, rather than + // logging a warning and proceeding with fallback. + bool xla_gpu_strict_conv_algorithm_picker = 156; + + reserved 161; // Was xla_gpu_bef_executable + reserved 162; // Was xla_gpu_bef_thunk + + reserved 169; // Was xla_gpu_enable_xla_runtime_executable + + // If true, XLA will try to pattern match subgraphs of HLO operations into + // custom fusions registered in the current process (pre-compiled hand written + // kernels, e.g. various GEMM fusions writtent in CUTLASS). + bool xla_gpu_enable_custom_fusions = 263; + + // A regular expression enabling only a subset of custom fusions. Enabled only + // if `xla_gpu_enable_custom_fusion` set to true. + string xla_gpu_enable_custom_fusions_re = 264; + + // Enables address computation fusion to optimize dynamic-slice and + // dynamic-update-slice operations around library calls. + bool xla_gpu_enable_address_computation_fusion = 105; + + reserved 233; // was xla_gpu_enable_gpu2_runtime + reserved 234; // was xla_gpu_enable_gpu2_hal + + // Timeout in seconds before terminating jobs that are stuck in a NCCL + // Rendezvous. Negative value disables the timeout and will not terminate. + int64 xla_gpu_nccl_termination_timeout_seconds = 163; + + // Enables shared constants for XLA/GPU. This allows large constants to be + // shared among multiple GPU executables. + bool xla_gpu_enable_shared_constants = 165; + + // Whether to use cuBLASLt for GEMMs on GPUs. + bool xla_gpu_enable_cublaslt = 166; + + // Commands are categorized into 5 types: + // FUSION represents regular fusion kernels. + // CUBLAS/CUBLASLT, CUDNN, and COLLECTIVES represent library calls. + // CONDITIONALS represents control flow. + enum CommandBufferCmdType { + INVALID = 0; + FUSION = 1; + CUBLAS = 2; + CUDNN = 3; + COLLECTIVES = 4; + CONDITIONALS = 5; + CUSTOM_CALL = 6; + CUBLASLT = 7; + } + + // Determine the types of commands that are recorded into command buffers. + repeated CommandBufferCmdType xla_gpu_enable_command_buffer = 258; + + reserved 202; // Was xla_gpu_graph_num_runs_to_instantiate + + // This number determines how many moved instructions like fusion kernels are + // required for a region to be captured as a function to be launched as a GPU + // graph. + int32 xla_gpu_graph_min_graph_size = 208; + + // Identify concurrent regions in GPU graphs and execute them concurrently. + bool xla_gpu_graph_enable_concurrent_region = 215; + + reserved 230; // Was xla_gpu_graph_eviction_timeout_seconds + + // Size threshold (in megabytes) for the GPU redzone scratch allocator. + int64 xla_gpu_redzone_scratch_max_megabytes = 167; + + // Amount of padding the redzone allocator will put on one side of each buffer + // it allocates. (So the buffer's total size will be increased by 2x this + // value.) + // + // Higher values make it more likely that we'll catch an out-of-bounds read or + // write. Smaller values consume less memory during autotuning. Note that a + // fused cudnn conv has up to 6 total buffers (4 inputs, 1 output, and 1 + // scratch), so this can be multiplied by quite a lot. + int64 xla_gpu_redzone_padding_bytes = 228; + + reserved 168; // Was xla_gpu_simplify_all_fp_conversions. + + reserved 172; // Was xla_gpu_normalize_layouts. + + // Generate calls to Arm Compute Library in the CPU backend. + bool xla_cpu_use_acl = 174; + + // By default, XLA:CPU will run fp16 dot/conv as fp32, as this is generally + // (much) faster on our hardware. Set this flag to disable this behavior. + bool xla_cpu_strict_dot_conv_math = 175; + + // An option to enable using cuDNN runtime compiled fusion kernels which is + // available and recommended for Ampere+ GPUs. + bool xla_gpu_use_runtime_fusion = 181; + + bool xla_dump_latency_hiding_schedule = 182; + + // By default, MLIR lowering will use Linalg elementwise fusion. If this flag + // is enabled, the pipeline will use tiling, fusion, peeling, vectorization + // instead. + bool xla_cpu_enable_mlir_tiling_and_fusion = 184; + + // XLA:CPU-Next tiling parameters for matmul. + bool xla_cpu_enable_custom_matmul_tiling = 195; + int64 xla_cpu_matmul_tiling_m_dim = 196; + int64 xla_cpu_matmul_tiling_n_dim = 197; + int64 xla_cpu_matmul_tiling_k_dim = 198; + + bool xla_cpu_enable_mlir_fusion_outlining = 192; + + // If set, use the experimental deallocation pass from mlir-hlo. + bool xla_cpu_enable_experimental_deallocation = 191; + + bool xla_gpu_enable_latency_hiding_scheduler = 186; + bool xla_gpu_enable_highest_priority_async_stream = 216; + bool xla_gpu_enable_analytical_latency_estimator = 255; + + bool xla_gpu_lhs_enable_gpu_async_tracker = 204; + string xla_gpu_pgle_profile_file_or_directory_path = 210; + int32 xla_gpu_memory_limit_slop_factor = 260; + + bool xla_gpu_enable_pipelined_collectives = 239; + bool xla_gpu_enable_pipelined_all_reduce = 217; + bool xla_gpu_enable_pipelined_all_gather = 227; + bool xla_gpu_enable_pipelined_reduce_scatter = 231; + bool xla_gpu_enable_pipelined_p2p = 246; + + // The minimum data size in bytes to trigger collective-permute-decomposer + // transformation. + int64 xla_gpu_collective_permute_decomposer_threshold = 237; + + enum PartitioningAlgorithm { + PARTITIONING_ALGORITHM_NOOP = 0; + PARTITIONING_ALGORITHM_EXP0 = 1; + PARTITIONING_ALGORITHM_EXP1 = 2; + PARTITIONING_ALGORITHM_EXP2 = 3; + } + // The partitioning algorithm to be used in the PartitionAssignment pass. + PartitioningAlgorithm xla_partitioning_algorithm = 187; + + bool xla_gpu_enable_triton_gemm = 188; + + bool xla_gpu_enable_cudnn_int8x32_convolution_reordering = 189; + + // Creates triton fusion for all supported gemms. + // To make sure only triton gemm is chosen by the autotuner run with + // `xla_gpu_cublas_fallback` set to false. + bool xla_gpu_triton_gemm_any = 190; + + reserved 211; // Was xla_gpu_enable_dot_strength_reduction + + bool xla_gpu_exhaustive_tiling_search = 219; + + bool xla_gpu_enable_triton_softmax_fusion = 220; + + bool xla_gpu_enable_priority_fusion = 221; + bool xla_gpu_enable_triton_softmax_priority_fusion = 286; + + // File to write autotune results to. It will be a binary file unless the name + // ends with .txt or .textproto. Warning: The results are written at every + // compilation, possibly multiple times per process. This only works on CUDA. + string xla_gpu_dump_autotune_results_to = 222; + + // File to load autotune results from. It will be considered a binary file + // unless the name ends with .txt or .textproto. At most one loading will + // happen during the lifetime of one process, even if the first one is + // unsuccessful or different file paths are passed here. This only works on + // CUDA. + string xla_gpu_load_autotune_results_from = 223; + + // Description of the target platform in GpuTargetConfigProto format; if + // provided, deviceless compilation is assumed, and the current device is + // ignored. + string xla_gpu_target_config_filename = 261; + + // Memory budget in GB per device for AutoSharding. + int32 xla_gpu_auto_spmd_partitioning_memory_budget_gb = 224; + + // See the definition of the + // xla_gpu_auto_spmd_partitioning_memory_budget_ratio flag for the meaning of + // this field. + float xla_gpu_auto_spmd_partitioning_memory_budget_ratio = 225; + + bool xla_gpu_triton_gemm_disable_reduced_precision_reduction = 226; + + int32 xla_gpu_triton_fusion_level = 229; + + bool xla_gpu_dump_autotuned_gemm_fusions = 232; + + string xla_gpu_override_gemm_autotuner = 295; + + bool xla_gpu_copy_insertion_use_region_analysis = 236; + + // If true, each fusion instruction will have a cost model runtime estimate in + // backend config after compilation. + bool xla_gpu_collect_cost_model_stats = 240; + + bool xla_gpu_enable_split_k_autotuning = 241; + + // Whether reduction epilogue fusion is enabled in fusion passes. + bool xla_gpu_enable_reduction_epilogue_fusion = 243; + // Allow early return when acquiring NCCL cliques. + bool xla_gpu_enable_nccl_clique_optimization = 244; + + // Replace custom calls with noop operations. + bool xla_gpu_mock_custom_calls = 245; + + // Allow Triton GEMM autotuning to fall back to cuBLAS when that is + // faster. + bool xla_gpu_cublas_fallback = 247; + + // Enable double buffering for loops. + bool xla_gpu_enable_while_loop_double_buffering = 248; + + enum WhileLoopUnrolling { + WHILE_LOOP_UNROLLING_NO_UNROLL = 0; + // Has the same effect as setting + // `xla_gpu_enable_while_loop_double_buffering`. + WHILE_LOOP_UNROLLING_DOUBLE_BUFFER = 1; + // Enables full loop unrolling using the same strategy as `DOUBLE_BUFFER`. + WHILE_LOOP_UNROLLING_FULL_UNROLL = 2; + } + + // Determine the while loop unrolling scheme. + WhileLoopUnrolling xla_gpu_enable_while_loop_unrolling = 294; + + // Change the layout of the second triton dot operand to be column major. + // Only works for (bf16 x bf16) -> bf16. + bool xla_gpu_ensure_minor_dot_contraction_dims = 249; + + // Filter out kernels that spill registers during autotuning. + bool xla_gpu_filter_kernels_spilling_registers_on_autotuning = 250; + + // Maximum number of buffers to print when debugging buffer assignment. + int64 xla_debug_buffer_assignment_show_max = 251; + + int32 xla_gpu_llvm_verification_level = 256; + + // Enable radix sort using CUB. + bool xla_gpu_enable_cub_radix_sort = 259; + + // Threshold to enable windowed einsum (collective matmul) in MB. + int64 xla_gpu_threshold_for_windowed_einsum_mib = 265; + + // Enables currently disabled features within Triton for Hopper. + bool xla_gpu_enable_triton_hopper = 266; + + // Enable NCCL user buffers. + bool xla_gpu_enable_nccl_user_buffers = 267; + + // Enable NCCL communicator splitting. + bool xla_gpu_enable_nccl_comm_splitting = 272; + + // Enable NCCL per stream communicators. + bool xla_gpu_enable_nccl_per_stream_comms = 276; + + // If enabled, uses the libnvptxcompiler library to compile PTX to cuBIN. + bool xla_gpu_enable_libnvptxcompiler = 269; + + bool xla_gpu_enable_dot_strength_reduction = 270; + // Whether to use multiple compute streams to run windowed einsum. + bool xla_gpu_multi_streamed_windowed_einsum = 280; + + // If enabled, uses bf16_6way gemm to compute F32 gemm. + bool xla_gpu_enable_bf16_6way_gemm = 271; + + // If enabled, uses bf16_3way gemm to compute F32 gemm. + bool xla_gpu_enable_bf16_3way_gemm = 279; + + // Specify the maximum number of channels(SMs) NCCL + // will use for collective operations. + int64 xla_gpu_nccl_collective_max_nchannels = 273; + + // Specify the maximum number of channels(SMs) NCCL + // will use for p2p operations. + int64 xla_gpu_nccl_p2p_max_nchannels = 274; + + bool xla_gpu_enable_mlir_emitters = 275; + // The maximum number of kernels to emit with MLIR. Unlimited if 0. + int64 xla_gpu_max_mlir_kernels = 281; + // The number of initial kernels to not emit with MLIR. Only supported kernels + // are counted. + int64 xla_gpu_skip_mlir_kernels = 282; + + // Threshold to rewrite matmul to cuBLAS or Triton (minumum combined number of + // elements of both matrices in non-batch dimensions to be considered for a + // rewrite). + int64 xla_gpu_gemm_rewrite_size_threshold = 283; + + // If true, will require complete AOT autotuning results; in the case of + // missing AOT result, the model will not be compiled or executed, a + // `NotFound` error will be returned. + bool xla_gpu_require_complete_aot_autotune_results = 284; + + // Let GEMM fusion autotuning probe cuDNN as a backend. + // Current levels: + // 0: Disabled. + // 1: Fusions of GEMM, elementwise, transpose/reshape operations. + // 2: + Broadcasts. + // 3: + Nontrivial noncontracting dimension reshapes/transposes. + int32 xla_gpu_cudnn_gemm_fusion_level = 285; + + // This instructs the runtime whether to use + // memcpy for p2p communication when source and + // target are located within a node(nvlink). + bool xla_gpu_use_memcpy_local_p2p = 287; + + // If non-zero, limits the number of solutions to be used by GEMM autotuner. + // This might be useful if underlying math library returns too many GEMM + // solutions. + int64 xla_gpu_autotune_max_solutions = 288; + + // If true, large constants will be printed out when dumping HLOs. + bool xla_dump_large_constants = 290; + + // If true, will verify that the numerical results of Triton fusions match + // the results of regular emitters. + bool xla_gpu_verify_triton_fusion_numerics = 291; + + // File to write autotune logs to. It will stored in txt format. + string xla_gpu_dump_autotune_logs_to = 292; + + // Base length to rewrite the reduce window to, no rewrite if set to 0. + int64 xla_reduce_window_rewrite_base_length = 293; + + // If true, will enable host memory offloading on a device. + bool xla_gpu_enable_host_memory_offloading = 296; + + // Excludes non-deterministic ops from compiled executables. + // Unlike --xla_gpu_deterministic_ops does not disable autotuning - the + // compilation itself can be non-deterministic. + // At present, the HLO op SelectAndScatter does not have a + // deterministic XLA:GPU implementation. + // Compilation errors out if SelectAndScatter is encountered. + // Scatter ops can non-deterministic by default; these get converted to + // a deterministic implementation. + bool xla_gpu_exclude_nondeterministic_ops = 297; + + // Next id: 299 + + // Extra options to pass to the compilation backend (e.g. LLVM); specific + // interpretation of these values is left to the backend. + map xla_backend_extra_options = 500; + + // Reserved tags were xla_hlo_dump_as_graphdef, xla_dump_to, + // xla_gpu_use_horizontal_fusion, + // xla_gpu_unsafe_fallback_to_driver_on_ptxas_error, + // xla_gpu_simplify_scatters, xla_gpu_simplify_gathers + // xla_gpu_enable_cuda_graphs + // xla_gpu_allow_all_reduce_kernel + // xla_gpu_enable_experimental_block_size + // xla_gpu_graph_level + // xla_gpu_single_wave_autotuning + // xla_gpu_enable_persistent_temp_buffers + reserved 5, 117, 133, 139, 176, 178, 180, 193, 214, 194, 242, 206; +} + +// Contains flags which affects the GPU compilation result. +// These flags are part of Debug Options as of now, and will be migrated to +// this proto. +message GpuCompilationEnvironment { + // Temporary dummy flag is added to test the flow. + // To be removed when we add flags here. + int64 dummy_flag = 1; +} + +message ShardableValueUpdatePairProto { + int64 input_parameter_number = 1; + repeated int64 parameter_shape_index = 2; + repeated int64 output_shape_index = 3; +} + +// These settings control how XLA compiles and/or runs code. Not all settings +// will have an effect on every platform. +// +// When adding new fields, keep in mind that boolean fields default to false. +// Next id: 24. +message ExecutionOptions { + // This optional field's layout is used as a hint when storing the output of + // this computation. Subsequent transfers of this output array to the client + // may be faster when using this layout. + // + // We use a Shape here to accommodate computations that return a tuple. + ShapeProto shape_with_output_layout = 2; + + // Used to seed random-number generators used in this computation. If this is + // 0, we generate a seed ourselves. + // + // TODO(b/32083678): Changing the seed unnecessarily forces a recompilation. + uint64 seed = 3; + + DebugOptions debug_options = 4; + + // This optional field specifies a particular set of devices to run the + // computation on. The computation will be partitioned across these devices. + // If not provided, the default device will be chosen. + repeated DeviceHandle device_handles = 5; + + // Number of replicas of the computation to run. If zero, uses the default + // number of replicas for the XLA service. + int32 num_replicas = 6; + + // This optional field specifies the device assignment if known at compile + // time. + DeviceAssignmentProto device_assignment = 7; + + // Alias input and output buffers for parameters that are passed-through XLA + // modules without being changed. + bool alias_passthrough_params = 8; + + // Number of partitions of the computation to run (model parallelism). + // If zero, uses the default number of partitions for the XLA service. + int32 num_partitions = 9; + + // Used to identify a set of programs that should be launch together. + int32 launch_id = 10; + + // Indicates whether to use SPMD (true) or MPMD (false) partitioning when + // num_partitions > 1 and XLA is requested to partition the input program. + bool use_spmd_partitioning = 11; + + // Whether to automatically generate XLA shardings for SPMD partitioner. + bool use_auto_spmd_partitioning = 15; + + // Device mesh shape used to create the sharding search space when + // use_auto_spmd_partitioning=true. + repeated int64 auto_spmd_partitioning_mesh_shape = 16; + + // Device mesh ids compatible with the above mesh_shape used when + // use_auto_spmd_partitioning=true. + repeated int64 auto_spmd_partitioning_mesh_ids = 17; + + // If set, deduplicate hlo into function calls to reduce binary size. Only + // works on TPU. + bool deduplicate_hlo = 12; + + reserved 13; // Was broadcast_replicated_parameters_via_collectives + + // Allows sharding propagation to propagate to the parameters. This changes + // the input shape of the computation (which is undesirable), but it can be + // used to allow to run partial compilation to determine what would be the + // input sharding of a computation if XLA would be allowed to propagate the + // sharding which can be used by higher level framework as a way to query + // intermediate sharding of operations when multiple computation would be + // chained and merged together. + // This is a vector of bool, because the user can control which parameters can + // have the sharding substituted. If only one boolean value is passed in the + // vector that is interpreted as the value to be applied for every parameter. + repeated bool allow_spmd_sharding_propagation_to_parameters = 23; + + // Allows sharding propagation to propagate to the outputs. This changes the + // output shape of the computation (which is undesirable), but it can be used + // to allow to run partial compilation to determine what would be the output + // sharding of a computation if XLA would be allowed to propagate the sharding + // which can be used by higher level framework as a way to query intermediate + // sharding of operations when multiple computation would be chained and + // merged together. + // This is a vector of bool, because the user can control (if the output of + // the computation is a tuple) which elements of the tuple can have the + // sharding substituted and which don't. If only one boolean value is passed + // in the vector that's interpreted as the value to be applied for every + // single element of the output tuple. One value per element of the tuple + // means that each value is attached to one of the output elements. + repeated bool allow_spmd_sharding_propagation_to_output = 14; + + // Whether to broadcast args across all replicas. One entry per arg. + repeated bool param_requires_broadcast_via_collectives = 18; + + // If enabled, the compiler may generate sharding and unsharding programs as + // separate HLO modules, and modify the main program's input and output to + // be sharded. + bool allow_separate_sharding_programs = 19; + + // The list of input/output pairs in the main program that could be sharded. + repeated ShardableValueUpdatePairProto shardable_value_update_pairs = 20; + + // Profiling data for feedback directed optimizations. Note that this is not + // the only way to feed FDO data into the compiler and individual backends + // may choose to get FDO data by other means. + bytes fdo_profile = 21; + + // Amount of device memory available for the executable to use. + int64 device_memory_size = 22; +} + +// Serialization of HloModuleConfig. See the C++ class definition for +// descriptions of each field. +// There are no guarantees of backwards or forwards compatibility. +// Next id: 34. +message HloModuleConfigProto { + enum FusionConfigCollection { + OFF = 0; // Do not collect configuration. + PER_EDGE = 1; // Collect per-edge configuration. + PER_NODE = 2; // Collect per-node configuration. + } + + message BoolList { + repeated bool vals = 1; + } + message Int64List { + repeated int64 vals = 1; + } + message Int64ListList { + repeated Int64List lists = 1; + } + + xla.ProgramShapeProto entry_computation_layout = 1; + uint64 seed = 2; + int32 launch_id = 3; + int64 replica_count = 4; + int64 num_partitions = 5; + repeated bool param_requires_broadcast_via_collectives = 6; + bool use_spmd_partitioning = 7; + bool use_auto_spmd_partitioning = 8; + repeated int64 auto_spmd_partitioning_mesh_shape = 9; + repeated int64 auto_spmd_partitioning_mesh_ids = 10; + bool deduplicate_hlo = 11; + int64 intra_op_parallelism_threads = 12; + string device_type = 13; + + DebugOptions debug_options = 14; + DeviceAssignmentProto static_device_assignment = 15; + bool allow_separate_sharding_programs = 30; + repeated ShardableValueUpdatePairProto shardable_value_update_pairs = 16; + bool alias_passthrough_params = 17; + bool content_aware_computation_sorting = 18; + FusionConfigCollection fusion_config_collection = 19; + + repeated BoolList fusion_config = 20; + map dot_config = 21; + repeated Int64ListList layout_config = 22; + + repeated uint64 memory_space_assignment_config = 23; + repeated BoolList phase_ordering_config = 24; + int32 phase_index = 25; + reserved 26; // Was flag_config + repeated bool allow_spmd_sharding_propagation_to_parameters = 33; + repeated bool allow_spmd_sharding_propagation_to_output = 27; + map analysis_allowance_map = 28; + xla.PrecisionConfig.Precision matrix_unit_operand_precision = 29; + bytes fdo_profile = 31; + int64 device_memory_size = 32; +} + +message HloModuleProtoWithConfig { + HloModuleProto hlo_module = 1; + HloModuleConfigProto config = 2; +} + +message GetDeviceHandlesRequest { + int64 device_count = 1; +} + +message GetDeviceHandlesResponse { + repeated DeviceHandle device_handles = 1; +} + +message TransferToClientRequest { + GlobalDataHandle data = 1; + + // This optional field directs the service to return the literal in this + // layout. A shape is used to hold the layout to accommodate tuples. + ShapeProto shape_with_layout = 2; +} + +message TransferToClientResponse { + LiteralProto literal = 1; +} + +message TransferToServerRequest { + LiteralProto literal = 1; + DeviceHandle device_handle = 2; +} + +message TransferToServerResponse { + GlobalDataHandle data = 1; +} + +message TransferToInfeedRequest { + LiteralProto literal = 1; + int64 replica_id = 2; + DeviceHandle device_handle = 3; +} + +message TransferToInfeedResponse {} + +message TransferFromOutfeedRequest { + // This optional field directs the service to return the literal in this + // layout. A shape is used to hold the layout to accommodate tuples. + ShapeProto shape_with_layout = 1; + + int64 replica_id = 2; + DeviceHandle device_handle = 3; +} + +message TransferFromOutfeedResponse { + LiteralProto literal = 1; +} + +message ResetDeviceRequest { + DeviceHandle device_handle = 1; +} + +message ResetDeviceResponse {} + +message ComputationGraphStatsRequest { + HloModuleProto computation = 1; + DebugOptions debug_options = 2; +} + +message ComputationStatsResponse { + ComputationStats stats = 1; +} + +message CreateChannelHandleRequest { + ChannelHandle.ChannelType channel_type = 1; +} + +message CreateChannelHandleResponse { + ChannelHandle channel = 1; +} + +message UnregisterRequest { + repeated GlobalDataHandle data = 1; +} + +message UnregisterResponse {} + +message CompileRequest { + // The graph to be compiled. + HloModuleProto computation = 1; + + // Options that affect how XLA compiles code to service this request. + ExecutionOptions execution_options = 2; + + // The layouts of the input arguments. If not set, the default layout will be + // used. Although the real arguments are not needed in compilation, the + // layouts of the arguments can affect the compilation. + repeated ShapeProto input_shape_with_layout = 3; +} + +message CompileResponse { + // The handle to the executable. + ExecutionHandle handle = 1; +} + +message ExecuteRequest { + ExecutionHandle handle = 1; + + // The shape and layout of the arguments must be the same as the those of the + // executable's parameters. + repeated GlobalDataHandle arguments = 2; +} + +// TODO(b/118493728): Remove this and ExecuteGraphParallelRequest and replace +// the uses with calls to Compile and Execute. +message ExecuteGraphRequest { + HloModuleProto computation = 1; + repeated GlobalDataHandle arguments = 2; + + // Options that affect how XLA compiles and runs code to service this request. + ExecutionOptions execution_options = 3; +} + +message ExecuteGraphParallelRequest { + repeated ExecuteGraphRequest requests = 1; +} + +message ExecuteResponse { + GlobalDataHandle output = 1; + ExecutionProfile profile = 2; +} + +message ExecuteParallelResponse { + repeated ExecuteResponse responses = 1; +} + +message ComputeConstantGraphRequest { + HloModuleProto computation = 1; + LayoutProto output_layout = 2; +} + +message ComputeConstantResponse { + // A LiteralProto is returned directly for this request. + LiteralProto literal = 1; +} + +message DeconstructTupleRequest { + GlobalDataHandle tuple_handle = 2; +} + +message DeconstructTupleResponse { + repeated GlobalDataHandle element_handles = 1; +} + +message LoadDataRequest { + // Describes the path of the ColumnIO tablet to load. + string columnio_tablet_path = 1; + + // Describes the field to load within the ColumnIO tablet. + string columnio_field = 2; + + // Individual element shape, excluding rows. + ShapeProto element_shape = 3; + + // Warning: ColumnIO does not support random-access, so use offset with + // caution in performance-critical scenarios. + int64 offset = 4; + + // Maximum number of elements (with shape element_shape) to load. + int64 limit = 5; + + // If more than one item is requested (via limit > 1), then this request + // attribute zips together the produced vectors. + bool zip = 6; +} + +message LoadDataResponse { + GlobalDataHandle data = 1; + ShapeProto data_shape = 2; + int64 available_rows = 3; + int64 rows_loaded = 4; + int64 nanoseconds = 5; +} + +message GetShapeRequest { + GlobalDataHandle data = 1; +} + +message GetShapeResponse { + ShapeProto shape = 1; +} + +message UnpackRequest { + GlobalDataHandle data = 1; +} + +message UnpackResponse { + repeated GlobalDataHandle tied_data = 1; +} + +// A trace estimated by the Latency Hiding Scheduler. +message ScheduleProto { + message Instruction { + // Instruction id (matches the id in HloInstructionProto). + int64 id = 1; + + // Start and end timestamps in cycles. + double start_timestamp_cycles = 2; + double end_timestamp_cycles = 3; + } + repeated Instruction instructions = 1; + // Computation id (matches the id in HloComputationProto). + int64 computation_id = 2; + HloModuleProto hlo_module = 3; + int64 cycles_per_microsecond = 4; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/xla_data.proto b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/xla_data.proto new file mode 100644 index 000000000..06594c84f --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/protos/xla/xla_data.proto @@ -0,0 +1,1054 @@ +/* Copyright 2017 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Primitive types are the individual values that can be held in rectangular +// multidimensional arrays. A description of the rectangular multidimensional +// array dimensions / primitive type is given by Shape, below. +// +// LINT.IfChange +enum PrimitiveType { + // Invalid primitive type to serve as default. + PRIMITIVE_TYPE_INVALID = 0; + + // Predicates are two-state booleans. + PRED = 1; + + // Signed integral values of fixed width. + S2 = 26; + S4 = 21; + S8 = 2; + S16 = 3; + S32 = 4; + S64 = 5; + + // Unsigned integral values of fixed width. + U2 = 27; + U4 = 22; + U8 = 6; + U16 = 7; + U32 = 8; + U64 = 9; + + // Floating-point values of fixed width. + // + // Note: if f16s are not natively supported on the device, they will be + // converted to f16 from f32 at arbirary points in the computation. + F16 = 10; + F32 = 11; + + // Truncated 16 bit floating-point format. This is similar to IEEE's 16 bit + // floating-point format, but uses 1 bit for the sign, 8 bits for the exponent + // and 7 bits for the mantissa. + BF16 = 16; + + F64 = 12; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2209.05433 + // + // F8E5M2 has 5 exponent bits and 2 mantissa bits, and is similar to the + // existing IEEE types. + // + // F8E4M3FN has 4 exponent bits and 3 mantissa bits. The "FN" means only + // Finite and NaN values are supported. Unlike IEEE types, infinities are not + // supported. NaN is represented when the exponent and mantissa bits are all + // 1s. All other values are finite. + // + // F8E4M3B11FNUZ has 4 exponent bits and 3 mantissa bits and a bias of 11. The + // "FNUZ" means only Finite and NaN values are supported; zero is unsigned. + // Unlike IEEE types, infinities are not supported. NaN is represented when + // the exponent and mantissa bits are all 0s with a sign bit of 1. All other + // values are finite. + // + // Support for these dtypes is under development. They do not yet work + // properly in most cases. + // TODO(b/259609697): Fully support FP8. + F8E5M2 = 19; + F8E4M3FN = 20; + F8E4M3B11FNUZ = 23; + + // FP8 dtypes, as described in this paper: https://arxiv.org/abs/2206.02915 + // + // F8E5M2FNUZ has 5 exponent bits and 2 mantissa bits. + // F8E4M3FNUZ has 4 exponent bits and 3 mantissa bits. + // + // The "FNUZ" means only Finite and NaN values are supported; zero is + // unsigned. Unlike IEEE types, infinities are not supported. NaN is + // represented when the exponent and mantissa bits are all 0s with a sign bit + // of 1. All other values are finite. + // + // These differences mean there's an additional exponent value available. To + // keep the same dynamic range as an IEEE-like FP8 type, the exponent is + // biased one more than would be expected given the number of exponent bits + // (8 for Float8E4M3FNUZ and 16 for Float8E5M2FNUZ). + F8E5M2FNUZ = 24; + F8E4M3FNUZ = 25; + + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + C128 = 18; // Paired F64 (real, imag), as in std::complex. + + // A tuple is a polymorphic sequence; e.g. a shape that holds different + // sub-shapes. They are used for things like returning multiple values from a + // computation; e.g. a computation that returns weights and biases may have a + // signature that results in a tuple like (f32[784x2000], f32[2000]) + // + // If a shape proto has the tuple element type, it may not have any entries + // in the dimensions field. + TUPLE = 13; + + // An opaque type used for passing context-specific data to a custom + // operation. Shapes of this primitive type will have empty dimensions and + // tuple_shapes fields. + // + // (OPAQUE would be a better name for this identifier, but that conflicts with + // a macro defined in windows.h.) + OPAQUE_TYPE = 14; + + // A token type threaded between side-effecting operations. Shapes of this + // primitive type will have empty dimensions and tuple_shapes fields. + TOKEN = 17; + + // Next = 28 +} +// LINT.ThenChange( +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, +// https://www.tensorflow.org/code/tensorflow/compiler/xla/tools/driver.cc +// ) + +// Describes the padding configuration for Pad operation. The padding amount on +// both edges as well as between the elements are specified for each dimension. +message PaddingConfig { + // Describes the padding configuration for a dimension. + message PaddingConfigDimension { + // Padding amount on the low-end (next to the index 0). May be negative. + int64 edge_padding_low = 1; + + // Padding amount on the high-end (next to the highest index). May be + // negative. + int64 edge_padding_high = 2; + + // Padding amount between the elements. May not be negative. + int64 interior_padding = 3; + } + + // The padding configuration for all dimensions. + repeated PaddingConfigDimension dimensions = 1; +} + +// A DimLevelType indicates the encoding method for a dimension in an array. +// The semantics of this field are identical to those of the MLIR SparseTensor +// dialect. +// This should be kept in sync with the SparseTensor DimLevelType enum: +// https://github.com/llvm/llvm-project/blob/5674a3c88088e668b684326c2194a6282e8270ff/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorAttrDefs.td#L86 +enum DimLevelType { + // The corresponding dimension is Dense, every entry is stored. + DIM_DENSE = 0; + // The corresponding dimension is Compressed, only nonzeros are stored. + DIM_COMPRESSED = 1; + // The corresponding dimension contains a single coordinate, no sibling + // elements for each parent. + DIM_SINGLETON = 2; + // The corresponding dimension is Compressed, but with potential trailing + // zeros, thus an extra upper bound (high) is used to exclude those zeros. + // E.g., indices = [1, 2, 0, 0, 3, 4, 0, 0], position = [(0, 2), (4, 6)]. + DIM_LOOSE_COMPRESSED = 3; +} + +// Describes a tile used in tiling-based layout. Refer to +// g3doc/third_party/xla/docs/tiled_layout.md for details about tiling-based +// layout. +message TileProto { + // Number of elements in each dimension of the tile. It's ordered from the + // most major dimension of the tile to the most minor dimension of the tile. + // The dimensions correspond to a suffix of the dimensions of the shape being + // tiled. + repeated int64 dimensions = 1; +} + +// Describes how data should be split between different memories. +message SplitConfigProto { + // The dimension that is split. + int64 dimension = 1; + // The indices where each split point occurs. For example, if the dimension + // size is 1024, a split_indices value of {512} indicates a two-way split of + // data through the middle. + repeated int64 split_indices = 2; +} + +// A layout describes how the array is placed in (1D) memory space. This +// includes the minor-to-major ordering of dimensions within a shape. +// +// Clients must specify the layouts of input Literals to the +// computation. Layouts specified in interior operations which take Shapes (for +// example, Convert) are ignored. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message LayoutProto { + // The dimension level type list for this array, specifying the way in which + // each array dimension is represented in memory. If this list is empty, the + // array is assumed to be dense. + repeated DimLevelType dim_level_types = 9; + + // Whether each dimension is unique or ordered. Each of the following lists + // must be empty, or have one entry for each entry of dim_level_types. If + // either list is empty, all dimensions are assumed to be unique and ordered, + // respectively. Entries in this list may not be false for some DimLevelType + // values (such as DIM_DENSE in particular). + repeated bool dim_unique = 13; + repeated bool dim_ordered = 14; + + // Sequence of dimension numbers, from minor (fastest varying index) to major + // (slowest varying index). This field is required. + repeated int64 minor_to_major = 1; + + // A sequence of tiles, starting from the tile that's applied first to the + // Shape. + // + // TODO(b/119839262): implement tiling in each backend or add Unimplemented + // error. + repeated TileProto tiles = 6; + + // The shape is padded at the end to multiple of, in terms of number of + // elements. This is useful when tiling does not bring the shape to certain + // desired granules. Tiling effectively pads/reshapes/transposes the shape + // to another shape. This field pads the total number of elements of that + // new shape to a multiple of certain number of elements. This is useful such + // as we want a layout which does not tile the data but still requires it to + // be padded to certain number of elements. + int64 tail_padding_alignment_in_elements = 16; + + // (Optional) Bit size of each element. When unspecified or being 0, default + // to ShapeUtil::ByteSizeOfPrimitiveType. + int64 element_size_in_bits = 7; + + // Memory space where this array resides. The integer field is interpreted in + // a backend-specific manner. + int64 memory_space = 8; + + // The integer types to be used for indices and pointers. These fields must + // not be used unless the layout represents a sparse array. The PrimitiveType + // must correspond to an unsigned integer (U8, U16, U32, or U64). + // If not provided, the compiler will use the largest unsigned integer + // that is naturally supported by the target device (U32 or U64 in currently + // supported devices). + PrimitiveType index_primitive_type = 11; + PrimitiveType pointer_primitive_type = 12; + + // The physical, on-device shape used to represent the shape this layout + // belongs to. Only used for sparse arrays. + // The layout(s) contained within the physical shape should not also contain + // a physical shape. + ShapeProto physical_shape = 10; + + // The dynamic shape metadata size in bytes in front of the shape data. The + // field may be non-zero for a static shape whose associated buffer is for a + // dynamic shape, e.g. a result of SliceToDynamic. + int64 dynamic_shape_metadata_prefix_bytes = 15; + + // The split configurations which describe if/how the data is split between + // different memories. + repeated SplitConfigProto split_configs = 17; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal() and + // LayoutUtil::Hash appropriately to account for the new field. + + reserved 2; + reserved "padded_dimensions"; + reserved 3; + reserved "padding_value"; + reserved 4; + reserved "format"; + reserved 5; + reserved "max_sparse_elements"; +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc, \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/layout_util.cc) + +// A shape describes the number of dimensions in the array, the size of each +// dimension, and the primitive component type. +// +// Tuples are a special case in that they have rank zero and have tuple_shapes +// defined. +// +// See the XLA documentation for more information on shapes and layouts. +// +// LINT.IfChange +message ShapeProto { + reserved 1; + reserved "rank"; + + // The element type for this shape. + PrimitiveType element_type = 2; + + // The size (number of elements) for each dimension, or an upper bound on the + // size if the dimension is dynamic. In XLA, dimensions are numbered from 0 + // to N-1 for an N-dimensional array. The first element of 'dimensions' is the + // size of dimension 0, the second element is the size of dimension 1, and so + // forth. Empty list indicates a scalar. + // + // If the respective element in 'is_dimension_dynamic' is true then the value + // in this field represents an upper bound on the size of the dimension. + repeated int64 dimensions = 3; + + // For tuples only, the shapes of constituent shapes in the tuple sequence. + repeated ShapeProto tuple_shapes = 4; + + // The layout used to back this shape. + LayoutProto layout = 5; + + // For arrays, this indicates whether or not each dimension is + // dynamically-sized. The number of elements in this repeated field should be + // zero (indicating that no dimensions are dynamic) or equal to the number of + // elements in the 'dimensions' field. + repeated bool is_dynamic_dimension = 6; + + // Important: if any field is added, be sure to modify ShapeUtil::Equal(), + // ShapeUtil::Compatible() and ShapeUtil::Hash() appropriately to account for + // the new field. +} +// LINT.ThenChange( \ +// https://www.tensorflow.org/code/tensorflow/compiler/xla/shape_util.cc) + +// Shape of the parameters and output of a computation (like a traditional +// function signature). +message ProgramShapeProto { + repeated ShapeProto parameters = 1; + ShapeProto result = 2; + repeated string parameter_names = 3; +} + +// Statistics of a computation. +message ComputationStats { + // The number of floating point operations in the computation. + double flop_count = 1; + + // The number of transcendental operations (e.g., exp) in the computation. + double transcendental_count = 2; +} + +// The type optimization profiles in use for Op-level optimizations. +enum ProfileType { + INVALID = 0; + WINDOW = 1; + FLAG = 2; + INTEGER = 3; +} + +// The source of the optimization profile. +enum ProfileSource { + PROFILE_SOURCE_UNKNOWN_SOURCE = 0; + PROFILE_SOURCE_EMBEDDED = 1; + PROFILE_SOURCE_REMOTE = 2; +} + +// The compilation event that triggered the use of the profile. +enum CompilationEvent { + COMPILATION_EVENT_UNKNOWN_EVENT = 0; + COMPILATION_EVENT_FIRST_COMPILATION = 1; + COMPILATION_EVENT_RECOMPILATION = 2; +} + +// Symbolization metadata for HLO Instructions. +// +// This metadata is used for debugging XLA code generation, as well as +// performance profiling of XLA-generated executables. +message OpMetadata { + // The framework op name that generated this XLA op. + // + // Frameworks that build on top of XLA should mirror the names of their ops + // back to users by specifying the op_type. In this way, even if the + // framework's "ops" are implemented as multiple XLA HLO Ops, they can be + // grouped appropriately. (e.g. if a SoftMax layer is emitted into XLA as + // multiple ops, then each op should have the op_type be "SoftMax".) + string op_type = 1; + // The user-specified name of the op. + // + // This name is often unique within a computation. Note: some frameworks + // add auto-generated names if the user does not provide one. + string op_name = 2; + // Indicate a file and line that this op is associated to in a user's program. + // + // e.g. it could be the file and line of user code that generated the op. + string source_file = 3; + int32 source_line = 4; + + // Deprecated, use [ProfileInfo][profile_type] instead. + repeated ProfileType profile_type = 5 [deprecated = true]; + + reserved 6; + reserved "creation_pass_id"; + + reserved 7; + reserved "logical_creation_pass_id"; + + // The footprint of the generated code for the instruction. + int64 size_of_generated_code_in_bytes = 8; + // The size of the working set, i.e., the amount of memory, used by the + // instruction in a compiler-managed fast device memory. + int64 size_of_memory_working_set_in_bytes = 9; + + // Information about the optimization profile that this operation contains. + message ProfileInfo { + // The type of optimization profiles that this operation contains. + repeated ProfileType profile_type = 1; + // Speedup of tuned config compared to default config. + // TODO(b/203817882) Set the relative_speedup. + double relative_speedup = 2; + // The source of the optimization profiles that this operation contains. + ProfileSource profile_source = 3; + // The compilation event that triggered the use of the profiles. + CompilationEvent compilation_event = 4; + } + + // Profile information for the Op. + ProfileInfo profile_info = 10; + + // Deduplicated HLO name for this op. In some cases, we can have multiple + // instructions (e.g. fusions) that are considered duplicates. We want to + // group them together under the same name so that we can group them together + // during analysis (e.g. HLO Op Profile tool in Xprof). + // E.g. If we have fusion.1, fusion.2, and fusion.3 marked as duplicates, + // fusion.2 and fusion.3 will have deduplicated_name = fusion.1 + string deduplicated_name = 12; + + // Whether to preserve the layout of the HLO op. + bool preserve_layout = 13; + + // 1-based position of the frame in frames flat array. + // Ids are 1-based to keep 0 value as representation of non-set property. + int32 stack_frame_id = 15; + + reserved 14; +} + +// Profile data from the execution of a computation. +message ExecutionProfile { + // Whether the executable was read from the compilation cache. + bool compilation_cache_hit = 1; + + // The time in milliseconds spent to compile the computation. This only set if + // the executable was not read from the compilation cache + // (compilation_cache_hit == false). + int64 compile_time_ms = 2; + + // The number of cycles spent for the computation. This does not include the + // time taken for the data transfers between the host and the device. This is + // a target-dependent field and only used for debugging purposes. + int64 compute_cycle_count = 3; + + // The time in nanoseconds spent for the computation, without data transfer. + int64 compute_time_ns = 4; + + // The time in nanoseconds spent for the entire computation, including the + // result data transfer time. Current implementation does not spend any cycles + // for the input data transfer since the memory is initialized with the proper + // values before the execution. + int64 compute_and_transfer_time_ns = 5; + + // The size of the binary code in the executable. + int64 executable_size_in_bytes = 6; + + // Whether this profile was drawn from a cache of profiles instead of from + // execution on the hardware. + bool profile_cache_hit = 7; + + // Whether a warm-up run of the computation was executed before the + // measured execution. + bool warmup_run_executed = 8; +} + +// Handle given to a user that represents an execution that the user launched +// asynchronously on the device. +message ExecutionHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a globally accessible allocation. +// Contrast this against a ComputationDataHandle, which is not globally +// accessible, since it only exists within a specific computation. +message GlobalDataHandle { + int64 handle = 1; +} + +// Handle given to a user that represents a replicated virtual device. Each +// replicated device represents N physical devices for execution where N is the +// number of replicas. +message DeviceHandle { + int64 handle = 1; + + // The number of model-parallel virtual devices that communicate via XLA + // Send/Recv instructions. + int64 device_count = 2; +} + +// Handle given to a user to represent a channel between two computations +// via a Send and Recv instruction pair. Channels are unbuffered, so Send +// Send instructions will be blocked until the data is transferred. +message ChannelHandle { + int64 handle = 1; + enum ChannelType { + // Invalid primitive type to serve as default. + CHANNEL_TYPE_INVALID = 0; + + // A channel for sending data between devices. + DEVICE_TO_DEVICE = 1; + + // A channel for sending data from the device to the host. Can only be used + // with a Send operation. + DEVICE_TO_HOST = 2; + + // A channel for sending data from the host to the device. Can only be used + // with a Recv operation. + HOST_TO_DEVICE = 3; + } + ChannelType type = 2; +} + +// DeviceAssignmentProto is a serialized form of DeviceAssignment class, which +// represents the device ids assigned to a set of replicated computations. +// See xla::DeviceAssignment class comment for more details. +message DeviceAssignmentProto { + int32 replica_count = 1; + int32 computation_count = 2; + + // Each logical computation runs on replica_count physical devices. + // ComputationDevice represents the device ids assinged to the replicas. + message ComputationDevice { + repeated int64 replica_device_ids = 1; + } + repeated ComputationDevice computation_devices = 3; +} + +// Literals are used when the server and client need to exchange materialized +// data / results. Literals are also used to describe constants used in +// computations. +// +// Transfers to/from the client are encoded in literal form, and the structure +// of the repeated fields is implied by the shape. +message LiteralProto { + ShapeProto shape = 1; + repeated bool preds = 2; + bytes s2s = 26; + bytes s4s = 21; + bytes s8s = 15; + bytes u2s = 27; + bytes u4s = 22; + bytes u8s = 3; + repeated int32 s32s = 4; + repeated int64 s64s = 5; + repeated uint32 u32s = 6; + repeated uint64 u64s = 7; + repeated float f32s = 8; + repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. + repeated double c128s = 18; // Stored as interleaved real, imag doubles. + repeated LiteralProto tuple_literals = 10; + // The F16s, BF16s, U16s and S16s are encoded in little endian byte order + bytes f16s = 11; + bytes bf16s = 13; + bytes u16s = 16; + bytes s16s = 17; + bytes f8e5m2s = 19; + bytes f8e4m3fns = 20; + bytes f8e4m3b11fnuzs = 23; + bytes f8e5m2fnuzs = 24; + bytes f8e4m3fnuzs = 25; + repeated int64 sparse_indices = 14; + // Next = 28 +} + +message WindowDimension { + // The size of the window in this dimension. For a rectangle, this would be + // the width or height. + int64 size = 1; + + // The stride at which the window moves across the base area in this + // dimension. In other words, this is the spacing between different + // positions of the window in this dimension. + int64 stride = 2; + + // If positive, means the amount of padding to add to the base area at the low + // end of this dimension; if negative, its negative means the number of + // elements removed from the low end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of padding + // values to pad on the left, given that indices increase when going right. + // The actual padding value depends upon the context. Convolution pads with + // zeros. ReduceWindow and SelectAndScatter pads with the reduce function's + // init value. + int64 padding_low = 3; + + // As padding_low, but on the high end of this dimension. For example, in the + // horizontal dimension of a rectangle, this would be the number of values to + // pad on the right, given that indices increase when going right. + int64 padding_high = 4; + + // Dilation factor of the sliding window in this dimension. A dilation factor + // of 1 means no dilation. window_dilation - 1 no-op entries ("holes") are + // implicitly placed between each kernel element. This value may not be less + // than 1. See documentation for convolution. + int64 window_dilation = 5; + + // Dilation factor of the base area in this dimension. A dilation factor of 1 + // means no dilation. base_dilation - 1 no-op entries ("holes") are implicitly + // placed between each base area element. This value may not be less than 1. + // See documentation for convolution. + int64 base_dilation = 6; + + // Window reversal means that this dimension was logically reversed before the + // operation. + bool window_reversal = 7; +} + +// Describes the windowing in an operation such as convolution. +// +// The window is moved across a base area and for each position of the +// window a computation is performed. The field below describes the +// window and the movement of the window across a base area. +message Window { + repeated WindowDimension dimensions = 1; +} + +// Describes the dimension numbers for a gather operation. +// +// See https://www.tensorflow.org/performance/xla/operation_semantics#gather for +// more details. +message GatherDimensionNumbers { + // "Window indices" is a term for a set of indices that index into the + // interior of a dynamic-slice from the input tensor, the starting indices for + // which were computed from output_gather_dims (see the operation semantic for + // how this is defined) and the start_indices tensor. + // + // The window indices for a specific output index Out is computed as: + // + // i = 0 + // for (k : [0, input_tensor_shape.rank)) + // window_indices[k] = + // if k in collapsed_slice_dims + // then 0 + // else Out[offset_dims[i++]] + repeated int64 offset_dims = 1; + repeated int64 collapsed_slice_dims = 2; + + // This is interpreted as a map from i to start_index_map[i]. It + // transforms the gather index looked up from the start_indices tensor into + // the starting index in the input space. + repeated int64 start_index_map = 3; + + // The dimension in the start_indices input that contains the starting + // indices. + int64 index_vector_dim = 4; +} + +// Describes the dimension numbers for a scatter operation. +// +// All the fields are similar to the corresponding fields in +// GatherDimensionNumbers. Differences are noted below. +message ScatterDimensionNumbers { + // The set of dimensions in the updates shape that are window dimensions. + repeated int64 update_window_dims = 1; + // The set of window dimensions that must be inserted into the updates shape. + repeated int64 inserted_window_dims = 2; + + repeated int64 scatter_dims_to_operand_dims = 3; + int64 index_vector_dim = 4; +} + +message ConvolutionDimensionNumbers { + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; + + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the input. + repeated int64 input_spatial_dimensions = 11; + + // The number of the dimension that represents input features in the + // convolutional kernel (rhs). + int64 kernel_input_feature_dimension = 3; + + // The number of the dimension that represents output features in + // the convolutional kernel (rhs). + int64 kernel_output_feature_dimension = 4; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the kernel (rhs). window.strides(0) is the + // stride in the kernel_spatial_dimensions(0) dimension. + repeated int64 kernel_spatial_dimensions = 6; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; + + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; + + // The dimension numbers for the spatial dimensions that the window + // moves through in the output. + repeated int64 output_spatial_dimensions = 12; + + // Next = 13 +} + +enum PaddingType { + PADDING_INVALID = 0; + PADDING_VALID = 1; // Only valid portion of the base are covered. + PADDING_SAME = 2; // Extra is added to produce same output size as the input. +} + +enum FftType { + FFT = 0; // Forward FFT; complex in, complex out. + IFFT = 1; // Inverse FFT; complex in, complex out. + RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out + IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, + // fft_length real out +} + +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +} + +enum SparsityType { + SPARSITY_INVALID = 0; + + // Structured N:M sparsity. + SPARSITY_STRUCTURED_N_M = 1; + + // Next: 2 +} + +// Contains sparsity metadata for a sparse dot operation. +// The only supported type atm is structured 2:4 sparsity, which is natively +// supported on NVidia GPUs. +// Restrictions: +// - only one operand of the dot operation may be sparse; +// - only the contracting dimension may be sparse. +message SparsityDescriptor { + SparsityType type = 1; + + // Sparse operand index (0 or 1). + int32 index = 2; + // Sparse dimension number. + int32 dimension = 3; + + // Structured N:M sparsity (N < M). + int32 n = 4; + int32 m = 5; + + // Next: 6 +} + +enum RandomDistribution { + RNG_INVALID = 0; + + // Creates a uniform-distribution-generated random number on the semi-open + // interval [parameter[0], parameter[1]). + RNG_UNIFORM = 1; + + // Creates a normal-distribution-generated random number with mean + // parameter[0] and standard deviation parameter[1]. + RNG_NORMAL = 2; + + // Next: 4 +} + +enum RandomAlgorithm { + RNG_DEFAULT = 0; // Backend dependent default algorithm. + RNG_THREE_FRY = 1; + RNG_PHILOX = 2; + // Next: 2 +} + +message TriangularSolveOptions { + // If true, solves ax = b. If false, solves xa = b. + bool left_side = 1; + + // If true, 'a' is lower triangular. If false, 'a' is upper triangular. + bool lower = 2; + + // If true, the diagonal elements of 'a' are assumed to be 1 and not accessed. + bool unit_diagonal = 3; + + // Should we transpose or use the adjoint of 'a'? + enum Transpose { + TRANSPOSE_INVALID = 0; + NO_TRANSPOSE = 1; // Don't transpose 'a'. + TRANSPOSE = 2; // Transpose 'a'. + ADJOINT = 3; // Complex conjugate and transpose 'a'. + } + Transpose transpose_a = 4; +} + +message CholeskyOptions { + // If true, uses the lower triangle of `a`. If false, uses the upper triangle + // of `a`. + bool lower = 1; +} + +// Attributes of the sort custom call (cub::DeviceRadixSort). +message SortOptions { + bool descending = 1; +} + +// Generic map of attributes used to pass hints / configuration options from +// the Python frontend to the XLA backend. +message FrontendAttributes { + map map = 1; +} + +// Represents a single statistic to track. +message Statistic { + // Must be a single word consisting of any alphanumeric characters + string stat_name = 1; + // Must be within a range of [0, 100], in order for the graph dumper to + // properly render the statistic onto the graph. + double stat_val = 2; +} + +// Represents the information needed to visualize propagation statistics when +// rendering an HLO graph. This includes an array of statistics as well as the +// index of the statistic to render. +message StatisticsViz { + int64 stat_index_to_visualize = 1; + repeated Statistic statistics = 2; +} + +// LINT.IfChange +message OpSharding { + enum Type { + // This sharding is replicated across all devices (implies maximal, + // all other fields are unused). + REPLICATED = 0; + // This sharding is maximal - one device runs the entire operation. + MAXIMAL = 1; + // This sharding is a tuple - only the tuple_shardings field is valid. + TUPLE = 2; + // None of the above; tile_shape and tile_assignment are both used. + OTHER = 3; + // This op is manually sharded: the shapes are already partitioned and the + // partitioner should not change this op. + MANUAL = 4; + // This sharding is a placeholder sharding with lowest precedence, it can be + // overwriten by any other shardings. + UNKNOWN = 5; + } + Type type = 1; + // The shape of the sharded tile. + ShapeProto tile_shape = 2; + // The shape of the tile assignment tensor - this must be the same rank as + // tile_shape and the product of its dimensions must equal + // tile_assignment_devices.size(). + repeated int64 tile_assignment_dimensions = 3; + // Flattened list of device IDs. The order of flattening is the same as used + // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). + // Only one of tile_assignment_devices and iota_dimensions shall be non-empty. + repeated int64 tile_assignment_devices = 4; + // If type == TUPLE, the sub-shardings, one per leaf node in the tuple shape, + // in pre-order. The tuple shape could be nested; here we store just a + // flattened list of all leaves in the tuple shape. Note that the tuple shape + // is not stored here; shardings do not store the shapes to which they are + // applied, this is inferred from the instruction this sharding gets attached + // to. + repeated OpSharding tuple_shardings = 5; + + // Only used for OTHER type. If true, data is sharded according to other + // dimensions of tile_assignment(), but replicated across devices along the + // last dimension. (Experimental) + bool replicate_on_last_tile_dim = 6; + // This field is used to track the source of this sharding, usually derived + // from instructions. Multple metadata may be populated if sharding is + // combined with other shardings. Metadata are to not be populated when + // type == TUPLE and instead metadata should be set on individual tuple + // elements. + repeated OpMetadata metadata = 7; + + // This field is used to represented the sharding type of each subgroup. + // For example, sharding={devices=[2,2,2,2]0,1,2,...,15 last_tile_dims={ + // replicate, manual, unreduced}} means that each of the last 3 dimensions + // in [2,2,2,2] represents a subgrouping in replicate, manual, + // unreduced sharding type respectively. + repeated Type last_tile_dims = 8; + + // Dimensions used to reshape the 1D iota array of device IDs. + // Only one of tile_assignment_devices and iota_reshape_dims shall be + // non-empty. + repeated int64 iota_reshape_dims = 9; + + // Dimension permutations to transposed the iota array reshaped to + // iota_reshape_dims. This must have the same size as iota_reshape_dims. + repeated int32 iota_transpose_perm = 10; + + // This field decides whether this op is in a shard group. + bool is_shard_group = 11; + + // This field is used to store the unique id of the shard group. + int64 shard_group_id = 12; + + // Used to decide whether this op is to be sharded like some other ops, or to + // which other ops will be sharded like. + enum ShardGroupType { + // This op will be sharded exactly the same as the other op. (hard + // restriction) + AS = 0; + // This op will try to allow sharding propagation within the same group even + // there is no data dependencies among them, but there is no guarantee that + // the final shardings within the same group will be exactly the same. (soft + // restriction) + LIKE = 1; + } + + ShardGroupType shard_group_type = 13; +} +// LINT.ThenChange() + +// Describes the replica groups in a cross replica op (e.g., all-reduce and +// all-to-all). +message ReplicaGroup { + // The ids of the replicas that belongs to the same group. The ordering of the + // ids matters in some ops (e.g., all-to-all). + repeated int64 replica_ids = 1; +} + +// Describes the source target pair in the collective permute op. +message SourceTarget { + int64 source = 1; + int64 target = 2; +} + +// Used to indicate the precision configuration. It has backend specific +// meaning. +message PrecisionConfig { + enum Precision { + DEFAULT = 0; + HIGH = 1; + HIGHEST = 2; + // Each U8/S8 value in a tensor actually represents 2 nibble values. + PACKED_NIBBLE = 3; + + // Next: 4 + } + + // The algorithm used to evaluate the instruction. + // + // The naming convention for the dot instruction is + // ALG_DOT_{A_TYPE}_{B_TYPE}_{ACCUM_TYPE}[_X{NUM_OPS}] where A_TYPE, B_TYPE + // and ACCUM_TYPE correspond to the types in the "primitive dot operations" + // (such as TensorCore operations) and NUM_OPS is the number of such + // operations used per "primitive tile". When the NUM_OPS + // field is skipped, it is assumed to be 1. The types mentioned in the name + // are independent of the storage types. + // + // In general ATYPE and BTYPE are the precisions that the LHS and RHS of the + // operation are rounded to and ACCUMTYPE is the accumulation type. If a + // backend does not support the given algorithm, an error is raised. The + // Algorithm enum is intended to eventually replace the Precision enum. + // + enum Algorithm { + // If the algorithm is `ALG_UNSET`, we will decide the algorithm based on + // the operand_precision values (for now). + ALG_UNSET = 0; + // The storage type can be any 8-bit floating point type. + ALG_DOT_ANY_F8_ANY_F8_F32 = 1; + // The storage type can be any 8-bit floating point type. Intermediate + // results will not periodically be promoted to a higher precision. This + // corresponds to CUBLASLT_MATMUL_DESC_FAST_ACCUM. Triton's + // maxNumImpreciseAcc=32 setting may be similar. + ALG_DOT_ANY_F8_ANY_F8_F32_FAST_ACCUM = 2; + ALG_DOT_F16_F16_F16 = 3; + ALG_DOT_F16_F16_F32 = 4; + ALG_DOT_BF16_BF16_BF16 = 5; + ALG_DOT_BF16_BF16_F32 = 6; + // An algorithm which uses 3 BF16_BF16_F32 matmuls to achieve better + // precision. + ALG_DOT_BF16_BF16_F32_X3 = 7; + // An algorithm which uses 6 BF16_BF16_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_BF16_BF16_F32_X6 = 8; + ALG_DOT_TF32_TF32_F32 = 9; + // An algorithm which uses 3 TF32_TF32_F32 matmuls to achieve better + // precision (similar to F32). + ALG_DOT_TF32_TF32_F32_X3 = 10; + ALG_DOT_F32_F32_F32 = 11; + ALG_DOT_F64_F64_F64 = 12; + + // Next: 13 + } + + repeated Precision operand_precision = 1; + + // Currently doesn't do anything, but we plan to support it for dot and + // possibly more instructions. + // + // TODO(b/316147294): Support this on GPU and add this to StableHLO as well. + // + // If this is set, then `operand_precision` should be set to DEFAULT and it + // will be ignored. + Algorithm algorithm = 2; + + // Next: 3 +} + +// Describes whether all data-parallelism replicas will receive the same +// parameter data at each buffer. +message ParameterReplication { + // A list of boolean values for the flattened leaf buffers. Each value + // indicates whether the corresponding leaf buffer is replicated. + // + // If this field is empty, it means no buffer is replicated. Otherwise, the + // number of elements in this field must match the number of leaf buffers in + // the HLO instruction's shape. + repeated bool replicated_at_leaf_buffers = 1; +} + +// A backend-config for kWhile loops that stores the loop's trip count, if it is +// known. +// +// This is useful for backends that can implement a `for i in 0..N` loop more +// efficiently than a `while` loop. For example, on GPUs, we can implement a +// `for i in 0..N` loop by enqueueing the kernels for the loop body N times, +// whereas implementing a `while` loop requires a host-device sync on each +// iteration. +message WhileLoopBackendConfig { + message KnownTripCount { + int64 n = 1; + } + // This indirection lets us distinguish between known-trip-count == 0 and + // unknown-trip-count. + KnownTripCount known_trip_count = 1; +} + +// Specifies a pair of output/operand buffers that alias each other for +// kCustomCall and kFusion +message OutputOperandAliasing { + repeated int64 output_shape_index = 1; + int64 operand_index = 2; + repeated int64 operand_shape_index = 3; +} diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/report_nvtx_pushpop_trace.csv.xz b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/report_nvtx_pushpop_trace.csv.xz new file mode 100644 index 000000000..a42402d65 Binary files /dev/null and b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/report_nvtx_pushpop_trace.csv.xz differ diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/axes_scan.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/axes_scan.py new file mode 100644 index 000000000..ec0f18d0d --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/axes_scan.py @@ -0,0 +1,182 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wrapper around jax.lax.scan with in_axes/out_axes API.""" +import functools +from typing import Any, Callable, Optional + +import jax +import jax.numpy as jnp +import numpy as np +from jax import core, lax +from jax.extend import linear_util as lu +from jax.interpreters import partial_eval as pe + +ScanAxis = Optional[int] + + +class _Broadcast: + pass + + +broadcast = _Broadcast() + + +def scan( + fn: Callable[..., Any], + in_axes: Any, + out_axes: Any, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False +): + """A wrapper around `jax.lax.scan` with in_axes/out_axes api. + + Example:: + def body_fn(b, c, x): + return b + 2, c + 1, 2 * x + + loop = scan(body_fn, in_axes=0, out_axes=0) + broadcast_in = 1 + carry = 2 + xs = jnp.arange(3) + broadcast_out, carry, ys = loop(broadcast_in, carry, xs) + print(broadcast_out) # prints: 3 + print(carry) # prints: 5 + print(ys) # prints: [0, 2, 4] + + + Args: + fn: the body function of the scan loop of the form + `(broadcast_in, carry, *args) -> (broadcast_out, carry, scan_out)`. + the broadcast argument allows for loop independent inputs/outputs to + be computed inside `fn`. `fn` will be called once to compute + `broadcast_out`. The actual loop will receive `broadcast_out` as the new + `broadcast_in`. This is useful for initializing values inside the loop. + in_axes: specifies the axis along which arguments are scanned. + Use `broadcast` to use the same value across iterations. + out_axes: specifies the axis along which outputs are concatenated. + Use `broadcast` if a return value should not be concatenated and + is independent of the loop body. + length: number of iterations. Only needs to be specified if there + is no scan axis from which it can be derived. + reverse: scan in reverse order from end to start. + unroll: how many scan iterations to unroll within a single + iteration of a loop (default: 1). + _split_transpose: An experimental feature to split the transpose of scan + into a scan and a map, backed by an experimental Jax lax.scan() feature. + Returns: + the function that performs the scan of the form: + (broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out). + """ + + def transpose_to_front(ax, xs): + if ax is broadcast: + return () + if ax == 0: + return xs + + def trans(x): + perm = tuple(range(x.ndim)) + perm = (ax,) + tuple(np.delete(perm, ax)) + return jnp.transpose(x, perm) + + return jax.tree_util.tree_map(trans, xs) + + def transpose_from_front(ax, xs): + if ax is broadcast: + return () + if ax == 0: + return xs + + def trans(x): + if ax < 0: + pax = x.ndim + ax + else: + pax = ax + assert pax < x.ndim + perm = tuple(range(1, pax + 1)) + (0,) + tuple(range(pax + 1, x.ndim)) + return jnp.transpose(x, perm) + + return jax.tree_util.tree_map(trans, xs) + + def scan_fn(broadcast_in, init, *args): + xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) + + def body_fn(c, xs, init_mode=False): + # inject constants + xs = jax.tree_util.tree_map( + lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs + ) + broadcast_out, c, ys = fn(broadcast_in, c, *xs) + + if init_mode: + ys = jax.tree_util.tree_map( + lambda ax, y: (y if ax is broadcast else ()), out_axes, ys + ) + return broadcast_out, ys + else: + ys = jax.tree_util.tree_map( + lambda ax, y: (() if ax is broadcast else y), out_axes, ys + ) + return c, ys + + broadcast_body = functools.partial(body_fn, init_mode=True) + + carry_avals = jax.tree_util.tree_map( + lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init + ) + scan_avals = jax.tree_util.tree_map( + lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs + ) + input_avals = (carry_avals, scan_avals) + + in_avals, in_tree = jax.tree_util.tree_flatten(input_avals) + f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( + lu.wrap_init(broadcast_body), in_tree + ) + in_pvals = list(map(pe.PartialVal.unknown, in_avals)) + _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) + + out_flat = [] + for pv, const in out_pvals: + if pv is not None: + raise ValueError( + 'broadcasted variable has a data dependency on the scan body.' + ) + out_flat.append(const) + broadcast_in, constants_out = jax.tree_util.tree_unflatten( + out_tree(), out_flat + ) + + if jax.version.__version_info__ > (0, 4, 25): + c, ys = lax.scan( + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, + _split_transpose=_split_transpose + ) + else: + c, ys = lax.scan( + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll + ) + ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) + ys = jax.tree_util.tree_map( + lambda ax, const, y: (const if ax is broadcast else y), + out_axes, + constants_out, + ys, + ) + return broadcast_in, c, ys + + return scan_fn diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/lift.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/lift.py new file mode 100644 index 000000000..34925370c --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/lift.py @@ -0,0 +1,1687 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Jax transform lifting.""" + +import collections +import functools +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) +import warnings + +from flax import traceback_util +from flax.typing import ( + In, + Out, + InOutAxis, + InOutScanAxis, +) +import jax +from jax import random + +from . import axes_scan, meta +from .frozen_dict import freeze, unfreeze +from .scope import ( + CollectionFilter, + DenyList, # pylint: disable=g-multiple-import + Filter, + PRNGSequenceFilter, + Scope, + group_collections, + in_filter, + intersect_filters, + is_filter_empty, + subtract_filters, + union_filters, +) + +traceback_util.register_exclusion(__file__) + + +def tree_map_rngs(fn, tree): + """Needed for mapping JAX random.* functions over PRNGKey leaves.""" + return jax.tree_util.tree_map( + fn, + tree, + is_leaf=lambda x: hasattr(x, 'dtype') + and jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key), + ) + + +def _dedup_scopes(scopes): + """Deduplicated scopes.""" + paths = [] + # must preseve insertion order for duplication to work correctly + minimal_set = collections.OrderedDict((s, ()) for s in scopes) + for leaf in scopes: + scope = leaf.parent + max_parent = leaf + max_parent_path = () + path = [leaf.name] + while scope is not None: + if scope in minimal_set: + max_parent = scope + max_parent_path = tuple(reversed(path)) + path.append(scope.name) + scope = scope.parent + if max_parent is not leaf and leaf in minimal_set: + del minimal_set[leaf] + paths.append((max_parent, max_parent_path)) + return tuple(minimal_set), tuple(paths) + + +def _dup_scopes(orig_scopes, scopes, paths): + """Duplicated scopes.""" + mapping = dict(zip(orig_scopes, scopes)) + scopes = [] + for root, path in paths: + scope = mapping[root] + for name in path: + scope = scope.push(name, reuse=True) + scopes.append(scope) + return scopes + + +def _transpose(xs): + return tuple(zip(*xs)) + + +def _partial_pack( + scope_tree: Scope, + in_variable_filters: Sequence[CollectionFilter], + out_variable_filters: Sequence[CollectionFilter], + rng_filters: Sequence[PRNGSequenceFilter], + name=None, +) -> tuple[Callable[..., Any], Callable[..., Any], Any, Any, Callable[..., Any], Callable[..., Any]]: + """Pack variables and rngs for functional transformations. + + The _partial_pack function is the building block for all other lifted transformations. + + Args: + fn: The function to pack. `fn` has the signature + in_variable_filters: Input variable filters. + out_variable_filters: Output variable filters. + rng_filters: RNG filters. + name: The name of the packed scope. + enable_kwargs: Whether to enable kwargs or not. + Returns: + `(scope_fn, repack_fn, variable_groups, rng_groups, publish_results_fn)` + """ + # pylint: disable=protected-access + scopes, treedef = jax.tree_util.tree_flatten(scope_tree) + scopes, paths = _dedup_scopes(scopes) + + variable_groups_xs = [] + + for scope in scopes: + scope._validate_trace_level() + scope._populate_collections() + variable_groups_xs.append( + group_collections(scope._variables, in_variable_filters) + ) + variable_groups_xs_t = _transpose(variable_groups_xs) + + # Make sure that in-only variable collections are frozen + for variable_group_xs in variable_groups_xs_t: + for variable_group in variable_group_xs: + for col_name, collection in variable_group.items(): + col_in_out = any( + in_filter(col_filter, col_name) + for col_filter in out_variable_filters + ) + if not col_in_out: + variable_group[col_name] = freeze(collection) + rng_groups_xs = [] + inner_rng_counters = [] + for scope in scopes: + rng_counters = scope.rng_counters + rng_groups = group_collections(scope.rngs, rng_filters) + rng_groups_xs.append(rng_groups) + inner_rng_counters.append(rng_counters) + rng_groups_xs_t = _transpose(rng_groups_xs) + + inner_scopes: List[Scope] = [] + + def scope_fn( + variable_groups_xs_t, + rng_groups_xs_t, + mutable_filter: CollectionFilter = True, + ): + nonlocal inner_scopes + for inner_scope in inner_scopes: + inner_scope.invalidate() + inner_scopes = [] + mutable: Filter = False + for out_filter in out_variable_filters: + mutable = union_filters(mutable, out_filter) + # could be () in the edge case where no rngs or variable_groups are lifted + # in this case fallback to ((),) * len(scopes) to make sure the zip has + # something to iterate over for each scope. + variable_groups_xs = _transpose(variable_groups_xs_t) or ((),) * len( + scopes + ) + rng_groups_xs = _transpose(rng_groups_xs_t) or ((),) * len(scopes) + assert len(variable_groups_xs) == len(scopes) + assert len(rng_groups_xs) == len(scopes) + for variable_groups, rng_groups, scope, rng_counters in zip( + variable_groups_xs, rng_groups_xs, scopes, inner_rng_counters + ): + variables = {} + rngs = {} + for variable_group in variable_groups: + variables.update(variable_group) + for rng_group in rng_groups: + rngs.update(rng_group) + # make sure variable dicts are cloned and can't be manipulated by ref + # sharing. + variables = jax.tree_util.tree_map(lambda x: x, variables) + scope_mutable = intersect_filters( + intersect_filters(scope.mutable, mutable), mutable_filter + ) + new_debug_path = scope.debug_path + if name: + if new_debug_path: + new_debug_path = new_debug_path[:-1] + ( + f'{name}({new_debug_path[-1]})', + ) + else: + new_debug_path = (f'{name}()',) + inner_scope = Scope( + variables, + name=scope.name, + rngs=rngs, + mutable=scope_mutable, + parent=None, + path=scope.path, + debug_path=new_debug_path, + flags=scope.flags, + ) + inner_scope.rng_counters = rng_counters + inner_scopes.append(inner_scope) + inner_scopes = _dup_scopes(scopes, inner_scopes, paths) + return treedef.unflatten(inner_scopes) + + def repack_fn(inner_scope_tree): + inner_scopes = treedef.flatten_up_to(inner_scope_tree) + inner_scopes, inner_paths = _dedup_scopes(inner_scopes) + inner_scopes = list(inner_scopes) + assert [p for _, p in paths] == [p for _, p in inner_paths] + out_variable_groups_xs = [] + for inner_scope in inner_scopes: + inner_scope.invalidate() + inner_scope._validate_trace_level() + mutable_variables = { + key: val + for key, val in inner_scope._variables.items() + if in_filter(inner_scope.mutable, key) + } + out_variable_groups = group_collections( + mutable_variables, tuple(out_variable_filters) + (True,) + ) + remainder = tuple(out_variable_groups[-1].keys()) + if remainder: + raise ValueError(f'unmapped output variables: {remainder}') + out_variable_groups_xs.append(out_variable_groups[:-1]) + + return _transpose(out_variable_groups_xs) + + def invalidate_scopes_fn(): + for inner_scope in inner_scopes: + inner_scope.invalidate() + + def publish_results_fn(out_variable_groups_xs_t): + out_variable_groups_xs = _transpose(out_variable_groups_xs_t) + for scope, out_variable_groups, rng_counters in zip( + scopes, out_variable_groups_xs, inner_rng_counters + ): + for out_variable_group in out_variable_groups: + for col_name, collection in out_variable_group.items(): + if not scope.is_mutable_collection(col_name): + # Some lifted transforms like scan return redundant variables. + continue + for var_name, value in collection.items(): + scope.put_variable(col_name, var_name, value) + + return ( + scope_fn, + repack_fn, + variable_groups_xs_t, + rng_groups_xs_t, + publish_results_fn, + invalidate_scopes_fn, + ) + +def pack( + fn: Callable[..., Any], + in_variable_filters: Sequence[CollectionFilter], + out_variable_filters: Sequence[CollectionFilter], + rng_filters: Sequence[PRNGSequenceFilter], + name=None, + enable_kwargs=False, +) -> Callable[..., Any]: + """Pack variables and rngs for functional transformations. + + The pack function is the building block for all other lifted transformations. + + Args: + fn: The function to pack. `fn` has the signature + `(scope_fn, repack_fn, variable_groups, rng_groups, *args) -> + (output, packed_variables)`. + in_variable_filters: Input variable filters. + out_variable_filters: Output variable filters. + rng_filters: RNG filters. + name: The name of the packed scope. + enable_kwargs: Whether to enable kwargs or not. + Returns: + A callable which expects a scope as the first argument. + """ + + @functools.wraps(fn) + def wrapper(scope_tree: Scope, *args, **kwargs): + if not enable_kwargs and kwargs: + msg = 'kwargs are not supported in {}, so "{}" is(are) ignored' + warnings.warn(msg.format(name, ', '.join(kwargs.keys())), RuntimeWarning) + ( + scope_fn, + repack_fn, + variable_groups_xs_t, + rng_groups_xs_t, + publish_results_fn, + invalidate_scopes_fn, + ) = _partial_pack(scope_tree, in_variable_filters, out_variable_filters, rng_filters, name) + try: + if enable_kwargs: + y, out_variable_groups_xs_t = fn( + scope_fn, + repack_fn, + variable_groups_xs_t, + rng_groups_xs_t, + *args, + **kwargs, + ) + else: + y, out_variable_groups_xs_t = fn( + scope_fn, repack_fn, variable_groups_xs_t, rng_groups_xs_t, *args + ) + finally: + invalidate_scopes_fn() + publish_results_fn(out_variable_groups_xs_t) + return y + + return wrapper + + +id_fn = lambda x: x + + +def map_variables( + fn: Callable[..., Any], + mapped_collections: CollectionFilter, + map_in_fn: Callable[..., Any] = id_fn, + map_out_fn: Callable[..., Any] = id_fn, + init: bool = False, + mutable: bool = False, + rngs: PRNGSequenceFilter = True, + variables: CollectionFilter = True, +) -> Callable[..., Any]: + """Map Variables inside a scope. + + Args: + fn: the function to be transformed. + mapped_collections: the collection(s) to be transformed. + map_in_fn: creates a view of the target variables. + map_out_fn: transforms the updated variables in the view after mutation. + init: If True, variables are initialized before transformation. + mutable: If True, the mapped variable collections will be mutable. + rngs: PRNGSequences added to the transformed scope (default: all). + variables: Additional Variable collections added to the transformed scope. + Besides those specified by `target` (default: all). + + Returns: + A callable expecting a scope as the first argument. + """ + is_target_out = mutable or init + + def wrapper(scope_fn, repack, variable_groups, rng_groups, *args, **kwargs): + target, variables = variable_groups + if init: + scopes = scope_fn((target, variables), rng_groups) + has_mutable_cols = any( + not is_filter_empty(scope.mutable) + for scope in jax.tree_util.tree_leaves(scopes) + ) + if has_mutable_cols: + fn(scopes, *args, **kwargs) + target, _ = repack(scopes) + target = tuple(map_out_fn(x) for x in target) + target = tuple(map_in_fn(unfreeze(x)) for x in target) + mfilter = True + if not is_target_out: + # mapped collections should not be mutable + # unless the mapping supports it (by init=True or mutable=True) + mfilter = subtract_filters(mfilter, mapped_collections) + scopes = scope_fn((target, variables), rng_groups, mutable_filter=mfilter) + y = fn(scopes, *args, **kwargs) + out_target, out_vars = repack(scopes) + if is_target_out: + out_target = tuple(map_out_fn(x) for x in out_target) + return y, (out_target, out_vars) + + in_vars = (mapped_collections, variables) + out_vars = ( + in_vars + if is_target_out + else (False, subtract_filters(variables, mapped_collections)) + ) + return pack( + wrapper, + in_vars, + out_vars, + (rngs,), + enable_kwargs=True, + name='map_variables', + ) + + +def swap_collection(fn: Callable[..., Any], col_a: str, col_b: str): + """Swap two collections.""" + + def swap(target): + a = target[col_a] if col_a in target else {} + b = target[col_b] if col_b in target else {} + target[col_b], target[col_a] = a, b + return target + + return map_variables(fn, (col_a, col_b), swap, swap, mutable=True) + + +def _split_in_out_axes(xs: Mapping[CollectionFilter, Any]): + unpack = lambda v: v.axis if isinstance(v, (In, Out)) else v + in_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, Out)} + out_axes = {k: unpack(v) for k, v in xs.items() if not isinstance(v, In)} + return in_axes, out_axes + + +def _bwd_wrapper(treedef, bwd_fn, tangent): + vars_grad, *inputs_grad = bwd_fn(tangent) + vars_grad = treedef.unflatten(vars_grad) + return (vars_grad, *inputs_grad) + + +def vjp( + fn: Callable[..., Any], + scope: Scope, + *primals, + has_aux: bool = False, + reduce_axes=(), + vjp_variables: CollectionFilter = 'params', + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: + """A lifted version of ``jax.vjp``. + + See ``jax.vjp`` for the unlifted vector-Jacobian product (backward gradient). + + Note that a gradient is returned for all variables in the collections + specified by `vjp_variables`. However, the backward function only expects + a cotangent for the return value of `fn`. If variables require a co-tangent + as well they can be returned from `fn` using `scope.variables()`. + + Example:: + + def learn_scale(scope, x, y): + p = scope.param('scale', nn.initializers.zeros_init(), ()) + return p * x * y + def f(scope, x, y): + z, bwd = lift.vjp(learn_scale, scope, x, y) + params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) + return z, params_grad, x_grad, y_grad + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + scope: The scope of which the variables will be differentiated. + *primals: A sequence of primal values at which the Jacobian of ``fn`` + should be evaluated. The length of ``primals`` should be equal to the + number of positional parameters to ``fn``. Each primal value should be a + tuple of arrays, scalar, or standard Python containers thereof. + has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default ``False``. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fn`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + VJP will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a VJP function that sums over the batch while ``vjp(f, *args)`` + will create a per-example VJP. + vjp_variables: The vjpfun will return a cotangent vector for all + variable collections specified by this filter. + variables: other variables collections that are available inside `fn` but + do not receive a cotangent. + rngs: the prngs that are available inside `fn`. + + Returns: + If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where + ``primals_out`` is ``fn(*primals)``. + ``vjpfun`` is a function from a cotangent vector with the same shape as + ``primals_out`` to a tuple of cotangent vectors with the same shape as + ``primals``, representing the vector-Jacobian product of ``fn`` evaluated at + ``primals``. If ``has_aux`` is ``True``, returns a + ``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data + returned by ``fn``. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): + vjp_vars, other_vars = variable_groups + + @functools.wraps(fn) + def wrapper(vjp_vars, *args): + variable_groups = (vjp_vars, other_vars) + scope = scope_fn(variable_groups, rng_groups) + if has_aux: + y, aux = fn(scope, *args) + else: + y = fn(scope, *args) + aux = () + return y, (aux, repack_fn(scope)) + + y, bwd, (aux, out_vars) = jax.vjp( + wrapper, vjp_vars, *args, reduce_axes=reduce_axes, has_aux=True + ) + treedef = jax.tree_util.tree_structure(scope) + bwd = jax.tree_util.Partial(functools.partial(_bwd_wrapper, treedef), bwd) + if has_aux: + return (y, bwd, aux), out_vars + else: + return (y, bwd), out_vars + + return pack( + inner, + (vjp_variables, variables), + (variables,), + (rngs,), + name='vjp', + enable_kwargs=False, + )(scope, *primals) + + +def value_and_grad( + fn: Callable[..., Any], + scope: Scope, + *primals, + has_aux: bool = False, + reduce_axes=(), + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: + """A limited lifted version of ``jax.value_and_grad``. + + See ``jax.value_and_grad`` for the unlifted reverse mode gradient. + + Note that for this convenience function, gradients are only calculated for + the function inputs (all function inputs), and not with respect to any scope + variables. The target function must return a scalar-valued output. + + Example:: + + def learn_scale(scope, x, y): + p = scope.param('scale', nn.initializers.zeros_init(), ()) + return p * x * y + def f(scope, x, y): + z, x_grad, y_grad = lift.value_and_grad(learn_scale, scope, x, y) + return z, x_grad, y_grad + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + scope: The scope of which the variables will be differentiated. + *primals: A sequence of primal values at which the Jacobian of ``fn`` + should be evaluated. The length of ``primals`` should be equal to the + number of positional parameters to ``fn``. Each primal value should be a + tuple of arrays, scalar, or standard Python containers thereof. + has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default ``False``. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fn`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + VJP will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a VJP function that sums over the batch while ``vjp(f, *args)`` + will create a per-example VJP. + variables: other variables collections that are available inside `fn` but + do not receive a cotangent. + rngs: the prngs that are available inside `fn`. + + Returns: + If ``has_aux`` is ``False``, returns a ``(primals_out, grads)`` pair, where + ``primals_out`` is ``fn(*primals)``. + If ``has_aux`` is ``True``, returns a + ``(primals_out, aux, grads)`` tuple where ``aux`` is the auxiliary data + returned by ``fn``. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): + @functools.wraps(fn) + def wrapper(*args): + scope = scope_fn(variable_groups, rng_groups) + if has_aux: + y, aux = fn(scope, *args) + else: + y = fn(scope, *args) + aux = () + return y, (aux, repack_fn(scope)) + + y, bwd, (aux, out_vars) = jax.vjp( + wrapper, + *args, + has_aux=True, + reduce_axes=reduce_axes, + ) + + inputs_grad = bwd(jax.numpy.ones_like(y)) + + if has_aux: + return (y, aux, inputs_grad), out_vars + else: + return (y, inputs_grad), out_vars + + return pack( + inner, + (variables,), + (variables,), + (rngs,), + name='value_and_grad', + enable_kwargs=False, + )(scope, *primals) + + +def jvp( + fn: Callable[..., Any], + scope: Scope, + primals, + tangents, + variable_tangents, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Tuple[Any, Any]: + """A lifted version of ``jax.jvp``. + + See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). + + Note that no tangents are returned for variables. When variable tangents + are required their value should be returned explicitly by `fn` + using `scope.variables()`. + + Example:: + + def learn_scale(scope, x): + p = scope.param('scale', nn.initializers.zeros_init(), ()) + return p * x + + def f(scope, x): + vars_t = jax.tree_util.tree_map(jnp.ones_like, + scope.variables().get('params', {})) + x, out_t = lift.jvp( + learn_scale, scope, (x,), (jnp.zeros_like(x),), + variable_tangents={'params': vars_t}) + return out_t + + Args: + fn: The function to be transformed. + scope: The scope(s) which should be lifted into the transform. + primals: The primal values at which the Jacobian of ``fun`` should be + evaluated. Should be either a tuple or a list of arguments, + and its length should be equal to the number of positional parameters of + ``fun``. + tangents: The tangent vector for which the Jacobian-vector product should be + evaluated. Should be either a tuple or a list of tangents, with the same + tree structure and array shapes as ``primals``. + variable_tangents: A dict or PyTree fo dicts with the same structure as + scopes. Each entry in the dict specifies the tangents for a variable + collection. Not specifying a collection in variable_tangents is + equivalent to passing a zero vector as the tangent. + variables: other variables collections that are available inside `fn` but + do not receive a tangent. + rngs: the prngs that are available inside `fn`. + + Returns: + A ``(primals_out, tangents_out)`` pair, where ``primals_out`` is + ``fun(*primals)``, and ``tangents_out`` is the Jacobian-vector product of + ``function`` evaluated at ``primals`` with ``tangents``. The + ``tangents_out`` value has the same Python tree structure and shapes as + ``primals_out``. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): + jvp_vars, other_vars = variable_groups + + @functools.wraps(fn) + def wrapper(vars_primals, args): + variable_groups = (vars_primals, other_vars) + scope = scope_fn(variable_groups, rng_groups) + y = fn(scope, *args) + return y, repack_fn(scope) + + (y, out_vars), out_tangents = jax.jvp( + wrapper, (jvp_vars, args), (variable_tangents, tangents) + ) + return (y, out_tangents[0]), out_vars + + # filter out empty tangent collections because JAX will error on non-equal + # tree structure for example: {"params": {}} != {}. + treedef = jax.tree_util.tree_structure(scope) + + variable_tangents = tuple( + {k: v for k, v in vt.items() if v} # pylint: disable=g-complex-comprehension + for vt in treedef.flatten_up_to(variable_tangents) + ) + target = tuple(variable_tangents[0].keys()) + return pack( + inner, + (target, variables), + (variables,), + (rngs,), + name='jvp', + enable_kwargs=False, + )(scope, *primals) + + +def vmap( + fn: Callable[..., Any], + variable_axes: Mapping[CollectionFilter, InOutAxis], + split_rngs: Mapping[PRNGSequenceFilter, bool], + in_axes=0, + out_axes=0, + axis_size: Optional[int] = None, + axis_name: Optional[str] = None, + spmd_axis_name: Optional[str] = None, + metadata_params: Dict[Any, Any] = {}, +) -> Callable[..., Any]: + """A lifted version of ``jax.vmap``. + + See ``jax.vmap`` for the unlifted batch transform in Jax. + + ``vmap`` can be used to add a batch axis to a scope function. + For example we could create a version of ``dense`` with + a batch axis that does not share parameters:: + + batch_dense = lift.vmap( + nn.dense, + in_axes=(0, None), + variable_axes={'params': 0}, + split_rngs={'params': True}) + + By using ``variable_axes={'params': 0}``, we indicate that the + parameters themselves are mapped over and therefore not shared along + the mapped axis. Consequently, we also split the 'params' RNG, + otherwise the parameters would be initialized identically along + the mapped axis. + + Similarly, ``vmap`` could be use to add a batch axis with parameter + sharing:: + + batch_foo = lift.vmap( + foo, + in_axes=0, out_axes=0, + variable_axes={'params': None}, + split_rngs={'params': False}) + + Here we use ``variable_axes={'params': None}`` to indicate the parameter + variables are shared along the mapped axis. Consequently, the 'params' + RNG must also be shared. + + Args: + fn: the function to be transformed. + variable_axes: the variable collections that are lifted into the batching + transformation. Use `None` to indicate a broadcasted collection or an + integer to map over an axis. + split_rngs: Split PRNG sequences will be different for each index of the + batch dimension. Unsplit PRNGs will be broadcasted. + in_axes: Specifies the mapping of the input arguments (see `jax.vmap). + out_axes: Specifies the mapping of the return value (see `jax.vmap). + axis_size: Specifies the size of the batch axis. This only needs to be + specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together with + parallel reduction primitives (e.g. `jax.lax.pmean`, `jax.lax.ppermute`, + etc.). Note, this is only used for pmap and shmap. For SPMD jit, you do + not need to manually synchronize. Just make sure that the axes are + correctly annotated and XLA:SPMD will insert the necessary collectives. + spmd_axis_name: Axis name added to any pjit sharding constraints appearing + in `fn`. See also + https://github.com/google/flax/blob/main/flax/linen/partitioning.py. + metadata_params: arguments dict passed to AxisMetadata instances in the + variable tree. + + Returns: + A vectorized version of the input scope function. + """ + variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) + variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) + variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items()) + rng_groups, rng_splits = _unzip2(split_rngs.items()) + rng_axes = tuple(0 if rng_split else None for rng_split in rng_splits) + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): + def find_axis_size(axis, x): + if axis is not None: + leaves = jax.tree_util.tree_leaves(x) + if leaves: + return leaves[0].shape[axis] + return () + + # split rngs + axis_sizes = jax.tree_util.tree_map( + find_axis_size, (variable_in_axes, in_axes), (variable_groups, args) + ) + axis_sizes = set(jax.tree_util.tree_leaves(axis_sizes)) + if axis_size is None and len(axis_sizes) == 1: + (d_axis_size,) = axis_sizes + elif len(axis_sizes) > 1: + raise ValueError(f'Inconsistent batch axis sizes: {axis_sizes}') + elif axis_size is None: + raise ValueError('axis_size should be specified manually.') + else: + d_axis_size = axis_size + # random.clone is only available on Jax versions 0.4.26 or newer + # see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html + if hasattr(random, 'clone'): + split_fn = lambda rng: random.split(random.clone(rng), d_axis_size) + else: + split_fn = lambda rng: random.split(rng, d_axis_size) + + rng_groups = tuple( + tree_map_rngs(split_fn, rng_group) if split else rng_group + for rng_group, split in zip(rng_groups, rng_splits) + ) + + new_variable_groups = [] + for var_group, axis in zip(variable_groups, variable_in_axes): + if axis is not None: + new_variable_groups.append( + meta.remove_axis(var_group, axis, metadata_params) + ) + else: + new_variable_groups.append(var_group) + variable_groups = tuple(new_variable_groups) + + @functools.partial( + jax.vmap, + in_axes=(variable_in_axes, rng_axes, in_axes), + out_axes=(out_axes, variable_out_axes), + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, + ) + @functools.wraps(fn) + def mapped(variable_groups, rng_groups, args): + scope = scope_fn(variable_groups, rng_groups) + y = fn(scope, *args) + return y, repack_fn(scope) + + y, vars_out = mapped(variable_groups, rng_groups, args) + new_vars_out = [] + for var_group, axis in zip(vars_out, variable_out_axes): + if axis is not None: + new_vars_out.append(meta.add_axis(var_group, axis, metadata_params)) + else: + new_vars_out.append(var_group) + vars_out = tuple(new_vars_out) + return y, vars_out + + return pack( + inner, variable_in_groups, variable_out_groups, rng_groups, name='vmap' + ) + + +def scan( + fn: Callable[..., Any], + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {}, + variable_broadcast: CollectionFilter = False, + variable_carry: CollectionFilter = False, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, + in_axes=0, + out_axes=0, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False, + data_transform: Optional[Callable[..., Any]] = None, + metadata_params: Dict[Any, Any] = {}, +) -> Callable[..., Any]: + """A lifted version of ``jax.lax.scan``. + + See ``jax.lax.scan`` for the unlifted scan in Jax. + + To improve consistency with ``vmap``, this version of scan + uses ``in_axes`` and ``out_axes`` to determine which arguments + are scanned over and along which axis. + + ``scan`` distinguishes between 3 different types of values inside the loop: + + 1. **scan**: a value that is iterated over in a loop. All scan values must + have the same size in the axis they are scanned over. Scanned outputs + will be stacked along the scan axis. + 2. **carry**: A carried value is updated at each loop iteration. It must + have the same shape and dtype throughout the loop. + 3. **broadcast**: a value that is closed over by the loop. When a variable + is broadcasted they are typically initialized inside the loop body but + independent of the loop variables. + + The loop body should have the signature + ``(scope, body, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` + are the scan values that go in and out of the loop. + + Example:: + + scope.variable('counter', 'i', jnp.zeros, ()) + def body_fn(scope, c, x): + counter = scope.variable('counter', 'i', jnp.zeros, ()) + counter.value += 1 + x = scope.child(nn.dense)(x, 1) + return c, x + + _, ys = lift.scan( + body_fn, + variable_carry='counter', + variable_broadcast='params', + split_rngs={'params': False})(scope, (), xs) + + Args: + fn: the function to be transformed. + variable_axes: the variable collections that are scanned over. + variable_broadcast: Specifies the broadcasted variable collections. + A broadcasted variable should not depend on any computation that cannot b + lifted out of the loop. This is typically used to define shared parameters + inside the fn. + variable_carry: Specifies the variable collections that are carried through + the loop. Mutations to these variables are carried to the next iteration + and will be preserved when the scan finishes. + split_rngs: Split PRNG sequences will be different for each loop iterations. + If split is False the PRNGs will be the same across iterations. + in_axes: Specifies the axis to scan over for the arguments. Should be a + prefix tree of the arguments. Use `flax.core.broadcast` to feed an entire + input to each iteration of the scan body. + out_axes: Specifies the axis to scan over for the return value. Should be a + prefix tree of the return value. + length: Specifies the number of loop iterations. This only needs + to be specified if it cannot be derived from the scan arguments. + reverse: If true, scan from end to start in reverse order. + unroll: how many scan iterations to unroll within a single + iteration of a loop (default: 1). + _split_transpose: An experimental feature to split the transpose of a scan + into a scan and a map, backed by an experimental Jax lax.scan() feature. + data_transform: optional function to transform raw variable and rng groups, + intended for inline SPMD annotations. + metadata_params: arguments dict passed to AxisMetadata instances in the + variable tree. + + Returns: + The scan function with the signature + ``(scope, carry, *xxs) -> (carry, yys)``, where ``xxs`` and ``yys`` are the + scan values that go in and out of the loop. + """ + variable_in_axes, variable_out_axes = _split_in_out_axes(variable_axes) + variable_in_groups, variable_in_axes = _unzip2(variable_in_axes.items()) + variable_out_groups, variable_out_axes = _unzip2(variable_out_axes.items()) + assert all(isinstance(ax, int) for ax in variable_in_axes) + assert all(isinstance(ax, int) for ax in variable_out_axes) + rng_groups, rng_splits = _unzip2(split_rngs.items()) + rng_axes = tuple( + 0 if rng_split else axes_scan.broadcast for rng_split in rng_splits + ) + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, init, *args): + def find_length(axis, x): + if axis is not axes_scan.broadcast: + leaves = jax.tree_util.tree_leaves(x) + if leaves: + return leaves[0].shape[axis] + return () + + # split rngs + lengths = jax.tree_util.tree_map(find_length, in_axes, args) + lengths = set(jax.tree_util.tree_leaves(lengths)) + if length is None and len(lengths) == 1: + (d_length,) = lengths + elif len(lengths) > 1: + raise ValueError(f'Inconsistent scan lengths: {lengths}') + elif length is None: + raise ValueError('length should be specified manually.') + else: + d_length = length + # random.clone is only available on Jax versions 0.4.26 or newer + # see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html + if hasattr(random, 'clone'): + split_fn = lambda rng: random.split(random.clone(rng), d_length) + else: + split_fn = lambda rng: random.split(rng, d_length) + + rng_groups = tuple( + tree_map_rngs(split_fn, rng_group) if split else rng_group + for rng_group, split in zip(rng_groups, rng_splits) + ) + + @functools.partial( + axes_scan.scan, + in_axes=(variable_in_axes, rng_axes, in_axes), + out_axes=(out_axes, variable_out_axes), + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose + ) + def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): + carry_vars, c = carry + variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups + if data_transform is not None: + variable_groups, rng_groups = data_transform( + variable_groups, rng_groups + ) + scope = scope_fn(variable_groups, rng_groups) + c, y = fn(scope, c, *args) + out_vars = repack_fn(scope) + broadcast_vars_out = out_vars[0] + carry_vars = out_vars[1] + scan_vars = out_vars[2:] + # add immutable broadcast vars back to broadcast output + # otherwise they won't be fed to the actual scan body + for in_group, out_group in zip(broadcast_vars, broadcast_vars_out): + for col in in_group: + if col not in out_group: + out_group[col] = in_group[col] + return broadcast_vars_out, (carry_vars, c), (y, scan_vars) + + broadcast_vars = variable_groups[0] + carry_vars = variable_groups[1] + scan_vars = variable_groups[2:] + new_scan_vars = [] + for scan_group, axis in zip(scan_vars, variable_in_axes): + new_scan_vars.append(meta.remove_axis(scan_group, axis, metadata_params)) + broadcast_vars, (carry_vars, c), (ys, scan_vars) = scanned( + broadcast_vars, + (carry_vars, init), + tuple(new_scan_vars), + rng_groups, + args, + ) + new_scan_vars = [] + for scan_group, axis in zip(scan_vars, variable_out_axes): + new_scan_vars.append(meta.add_axis(scan_group, axis, metadata_params)) + scan_vars = tuple(new_scan_vars) + out_vars = ( + broadcast_vars, + carry_vars, + ) + scan_vars + return (c, ys), out_vars + + return pack( + inner, + (variable_broadcast, variable_carry) + variable_in_groups, + (variable_broadcast, variable_carry) + variable_out_groups, + rng_groups, + name='scan', + ) + + +C = TypeVar('C') + + +def while_loop( + cond_fn: Callable[[Scope, C], bool], + body_fn: Callable[[Scope, C], C], + scope: Scope, + init: C, + carry_variables: CollectionFilter = False, + broadcast_variables: CollectionFilter = True, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {}, +) -> C: + """Lifted version of jax.lax.while_loop. + + The lifted scope is passed to `cond_fn` and `body_fn`. + Broadcasted variables are immutable. The carry variable are + mutable but cannot change shape and dtype. + This also means you cannot initialize variables inside + the body. Consider calling `body_fn` once manually before + calling `while_loop` if variable initialization is required. + + Example:: + + def f(scope, x): + def cond_fn(scope, c): + return scope.get_variable('state', 'acc') < 10 + def body_fn(scope, c): + acc = scope.variable('state', 'acc') + acc += 1 + y = scope.child(nn.dense)(c, c.shape[-1]) + return y + + c = x + c = body_fn(scope, c) + return lift.while_loop(cond_fn, body_fn, scope, (), + carry_variables='state') + + Args: + cond_fn: Should return True as long as the loop should continue. + body_fn: The body of the while loop. + scope: The scope(s) which should be lifted into the loop. + init: The initial state passed to the loop + carry_variables: collections that are carried through the loop + and are therefore mutable (default: none). + broadcast_variables: collections that are closed over and are + therefore read-only (default: all collections) + split_rngs: Split PRNG sequences will be different for each loop iterations. + If split is False the PRNGs will be the same across iterations. + Returns: + The final state after executing the while loop. + """ + rng_groups, rng_splits = _unzip2(split_rngs.items()) + + def inner(scope_fn, repack_fn, variable_groups, rng_groups): + carry_variables, broadcast_variables = variable_groups + + def make_loop_rngs(i): + local_rng_groups = [] + for rng_group, rng_split in zip(rng_groups, rng_splits): + if rng_split: + rng_group = tree_map_rngs( + lambda rng: random.fold_in(rng, i), rng_group + ) + local_rng_groups.append(rng_group) + return local_rng_groups + + def cond_wrapper(c): + i, carry_variables, carry = c + scope = scope_fn( + (carry_variables, broadcast_variables), + make_loop_rngs(-i), + mutable_filter=False, + ) + return cond_fn(scope, carry) + + def body_wrapper(c): + i, carry_variables, carry = c + scope = scope_fn( + (carry_variables, broadcast_variables), make_loop_rngs(i) + ) + carry = body_fn(scope, carry) + (carry_variables,) = repack_fn(scope) + return (i + 1, carry_variables, carry) + + c = (0, carry_variables, init) + _, carry_variables, carry = jax.lax.while_loop( + cond_wrapper, body_wrapper, c + ) + return carry, (carry_variables,) + + return pack( + inner, + (carry_variables, broadcast_variables), + (carry_variables,), + rng_groups, + name='while_loop', + )(scope) + + +def cond( + pred: Any, + true_fun: Callable[..., C], + false_fun: Callable[..., C], + scope: Scope, + *operands, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> C: + """Lifted version of ``jax.lax.cond``. + + The returned values from ``true_fun`` and ``false_fun`` + must have the same Pytree structure, shapes, and dtypes. + The variables created or updated inside the + branches must also have the same structure. + Note that this constraint is violated when + creating variables or submodules in only one branch. + Because initializing variables in just one branch + causes the paramater structure to be different. + + Example:: + + def cond_example(scope, x, pred): + scope.variable('state', 'true_count', lambda: 0) + scope.variable('state', 'false_count', lambda: 0) + def true_fn(scope, x): + scope.variable('state', 'true_count').value += 1 + return scope.child(nn.dense)(x, 2) + def false_fn(scope, x): + scope.variable('state', 'false_count').value += 1 + return -scope.child(nn.dense)(x, 2) + return lift.cond(pred, true_fn, false_fn, scope, x) + + + Args: + pred: determines if true_fun or false_fun is evaluated. + true_fun: The function evalauted when ``pred`` is `True`. + The signature is (Scope, *operands) -> T. + false_fun: The function evalauted when ``pred`` is `False`. + The signature is (Scope, *operands) -> T. + scope: A Scope or Pytree of scopes to pass + *operands: The arguments passed to ``true_fun`` and ``false_fun`` + variables: The variable collections passed to the conditional + branches (default: all) + rngs: The PRNG sequences passed to the conditionals (default: all) + Returns: + The result of the evaluated branch (``true_fun`` or ``false_fun``). + """ + branches = [true_fun, false_fun] + + def inner(scope_fn, repack_fn, variable_groups, rng_groups): + def branch_wrapper(branch_fn, *operands): + scope = scope_fn(variable_groups, rng_groups) + y = branch_fn(scope, *operands) + return y, repack_fn(scope) + + pure_branches = [ + functools.partial(branch_wrapper, branch_fn) for branch_fn in branches + ] + return jax.lax.cond(pred, pure_branches[0], pure_branches[1], *operands) + + return pack(inner, (variables,), (variables,), (rngs,), name='cond')(scope) + + +def switch( + index: Any, + branches: Sequence[Callable[..., C]], + scope: Scope, + *operands, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> C: + """Lifted version of ``jax.lax.switch``. + + The returned values from ``branches`` + must have the same Pytree structure, shapes, and dtypes. + The variables created or updated inside the + branches must also have the same structure. + Note that this constraint is violated when + creating variables or submodules in only one branch. + Because initializing variables in just one branch + causes the parameter structure to be different. + + Example:: + + def switch_example(scope, x, index): + scope.variable('state', 'a_count', lambda: 0) + scope.variable('state', 'b_count', lambda: 0) + scope.variable('state', 'c_count', lambda: 0) + def a_fn(scope, x): + scope.variable('state', 'a_count').value += 1 + return scope.child(nn.dense)(x, 2) + def b_fn(scope, x): + scope.variable('state', 'b_count').value += 1 + return -scope.child(nn.dense)(x, 2) + def c_fn(scope, x): + scope.variable('state', 'c_count').value += 1 + return scope.child(nn.dense)(x, 2) + return lift.switch(index, [a_fn, b_fn, c_fn], scope, x) + + If you want to have a different parameter structure for each branch + you should run all branch on initialization before calling switch:: + + def multihead_switch_example(scope, x, index): + def a_fn(scope, x): + x = scope.child(nn.dense)(x, 10) + x = scope.child(nn.dense)(x, 7) + return scope.child(nn.dense)(x, 5) + def b_fn(scope, x): + x = scope.child(nn.dense)(x, 11) + return scope.child(nn.dense)(x, 5) + def c_fn(scope, x): + return scope.child(nn.dense)(x, 5) + + branches = [a_fn, b_fn, c_fn] + + # run all branches on init + if scope.is_mutable_collection('params'): + for branch in branches: + _ = branch(scope, x) + + return lift.switch(index, branches, scope, x) + + Args: + index: Integer scalar type, indicating which branch function to apply. + branches: Sequence of functions to be applied based on index. + The signature of each function is (Scope, *operands) -> T. + scope: A Scope or Pytree of scopes to pass + *operands: The arguments passed to ``true_fun`` and ``false_fun`` + variables: The variable collections passed to the conditional + branches (default: all) + rngs: The PRNG sequences passed to the conditionals (default: all) + Returns: + The result of the evaluated branch. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups): + def branch_wrapper(branch_fn, *operands): + scope = scope_fn(variable_groups, rng_groups) + y = branch_fn(scope, *operands) + return y, repack_fn(scope) + + pure_branches = [ + functools.partial(branch_wrapper, branch_fn) for branch_fn in branches + ] + return jax.lax.switch(index, pure_branches, *operands) + + return pack(inner, (variables,), (variables,), (rngs,), name='switch')(scope) + + +def custom_vjp( + fn: Callable[..., Any], + forward_fn: Callable[..., Any], + backward_fn: Callable[..., Any], + grad_vars: CollectionFilter = 'params', + nondiff_argnums=(), +): + """Lifted version of `jax.custom_vjp`. + + `forward_fn` and `backward_fn` together define a custom vjp for `fn`. + The original `fn` will run in case a vjp (backward gradient) is not computed. + + The `forward_fn` receives the same arguments as `fn` but is expected to return + a tuple containing the output of `fn(scope, *args)` and the residuals that are + passed to `backward_fn`. + + The `backward_fn` receives the nondiff arguments, residuals, and the output + tangents. It should return a tuple containing the variable and input tangents. + + Note that the vjp function returned by `lift.vjp` can be passed as residual + and used in the `backward_fn`. The scope is unavailable during the backward + pass. If the scope is required in `backward_fn`, a snapshot of the variables + can be taken and returned as a residual in the `forward_fn`. + + Example:: + + f = nn.dense + + def fwd(scope, x, features): + y, vjp_fn = lift.vjp(partial(f, features=features), scope, x) + return y, vjp_fn + + def bwd(features, vjp_fn, y_t): + params_t, *inputs_t = vjp_fn(y_t) + params_t = jax.tree_util.tree_map(jnp.sign, params_t) + return (params_t, *inputs_t) + + dense_sign_grad = lift.custom_vjp( + f, forward_fn=fwd, backward_fn=bwd, nondiff_argnums=(2,)) + + Args: + fn: The function to define a custom_vjp for. The first argument + should be a ``Module`` instance. + forward_fn: A function with the same arguments as `fn` returning an tuple + with the original output and the residuals that will be passed to + `backward_fn`. + backward_fn: arguments are passed as (*nondiff_args, residuals, tangents) + The function should return a tuple containing the tangents for the + variable in the collections specified by `grad_vars` and the input + arguments (except the scope and nondiff args). + grad_vars: The collections for which a vjp will be computed + (default: "params"). + nondiff_argnums: arguments for which no vjp is computed. + Returns: + A function with the same signature as `fn` with the custom vjp. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args): + grad_variables, other_variables = variable_groups + scopes_treedef = None + + def f(grad_variables, *args): + scope = scope_fn((grad_variables, other_variables), rng_groups) + y = fn(scope, *args) + vars_out = repack_fn(scope) + return y, vars_out + + f = jax.custom_vjp(f, nondiff_argnums=nondiff_argnums) + + def f_fwd(grad_variables, *args): + nonlocal scopes_treedef + scopes = scope_fn((grad_variables, other_variables), rng_groups) + scopes_treedef = jax.tree_util.tree_structure(scopes) + y, res = forward_fn(scopes, *args) + vars_out = repack_fn(scopes) + return (y, vars_out), res + + def f_bwd(*args): + # the backward function does not pass a lifted scope to the user. + # Currently, there is no way to have side effects flow out of backward + # pass. Even without mutation variables would be ill-defined. For example, + # would we take a snapshot of the variables before or after calling + # `forward_fn`? + nondiff_args = args[:-2] + res, g = args[-2:] # pylint: disable=unbalanced-tuple-unpacking + g_y, _ = g + var_t, *inputs_t = backward_fn(*nondiff_args, res, g_y) + assert scopes_treedef is not None, 'backward called before forward?!' + var_t = tuple(scopes_treedef.flatten_up_to(var_t)) + return (var_t, *inputs_t) + + f.defvjp(f_fwd, f_bwd) + + return f(grad_variables, *args) + + variable_in_groups = (grad_vars, True) + variable_out_groups = (grad_vars, True) + rng_groups = (True,) + return pack( + inner, + variable_in_groups, + variable_out_groups, + rng_groups, + name='custom_vjp', + ) + + +def checkpoint( + fn: Callable[..., Any], + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + concrete: bool = False, + prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), + policy: Optional[Callable[..., bool]] = None, +) -> Callable[..., Any]: + """Lifted version of ``jax.checkpoint``. + + This function is aliased to ``lift.remat`` just like ``jax.remat``. + + Args: + fn: scope function for which intermediate computations should be + re-computed when computing gradients. + variables: The variable collections that are lifted. By default all + collections are lifted. + rngs: The PRNG sequences that are lifted. By default all PRNG sequences + are lifted. + concrete: Optional, boolean indicating whether ``fun`` may involve + value-dependent Python control flow (default ``False``). Support for such + control flow is optional, and disabled by default, because in some + edge-case compositions with :func:`jax.jit` it can lead to some extra + computation. + prevent_cse: Optional, boolean indicating whether to prevent common + subexpression elimination (CSE) optimizations in the HLO generated from + differentiation. This CSE prevention has costs because it can foil other + optimizations, and because it can incur high overheads on some backends, + especially GPU. The default is True because otherwise, under a ``jit`` or + ``pmap``, CSE can defeat the purpose of this decorator. But in some + settings, like when used inside a ``scan``, this CSE prevention mechanism + is unnecessary, in which case ``prevent_cse`` can be set to False. + static_argnums: Optional, int or sequence of ints, indicates which argument + values on which to specialize for tracing and caching purposes. Specifying + arguments as static can avoid ConcretizationTypeErrors when tracing, but + at the cost of more retracing overheads. + policy: Experimental checkpoint policy, see ``jax.checkpoint``. + Returns: + A wrapped version of ``fn``. When computing gradients intermediate + computations will be re-computed when computing gradients. + """ + + def inner(scope_fn, repack_fn, variable_groups, rng_groups, *args, **kwargs): + # add 2 to each static_argnums because we add two initial arguments to rematted + static_argnums_ = jax.tree_util.tree_map(lambda x: x + 2, static_argnums) + + @functools.partial( + jax.remat, + concrete=concrete, + static_argnums=static_argnums_, + prevent_cse=prevent_cse, + policy=policy, + ) + @functools.wraps(fn) + def rematted(variable_groups, rng_groups, *args, **kwargs): + scope = scope_fn(variable_groups, rng_groups) + y = fn(scope, *args, **kwargs) + return y, repack_fn(scope) + + return rematted(variable_groups, rng_groups, *args, **kwargs) + + return pack( + inner, + (variables,), + (variables,), + (rngs,), + name='remat', + enable_kwargs=True, + ) + + +remat = checkpoint + + +def _hashable_filter(x): + """Hashable version of CollectionFilter.""" + if isinstance(x, str): + return (x,) + if isinstance(x, Iterable): + return tuple(x) # convert un-hashable list & sets to tuple + if isinstance(x, DenyList): + return DenyList( + _hashable_filter(x.deny) + ) # convert inner filter recursively + return x + + +def jit( + fn: Callable[..., Any], + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + static_argnums: Union[int, Iterable[int]] = (), + static_argnames: Union[str, Iterable[str]] = (), + donate_argnums: Union[int, Iterable[int]] = (), + device=None, + backend: Union[str, None] = None, +) -> Callable[..., Any]: + """Lifted version of ``jax.jit``. + + Args: + fn: Scope function to be jitted. + variables: The variable collections that are lifted. By default all + collections are lifted. + rngs: The PRNG sequences that are lifted. By default all PRNG sequences + are lifted. + static_argnums: An int or collection of ints specifying which positional + arguments to treat as static (compile-time constant). Operations that only + depend on static arguments will be constant-folded in Python (during + tracing), and so the corresponding argument values can be any Python + object. Static arguments should be hashable, meaning both ``__hash__`` and + ``__eq__`` are implemented, and immutable. Calling the jitted function + with different values for these constants will trigger recompilation. If + the jitted function is called with fewer positional arguments than + indicated by ``static_argnums`` then an error is raised. Arguments that + are not arrays or containers thereof must be marked as static. + Defaults to (). + static_argnames: An optional string or collection of strings specifying + which named arguments to treat as static (compile-time constant). See the + comment on ``static_argnums`` for details. If not + provided but ``static_argnums`` is set, the default is based on calling + ``inspect.signature(fun)`` to find corresponding named arguments. + donate_argnums: Specify which arguments are "donated" to the computation. + It is safe to donate arguments if you no longer need them once the + computation has finished. In some cases XLA can make use of donated + buffers to reduce the amount of memory needed to perform a computation, + for example recycling one of your input buffers to store a result. You + should not reuse buffers that you donate to a computation, JAX will raise + an error if you try to. + device: This is an experimental feature and the API is likely to change. + Optional, the Device the jitted function will run on. (Available devices + can be retrieved via :py:func:`jax.devices`.) The default is inherited + from XLA's DeviceAssignment logic and is usually to use + ``jax.devices()[0]``. + backend: a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + + Returns: + A wrapped version of ``fn``, set up for just-in-time compilation. + """ + if not isinstance(static_argnums, Iterable): + static_argnums = (static_argnums,) + if not isinstance(donate_argnums, Iterable): + donate_argnums = (donate_argnums,) + # offset argnums by two because first argument in the original function is the + # scope while jitted has 3 functions before the user arguments. + static_argnums = (0,) + tuple(i + 2 for i in static_argnums if i > 0) + donate_argnums = tuple(i + 2 for i in donate_argnums if i > 0) + + # Close over scope_fn & repack_fn to avoid recompilation + # this is impure but we use the fingerprint arg to differentiate between cases + # where scope_fn or repack_fn actually produce non-identical results. + scope_fn = None # type: Optional[Callable] + repack_fn = None # type: Optional[Callable] + + @functools.partial( + jax.jit, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + device=device, + backend=backend, + ) + @functools.wraps(fn) + def jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs): + nonlocal scope_fn, repack_fn + hash_key = fingerprint[1] + # fingerprint is only used to differentiate the cache signature + del fingerprint + scope = scope_fn(variable_groups, rng_groups) # pylint: disable=not-callable + y = fn(scope, hash_key, *args, **kwargs) + return y, repack_fn(scope) # pylint: disable=not-callable + + def inner( + scope_fun, + repack_fun, + variable_groups, + rng_groups, + module_hash_key, + *args, + **kwargs, + ): + nonlocal scope_fn, repack_fn + try: + scope_fn = scope_fun + repack_fn = repack_fun + scopes = jax.tree_util.tree_leaves(scope_fn(variable_groups, rng_groups)) + mutable = tuple(_hashable_filter(scope.mutable) for scope in scopes) + fingerprint = (mutable, module_hash_key) + return jitted(fingerprint, variable_groups, rng_groups, *args, **kwargs) + finally: + scope_fn, repack_fn = None, None + + return pack( + inner, (variables,), (variables,), (rngs,), name='jit', enable_kwargs=True + ) + + +def remat_scan( + body_fn: Callable[..., Any], + lengths: Sequence[int], + policy: Optional[Callable[..., bool]] = None, + variable_broadcast: CollectionFilter = False, + variable_carry: CollectionFilter = False, + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = {True: 0}, + split_rngs: Mapping[PRNGSequenceFilter, bool] = {True: True}, +) -> Callable[..., Any]: + """Combines `lift.remat` and `lift.scan` for memory efficiency and constant time compilation. + + ``remat_scan`` allows for constant compile times and sublinear + memory usage with respect to model depth. At a small constant + penalty. This is typically beneficial for very deep models. + + Example:: + + def body_fn(scope, x): + return nn.dense(scope, x, features=x.shape[-1]) + # 100x dense with O(sqrt(N)) memory for gradient computation + y = lift.remat_scan(body_fn, lengths=(10, 10))(scope, x) + + Args: + body_fn: Scope function to be repeated using a (nested scan) + lengths: number of loop iterations at the given level. The total number of + iterations `n = prod(lengths)`. each loop is rematerialized. This way the + memory consumption is proportional to `n^(1 / d)` where `d = + len(lengths)`. Minimal memory consumptions requires tuning the lengths + such that the same amount of memory is consumed at each level of the + nested loop. + policy: Experimental checkpoint policy, see ``jax.checkpoint``. + variable_broadcast: Specifies the broadcasted variable collections. A + broadcasted variable should not depend on any computation that cannot be + lifted out of the loop. This is typically used to define shared parameters + inside the fn. + variable_carry: Specifies the variable collections that are carried through + the loop. Mutations to these variables are carried to the next iteration + and will be preserved when the scan finishes. + variable_axes: the variable collections that are scanned over. + split_rngs: Split PRNG sequences will be different for each loop iterations. + If split is False the PRNGs will be the same across iterations. + Returns: + A wrapped version of ``body_fn`` that repeats itself prod(lengths) times. + """ + # TODO(jheek) should remat scan have scan inputs/outputs? + scan_fn = functools.partial( + scan, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + variable_axes=variable_axes, + split_rngs=split_rngs, + ) + if len(lengths) == 1: + + def wrapper(scope, carry): + return body_fn(scope, carry), () + + fn = lambda scope, c: scan_fn(wrapper, length=lengths[0])(scope, c)[0] + else: + + @functools.partial(remat, policy=policy, prevent_cse=False) + def inner_loop(scope, carry): + carry = remat_scan( + body_fn, + lengths[1:], + policy, + variable_broadcast, + variable_carry, + variable_axes, + split_rngs, + )(scope, carry) + return carry, () + + fn = lambda scope, c: scan_fn(inner_loop, length=lengths[0])(scope, c)[0] + return fn + + +def _unzip2(xs): + ys = tuple(zip(*xs)) + return ys if ys else ((), ()) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/scope.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/scope.py new file mode 100644 index 000000000..1d7f430ad --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/core/scope.py @@ -0,0 +1,1238 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flax functional core: Scopes.""" + +import collections +import contextlib +import dataclasses +import functools +import hashlib +import typing +from typing import ( + Any, + Callable, + Dict, + Generic, + Iterable, + Literal, + Mapping, + Optional, + Sequence, + Set, + Tuple, + TypeVar, + Union, + cast, + overload, +) + +import jax +import numpy as np +from jax import numpy as jnp +from jax import random, tree_util + +from flax import config as config +from flax import configurations as legacy_config # only for flax_lazy_rng +from flax import errors, struct, traceback_util +from flax.ids import uuid +from flax.typing import ( + PRNGKey, + Array, + RNGSequences, + Collection, + MutableCollection, + VariableDict, + FrozenVariableDict as FrozenVariableDict, + MutableVariableDict, + PRNGFoldable, +) + +from . import meta, partial_eval, tracers +from .frozen_dict import FrozenDict, freeze, unfreeze + +traceback_util.register_exclusion(__file__) + +T = TypeVar('T') + + +Filter = Union[bool, str, typing.Collection[str], 'DenyList'] + +# When conditioning on filters we require explicit boolean comparisons. +# pylint: disable=g-bool-id-comparison + + +@dataclasses.dataclass(frozen=True, eq=True) +class DenyList: + """DenyList represents an opt-out based mutability filter. + DenyList can be used to make every collection mutable except the ones + defined in the given filter. + To for example make everything but the params collection mutable:: + nn.apply(fn, mutable=nn.DenyList(["params"])) + Attributes: + deny: The filter representing the collections that are not mutable. + """ + + deny: Filter + + +CollectionFilter = Filter +PRNGSequenceFilter = Filter + + +class LazyRng(struct.PyTreeNode): + """Wrapper around JAX PRNGKey that lazily maintains a tuple of static data to be folded into the rng.""" + + rng: PRNGKey + suffix: Tuple[PRNGFoldable, ...] = struct.field(pytree_node=False) + + def as_jax_rng(self) -> PRNGKey: + return _fold_in_static(self.rng, self.suffix) + + @staticmethod + def create( + rng: Union['LazyRng', PRNGKey], *suffix: PRNGFoldable + ) -> 'LazyRng': + if not legacy_config.flax_lazy_rng: + if isinstance(rng, LazyRng): + assert not rng.suffix + rng = rng.rng + return LazyRng(_legacy_rng_fold_in(rng, suffix), ()) + if isinstance(rng, LazyRng): + return LazyRng(rng.rng, rng.suffix + suffix) + else: + return LazyRng(rng, suffix) + + +def _legacy_rng_fold_in(rng: PRNGKey, data: Iterable[PRNGFoldable]) -> PRNGKey: + """Legacy RNG folding.""" + for x in data: + if isinstance(x, str): + m = hashlib.sha1() + m.update(x.encode('utf-8')) + d = m.digest() + hash_int = int.from_bytes(d[:4], byteorder='big') + rng = random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore + elif isinstance(x, int): + rng = random.fold_in(rng, x) + else: + raise ValueError(f'Expected int or string, got: {x}') + return rng + + +def _fold_in_static( + rng: PRNGKey, data: typing.Collection[PRNGFoldable] +) -> PRNGKey: + """Folds static data (strings & ints) into a jax.random.PRNGKey using its SHA-1 hash. + + This is faster than splitting an PRNGKey because it allows generating new PRNG + keys in parallel that are independent of each other. + + Args: + rng: the rng to fold the string into. + data: the string to be folded in. + + Returns: + The newly generated PRNG key. + """ + if not data: + return rng + m = hashlib.sha1() + for x in data: + if config.flax_fix_rng_separator: + # encode seperate to avoid collisions like for example: ("ab", "c") and ("a", "bc") + m.update(b'\00') + if isinstance(x, str): + m.update(x.encode('utf-8')) + elif isinstance(x, int): + m.update(x.to_bytes((x.bit_length() + 7) // 8, byteorder='big')) + else: + raise ValueError(f'Expected int or string, got: {x}') + d = m.digest() + hash_int = int.from_bytes(d[:4], byteorder='big') + return random.fold_in(rng, jnp.uint32(hash_int)) # type: ignore + + +def is_filter_empty(filter_like: Filter) -> bool: + """Returns True if `filter_like` is an empty filter. + + Args: + filter_like: The filter to test. + + Returns: + A filter is empty when it is an empty collection, it is a bool with value + False, ir it is a DenyList that matches everything. A string filter is never + empty. + """ + if isinstance(filter_like, str): + return False + if isinstance(filter_like, typing.Collection): + return not filter_like + if isinstance(filter_like, bool): + return not filter_like + if isinstance(filter_like, DenyList): + # if any arbitrary collection is in the denylist it matches everything so + # the filter is empty. This is checked with a stub. + return in_filter(filter_like.deny, '__flax_internal_stub__') + raise errors.InvalidFilterError(filter_like) + + +def in_filter(filter_like: Filter, col: str) -> bool: + """Checks whether a filter can be applied to a collection. + + Used for both collections and rng sequence filters. + + Args: + filter_like: a filter (either a boolean, a string, or a list of strings) for + a collection. + col: a collection, which is a string identifying a dictionary of data, for + instance "params" or "batch_stats". + + Returns: + True if either `filter_like` is True, equal to `col`, or a sequence + containing `col`. + """ + if isinstance(filter_like, str): + return col == filter_like + if isinstance(filter_like, typing.Collection): + return col in filter_like + if isinstance(filter_like, bool): + return filter_like + if isinstance(filter_like, DenyList): + return not in_filter(filter_like.deny, col) + raise errors.InvalidFilterError(filter_like) + + +def filter_to_set(x: Filter) -> Set[str]: + """Converts a Filter into a set of collections, fails on the infinite set. + + Args: + x: a filter (boolean, string, or list of strings). + + Returns: + The input filter represented as a set of strings. + """ + assert x is not True and not isinstance(x, DenyList), 'Infinite set' + if x is False: + return set() + if isinstance(x, str): + return set([x]) + if isinstance(x, typing.Collection): + return set(x) + raise errors.InvalidFilterError(x) + + +def union_filters(a: Filter, b: Filter) -> Filter: + """Takes the union of two filters (similar to a logical or). + + Args: + a: a filter. + b: a filter. + + Returns: + The union of the two input filters. For instance, + `union_filters('f1', ['f2']) = {'f1', 'f2'}`. + """ + if a is True or b is True: + return True + if isinstance(a, DenyList) and isinstance(b, DenyList): + return DenyList(intersect_filters(a.deny, b.deny)) + if isinstance(b, DenyList): + a, b = b, a + if isinstance(a, DenyList): + return DenyList(subtract_filters(a.deny, b)) + + a = filter_to_set(a) + b = filter_to_set(b) + return a.union(b) + + +def subtract_filters(a: Filter, b: Filter) -> Filter: + """Returns the subtraction of b from a. + + Args: + a: a filter. + b: a filter. + + Returns: + A filter matching with values in a that are not in b. + """ + if b is True: + return False + if a is True: + return DenyList(b) + if isinstance(a, DenyList) and isinstance(b, DenyList): + return subtract_filters(b.deny, a.deny) + if isinstance(a, DenyList): + return DenyList(union_filters(a.deny, b)) + if isinstance(b, DenyList): + return intersect_filters(a, b.deny) + a = filter_to_set(a) + b = filter_to_set(b) + return a - b + + +def intersect_filters(a: Filter, b: Filter) -> Filter: + """Take the intersection of two filters (similar to a logical and). + + Args: + a: a filter. + b: a filter. + + Returns: + The intersection of the two input filters. For instance, + `intersect_filters('f1', ['f1', 'f2']) = {'f1'}`. + """ + if a is True: + return b + if b is True: + return a + if isinstance(a, DenyList) and isinstance(b, DenyList): + return DenyList(union_filters(b.deny, a.deny)) + if isinstance(b, DenyList): + b, a = a, b + if isinstance(a, DenyList): + return subtract_filters(b, a.deny) + a = filter_to_set(a) + b = filter_to_set(b) + return a.intersection(b) + + +def group_collections( + xs: VariableDict, col_filters: Sequence[CollectionFilter] +) -> Sequence[MutableVariableDict]: + """Groups variables by collection filters. + + Iteratively applies the filters in `col_filters` to `xs`, and adds the result + of applying each filter to the output sequence. Each key in `xs` is only added + to the output once. + + Args: + xs: a dictionary of variables, keyed by collections (strings). + col_filters: a list of collection filters. + + Returns: + A sequence S with `len(S) == len(col_filters)`. Each `S[i]` is the result of + applying filter `col_filters[i]` to the remaining keys in `xs`. + """ + cols: Iterable[str] + cols = xs.keys() + groups = [] + for col_filter in col_filters: + remaining_cols = [] + group = {} + for col in cols: + if in_filter(col_filter, col): + group[col] = jax.tree_util.tree_map(lambda x: x, xs[col]) + else: + remaining_cols.append(col) + cols = remaining_cols + groups.append(group) + return tuple(groups) + + +class Variable(Generic[T]): + """A Variable object allows mutable access to a variable in a VariableDict. + + Variables are identified by a collection (e.g., "batch_stats") and a name + (e.g., "moving_mean"). The value property gives access to the variable's + content and can be assigned to for mutation. + """ + + def __init__(self, scope: 'Scope', collection: str, name: str, unbox: bool): + """Initializes a variable. + + Args: + scope: The scope in which the variable is stored. + collection: The collection of the variable (e.g., "params"). + name: The name of the variable (e.g., "dense"). + unbox: Whether to unbox boxed values with metadata. + """ + self._id = uuid() + self.scope = scope + self.collection = collection + self.name = name + self.unbox = unbox + + @property + def value(self) -> T: + """Returns the value of this Variable.""" + v = self.scope.get_variable(self.collection, self.name) + return meta.unbox(v) if self.unbox else v + + @value.setter + def value(self, value: T): + """Updates the value of this Variable.""" + if self.unbox: + cur = self.scope.get_variable(self.collection, self.name) + cur_struct = tree_util.tree_structure(cur, is_leaf=meta.is_axis_metadata) + value_struct = tree_util.tree_structure( + value, is_leaf=meta.is_axis_metadata + ) + has_meta = any(map(meta.is_axis_metadata, cur_struct.flatten_up_to(cur))) + if cur_struct == value_struct and has_meta: + value = meta.replace_boxed(cur, value) + + self.scope.put_variable(self.collection, self.name, value) + + def is_mutable(self) -> bool: + """Checks if this Variable is mutable.""" + return self.scope.is_mutable_collection(self.collection) + + +class _ChildRNGSentinel: + pass + + +# used to identify that an rng counter is meant for a child scope +child_rng_token = _ChildRNGSentinel() + + +class _DefaultSentinel: + pass + + +# used to denote no default flag value on scope +no_flag = _DefaultSentinel() + + +class Scope: + """A Scope allows easy access to variables and manages RNGS of a neural network layer. + + Scopes are purely functional and encapsulated in + :class:`flax.linen.module.Module`, so users writing neural network code + usually generally do not interact with ``Scopes`` directly. + + See `core design tests + `_ + for a number of examples using ``Scopes``. + """ + + reservations: Dict[str, Set[Optional[str]]] + + def __init__( + self, + variables: MutableVariableDict, + rngs: Optional[Union[RNGSequences, Dict[str, LazyRng]]] = None, + name: Optional[str] = None, + mutable: CollectionFilter = False, + parent: Optional['Scope'] = None, + path: Iterable[str] = (), + debug_path: Iterable[str] = (), + flags: Optional[Mapping] = None, + ): + """Initializes a Scope. + + Args: + variables: VariableDict to initialize the Scope with. + rngs: RNGs used in this scope or one of the child scopes. + name: name of this scope. + mutable: A CollectionFilter determining which variables are mutable. + parent: The parent scope. + path: The path in the variable tree from the root scope to this scope. It + exactly matches the module path. + debug_path: Similar to path but could contain transformation decorators. + flags: internal flags. + """ + rngs = {k: LazyRng.create(v) for k, v in rngs.items()} if rngs else {} + self._variables = variables + self.parent = parent + self.name = name + self.path = tuple(path) + self.debug_path = tuple(debug_path) or self.path + self.rngs = rngs + self.mutable = mutable + self.flags = freeze({} if flags is None else flags) + + self._root = parent.root if parent else None + self.trace_level = tracers.trace_level(tracers.current_trace()) + + self.rng_counters = {key: 0 for key in self.rngs} + self.reservations = collections.defaultdict(set) + + self._invalid = False + + def __eq__(self, other: Any) -> bool: + # If the root variable dict and path are the same, then two scopes behave + # identically. Effectively, a scope is nothing more than a cursor into a + # variable dict and an rng counter dict. + if not isinstance(other, Scope): + return False + if self is other: + return True + return ( + self.root._variables is other.root._variables + and self.path == other.path + and self.rng_counters is other.rng_counters + ) + + def __hash__(self) -> int: + # see __eq__ + return hash((id(self.root._variables), self.path, id(self.rng_counters))) + + @property + def root(self) -> 'Scope': + return self._root or self + + @property + def path_text(self) -> str: + """Returns the debug path as a human readable string.""" + return '/' + '/'.join(self.debug_path) + + @property + def invalid(self) -> bool: + """Returns true if this scope is invalidated as a result of `Scope.temporary`.""" + return self._invalid + + def _check_valid(self): + if self._invalid: + raise errors.InvalidScopeError(self.name) + + @contextlib.contextmanager + def temporary(self): + """Returns a context manager that will invalidate this Scope when leaving the context.""" + try: + yield self + finally: + self.invalidate() + + def invalidate(self): + """Invalidates the Scope.""" + self._invalid = True + + def mutable_variables(self) -> Union[VariableDict, Dict[str, Any]]: + """Returns an immutable copy of the mutable variables belonging to this Scope.""" + self._populate_collections() + xs = { + k: v for k, v in self._variables.items() if in_filter(self.mutable, k) + } + if config.flax_return_frozendict: + return freeze(xs) + return xs + + def variables(self) -> Union[VariableDict, Dict[str, Any]]: + """Returns an immutable copy of the variables belonging to this Scope.""" + self._populate_collections() + if config.flax_return_frozendict: + return freeze(self._variables) + return self._variables + + def _validate_trace_level(self): + tracers.check_trace_level(self.trace_level) + + def rewound(self, rewind_rngs: bool = False) -> 'Scope': + """Returns a rewound version of this Scope. + + Args: + rewind_rngs: if true, reset the RNG counter of this scope. + + Returns: + A rewound version of this scope, which means reservations are + emptied, and the rng counter is optionally rewound. + """ + self._check_valid() + scope = Scope( + self._variables, + self.rngs, + self.name, + self.mutable, + self.parent, + path=self.path, + debug_path=self.debug_path, + flags=self.flags, + ) + if not rewind_rngs: + scope.rng_counters = self.rng_counters + return scope + + def name_reserved(self, name: str, col: Optional[str] = None) -> bool: + """Checks whether a name for a child Scope or Variable is taken. + + Args: + name: the name to check for collision. + col: if a variable, the collection used. + """ + if name in self.reservations: + # allow the same name for two variables in + # different collections, otherwise raise error. + if ( + None in self.reservations[name] + or col is None + or col in self.reservations[name] + ): + return True + return False + + def reserve(self, name: str, col: Optional[str] = None): + """Reserves a name for a child Scope or Variable. + + Throws an error if the name exists already. + + Args: + name: the name to reserve. + col: if a variable, the collection used. + """ + if not isinstance(name, str): + raise TypeError( + 'The type of scope "{name}" should be string but ' f'it is {type(name)}' + ) + if self.name_reserved(name, col): + raise ValueError(f'Duplicate use of scope name: "{name}"') + self.reservations[name].add(col) + + def default_name(self, prefix: str) -> str: + """Generates an unreserved name with the given prefix. + + Args: + prefix: prefix to use for generating an unreserved name. + + Returns: + The generated name. + """ + i = 0 + while True: + name = f'{prefix}{i}' + if name not in self.reservations: + return name + i += 1 + + def push( + self, name: Optional[str] = None, prefix: str = '', reuse=False + ) -> 'Scope': + """Creates a child Scope. + + Args: + name: optional name of the child. + prefix: prefix used for generating the name if `name` is `None`. + reuse: if True will return a pre-existing child scope with the given name + instead of throwing an error. + + Returns: + The child scope. + """ + self._check_valid() + self._validate_trace_level() + if name is None: + name = self.default_name(prefix) + if not reuse or name not in self.reservations: + self.reserve(name) + rngs = {key: LazyRng.create(rng, name) for key, rng in self.rngs.items()} + rng_key = (child_rng_token, name) + if rng_key in self.rng_counters: + rng_counters = self.rng_counters.get(rng_key) # type: ignore + else: + rng_counters = {key: 0 for key in rngs} + self.rng_counters[rng_key] = rng_counters # type: ignore + scope = Scope( + {}, + name=name, + rngs=rngs, + parent=self, + mutable=self.mutable, + path=self.path + (name,), + debug_path=self.debug_path + (name,), + flags=self.flags, + ) + scope.rng_counters = rng_counters + return scope + + def child( + self, + fn: Callable[..., Any], + name: Optional[str] = None, + prefix: Optional[str] = None, + named_call: bool = True, + **partial_kwargs, + ) -> Callable[..., Any]: + """Partially applies a child scope to fn. + + When calling the returned function multiple times variables will be reused. + + Args: + fn: the function to partially apply the child Scope to. + name: optional name of the child. + prefix: prefix used for generating name if it is `None`. + named_call: if true, `fn` will be run under `jax.named_scope`. The XLA + profiler will use this to name tag the computation. + **partial_kwargs: additional kwargs partially applied to `fn`. + + Returns: + The function with a partially applied scope. + """ + if name is None: + if prefix is None: + prefix = fn.__name__ + '_' if hasattr(fn, '__name__') else '' + name = self.default_name(prefix) + scope = self.push(name) + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + kwargs = dict(partial_kwargs, **kwargs) + if named_call: + with jax.named_scope(name): + res = fn(scope.rewound(), *args, **kwargs) + else: + res = fn(scope.rewound(), *args, **kwargs) + return res + + return wrapper + + def is_mutable_collection(self, col: str) -> bool: + """Returns true if the collection `col` is mutable.""" + return in_filter(self.mutable, col) + + def is_collection_empty(self, col: str) -> bool: + """Returns true if the collection is empty.""" + if col in self.root._variables: # pylint: disable=protected-access + return not self.root._variables[col] # pylint: disable=protected-access + return True + + def _mutable_collection(self, col: str) -> MutableCollection: + """Returns the collection `col` as a mutable object.""" + assert self.is_mutable_collection(col), f'Collection {col} is not mutable' + + # The actual variable dict is stored in the root scope only, and subscopes + # hold references to subtrees relevant to them. This function ensures that + # the collections are created in the top-level Scope and we return the + # correct reference. + if col not in self._variables: + if not self.parent: + # If this is the top-level Scope, just add an empty collection. + self._variables[col] = {} + else: + assert self.name is not None # Only top-level Scope have name None. + # Populate the parent collections recursively and obtain a reference to + # the direct parent (which, by transitivity, is be a reference to a + # dict in the root Scope). + parent_col = self.parent._mutable_collection(col) # pylint: disable=protected-access + if self.name not in parent_col: + # If this Scope's name does not occur in the parent collection, add it + # to the parent scope (updating the parent's variable dict). + parent_col[self.name] = {} + # Store a reference to the parent's scope collection for in this scope's + # variable dict. + self._variables[col] = parent_col[self.name] + + return self._variables[col] + + def _collection(self, col: str) -> Collection: + """Returns a collection of variables of collection `col`.""" + if col not in self._variables: + if self.parent: + assert self.name is not None + parent_col = self.parent._collection(col) # pylint: disable=protected-access + if self.name not in parent_col: + return FrozenDict() + self._variables[col] = parent_col[self.name] + else: + return FrozenDict() + return self._variables[col] + + def has_rng(self, name: str) -> bool: + """Returns true if a PRNGSequence with name `name` exists.""" + return name in self.rngs + + def make_rng(self, name: str = 'params') -> PRNGKey: + """Generates A PRNGKey from a PRNGSequence with name `name`.""" + if not self.has_rng(name): + if self.has_rng('params'): + name = 'params' + else: + raise errors.InvalidRngError(f'{self.name} needs PRNG for "{name}"') + self._check_valid() + self._validate_trace_level() + self.rng_counters[name] += 1 + return LazyRng.create(self.rngs[name], self.rng_counters[name]).as_jax_rng() + + def get_variable(self, col: str, name: str, default: Any = None) -> Any: + """Retrieves the value of a Variable. + + Args: + col: the variable collection. + name: the name of the variable. + default: the default value to return if the variable does not exist in + this scope. + + Returns: + The value of the input variable, of the default value if the variable + doesn't exist in this scope. + """ + variables = self._collection(col) + if name in variables: + return variables[name] + else: + return default + + def has_variable(self, col: str, name: str) -> bool: + """Returns true if the given variable exists in this scope. + + Args: + col: the collection of the variable. + name: the name of the variable. + """ + variables = self._collection(col) + return name in variables + + def put_variable(self, col: str, name: str, value: Any): + """Updates the value of the given variable if it is mutable, or an error otherwise. + + Args: + col: the collection of the variable. + name: the name of the variable. + value: the new value of the given variable. + """ + self._check_valid() + self._validate_trace_level() + if not self.is_mutable_collection(col): + raise errors.ModifyScopeVariableError(col, name, self.path_text) + variables = self._mutable_collection(col) + + # Make sure reference sharing of child variable dictionaries isn't broken. + # See https://github.com/google/flax/issues/2022 for more details. + def put(target, key, val): + if ( + key in target + and isinstance(target[key], dict) + and isinstance(val, Mapping) + ): + for k, v in val.items(): + put(target[key], k, v) + else: + target[key] = val + + put(variables, name, value) + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[True], + **init_kwargs, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[False], + **init_kwargs, + ) -> Variable[meta.AxisMetadata[T]]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + **init_kwargs, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ... + + def variable( + self, + col: str, + name: str, # pylint: disable=keyword-arg-before-vararg + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + **init_kwargs, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + """Creates a variable if it doesn't exist yet in this scope and returns it. + + Args: + col: the collection of the variable. + name: the name of the variable. + init_fn: a function taking a PRNGKey plus any other number of positional + arguments. If None, the variable must already be initialized otherwise + an error is raised. + *init_args: the positional arguments to evaluate init_fn on lazily. + unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed + value, see ``flax.nn.meta.unbox`` (default: True). + **init_kwargs: the key-word arguments to evaluate init_fn on lazily. + + Returns: + The variable. Throws an error if the variable exists already. + """ + self.reserve(name, col) + if not self.has_variable(col, name): + if not self.is_mutable_collection(col) or init_fn is None: + if self.is_collection_empty(col): + raise errors.ScopeCollectionNotFound(col, name, self.path_text) + raise errors.ScopeVariableNotFoundError(name, col, self.path_text) + init_value = init_fn(*init_args, **init_kwargs) + self.put_variable(col, name, init_value) + # cast to make static analyzers happy + return cast( + Union[Variable[T], Variable[meta.AxisMetadata[T]]], + Variable(self, col, name, unbox=unbox), + ) + + @overload + def param( + self, name: str, init_fn: Callable[..., T], *init_args, + ) -> T: + ... + + @overload + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: Literal[True], + **init_kwargs, + ) -> T: + ... + + @overload + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: Literal[False], + **init_kwargs, + ) -> meta.AxisMetadata[T]: + ... + + @overload + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: bool, + **init_kwargs, + ) -> Union[T, meta.AxisMetadata[T]]: + ... + + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: bool = True, + **init_kwargs, + ) -> Union[T, meta.AxisMetadata[T]]: + """Creates a parameter if it doesn't exist yet in this scope and returns it. + + If the parameter exists already, the existing value is simply returned. + + Args: + name: the name of the parameter. + init_fn: a function taking a PRNGKey plus any other number of positional + arguments. + *init_args: the positional arguments to evaluate init_fn on lazily. + unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed + value, see ``flax.nn.meta.unbox`` (default: True). + **init_kwargs: the key-word arguments to evaluate init_fn on lazily. + + Returns: + The parameters. Throws an error if the params exist already. + """ + self.reserve(name, 'params') + if self.has_variable('params', name): + value = self.get_variable('params', name) + # Validate that the shape of the init_fn output is the same as the shape + # of the existing parameter. This is to make sure that the hparams set up + # in a Flax Module match the shapes coming in during apply, and if not, + # catch it with an error message. + # NOTE: We could consider moving this to `self.` + abs_value = jax.eval_shape( + lambda: init_fn(random.key(0), *init_args, **init_kwargs) + ) + abs_value_flat = jax.tree_util.tree_leaves(abs_value) + value_flat = jax.tree_util.tree_leaves(value) + for val, abs_val in zip(value_flat, abs_value_flat): + # NOTE: We could check dtype consistency here as well but it's + # usefuleness is less obvious. We might intentionally change the dtype + # for inference to a half float type for example. + if jnp.shape(val) != jnp.shape(abs_val): + raise errors.ScopeParamShapeError( + name, self.path_text, jnp.shape(abs_val), jnp.shape(val) + ) + else: + if not self.is_mutable_collection('params'): + if self.is_collection_empty('params'): + raise errors.ScopeCollectionNotFound('params', name, self.path_text) + raise errors.ScopeParamNotFoundError(name, self.path_text) + value = init_fn(self.make_rng('params'), *init_args, **init_kwargs) + self.put_variable('params', name, value) + if unbox: + value = meta.unbox(value) + return value + + def _populate_collections(self): + collections = self.root._variables.keys() # pylint: disable=protected-access + for col in collections: + self._collection(col) + + def has_flag(self, key) -> bool: + return key in self.flags + + def get_flag(self, key, default=no_flag) -> Any: + if key not in self.flags and default is no_flag: + return ValueError(f'Flag {key} not present on scope.') + return self.flags.get(key, default) + + +def _unfreeze_variables(variables, mutable): + new_variables = {} + for key, value in variables.items(): + if in_filter(mutable, key): + new_variables[key] = unfreeze(value) + else: + new_variables[key] = value + return new_variables + + +def bind( + variables: VariableDict, + rngs: Optional[RNGSequences] = None, + mutable: CollectionFilter = False, + flags: Optional[Mapping] = None, +): + """Binds variables and rngs to a new ``Scope``. + + bind provides a ``Scope`` instance without transforming a function with + ``apply``. This is particularly useful for debugging and interactive use cases + like notebooks where a function would limit the ability split up code into + different cells. + + a ``Scope`` instance is a stateful object. Note that idiomatic JAX is + functional and therefore a ``Scope` does not mix well well with vanilla JAX + APIs. Therefore, we recommend using ``apply`` when code should be reusable and + compatible across the JAX software ecosystem. + + Args: + variables: Variable dictionary to bind. + rngs: RNGs to bind. + mutable: Which variable collections to treat as mutable. + flags: internal flags. + + Returns: + A new scope with the variables and rngs bound to it. + """ + if not _is_valid_variables(variables): + raise errors.ApplyScopeInvalidVariablesTypeError() + if rngs is not None and not _is_valid_rngs(rngs): + raise errors.InvalidRngError( + 'rngs should be a dictionary mapping strings to `jax.PRNGKey`.' + ) + new_variables = _unfreeze_variables(variables, mutable) + return Scope(new_variables, rngs=rngs, mutable=mutable, flags=flags) + + +def apply( + fn: Callable[..., Any], + mutable: CollectionFilter = False, + flags: Optional[Mapping] = None, +) -> Callable[..., Any]: + """Functionalize a `Scope` function. + + Args: + fn: a function taking a `Scope` as its first argument. + mutable: the filter determining which variable collections are mutable. + flags: internal flags. + + Returns: + `fn` with the scope partially applied. + """ + + @functools.wraps(fn) + def wrapper( + variables: VariableDict, + *args, + rngs: Optional[Union[PRNGKey, RNGSequences]] = None, + **kwargs, + ) -> Union[Any, Tuple[Any, Union[VariableDict, Dict[str, Any]]]]: + if rngs is not None: + if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): + raise ValueError( + 'The ``rngs`` argument passed to an apply function should be a ' + '``jax.PRNGKey`` or a dictionary mapping strings to ' + '``jax.PRNGKey``.' + ) + if not isinstance(rngs, (dict, FrozenDict)): + rngs = {'params': rngs} + + # Try to detect if user accidentally passed {'params': {'params': ...}. + if ( + 'params' in variables + and isinstance(variables['params'], (dict, FrozenDict)) + and 'params' in variables['params'] + ): + raise errors.ApplyScopeInvalidVariablesStructureError(variables) + + with bind( + variables, rngs=rngs, mutable=mutable, flags=flags + ).temporary() as root: + y = fn(root, *args, **kwargs) + if mutable is not False: + return y, root.mutable_variables() + else: + return y + + return wrapper + + +def init( + fn: Callable[..., Any], + mutable: CollectionFilter = True, + flags: Optional[Mapping] = None, +) -> Callable[..., Any]: + """Functionalize a `Scope` function for initialization. + + Args: + fn: a function taking a `Scope` as its first argument. + mutable: the filter determining which variable collections are mutable. + flags: internal flags. + + Returns: + `fn` with the scope partially applied. + """ + + @functools.wraps(fn) + def wrapper(rngs, *args, **kwargs) -> Tuple[Any, VariableDict]: + if not _is_valid_rng(rngs) and not _is_valid_rngs(rngs): + raise ValueError( + 'First argument passed to an init function should be a ' + '``jax.PRNGKey`` or a dictionary mapping strings to ' + '``jax.PRNGKey``.' + ) + if not isinstance(rngs, (dict, FrozenDict)): + rngs = {'params': rngs} + init_flags = {**(flags if flags is not None else {}), 'initializing': True} + return apply(fn, mutable=mutable, flags=init_flags)( + {}, *args, rngs=rngs, **kwargs + ) + + return wrapper + + +def lazy_init( + fn: Callable[..., Any], + mutable: CollectionFilter = True, + flags: Optional[Mapping] = None, +) -> Callable[..., Any]: + """Functionalizes a `Scope` function for lazy initialization. + + Similair to ``init`` except that the init function now accepts + ``jax.ShapeDtypeStruct`` instances for arguments that do not + affect the variable initialization (typically this is all the input data). + + Example:: + + def f(scope, x): + # the kernel init only uses the shape of x so we don't actually + # need a value for x and can pass it as a ShapeDtypeStruct in lazy_init. + k = scope.param("kernel", nn.initializers.lecun_normal(), (x.shape[-1], x.shape[-1])) + return x @ k + init_fn = lazy_init(f) + variables = init_fn(random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) + + + Args: + fn: a function taking a `Scope` as its first argument. + mutable: the filter determining which variable collections are mutable. + flags: internal flags. + + Returns: + `fn` with the scope partially applied. Unlike ``init`` which returns a tuple of function + output and variables, the lazy init function only returns the variables. + """ + return partial_eval.lazy_init( + lambda *args, **kwargs: init(fn, mutable, flags)(*args, **kwargs)[1] + ) + + +def _is_valid_collection(col: VariableDict): + if not isinstance(col, (FrozenDict, dict)): + return False + for name in col.keys(): + # Any value can be stored in a collection so only keys can be verified. + if not isinstance(name, str): + return False + return True + + +def _is_valid_variables(variables: VariableDict) -> bool: + """Checks whether the given variable dict is valid. + + Args: + variables: A variable dict. + + Returns: + True if `variables` is a valid variable dict. + """ + for name, col in variables.items(): + if not isinstance(name, str): + return False + if not _is_valid_collection(col): + return False + return True + + +def _is_valid_rng(rng: Array): + """Checks whether rng is a valid JAX PRNGKey, also handling custom prngs.""" + # This check is valid for either new-style or old-style PRNG keys + if not isinstance(rng, (np.ndarray, jnp.ndarray)): + return False + + # Handle new-style typed PRNG keys + if hasattr(jax.dtypes, 'prng_key'): # JAX 0.4.14 or newer + if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key): + return rng.shape == () + elif hasattr(jax.random, 'PRNGKeyArray'): # Previous JAX versions + if isinstance(rng, jax.random.PRNGKeyArray): + return rng.shape == () + + # Handle old-style raw PRNG keys + expected_rng = jax.eval_shape( + lambda s: jax.random.key_data(jax.random.key(s)), 0 + ) + if (rng.shape, rng.dtype) != (expected_rng.shape, expected_rng.dtype): + return False + return True + + +def _is_valid_rngs(rngs: Union[PRNGKey, RNGSequences]): + if not isinstance(rngs, (FrozenDict, dict)): + return False + for key, val in rngs.items(): + if not isinstance(key, str): + return False + if not _is_valid_rng(val): + return False + return True diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/module.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/module.py new file mode 100644 index 000000000..f8234eb66 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/module.py @@ -0,0 +1,3198 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Flax Module.""" + +import contextlib +import dataclasses +import enum +import functools +import inspect +import sys +import threading +import typing +import weakref +from types import MappingProxyType +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import jax +import jax.numpy as jnp +import typing_extensions as tpe + +import flax +import flax.linen as nn +from flax import ( + config, + core, + errors, + serialization, + traceback_util, + traverse_util, +) +from flax.core import Scope, meta, partial_eval +from flax.core.frozen_dict import FrozenDict +from flax.core.scope import ( + CollectionFilter, + DenyList, + Variable, + union_filters, +) +from flax.ids import FlaxId, uuid +from flax.linen import kw_only_dataclasses +from flax.typing import ( + RNGSequences, + PRNGKey, + FrozenVariableDict, + VariableDict, +) + +traceback_util.register_exclusion(__file__) + + +T = TypeVar('T') +K = TypeVar('K') +M = TypeVar('M', bound='Module') +_CallableT = TypeVar('_CallableT', bound=Callable) + + +# Used for abstractly testing module behavior. +TestScope = type( + 'TestScope', + (Scope,), + {'make_rng': lambda self, name: jax.random.key(0)}, +) + + +# pylint: disable=protected-access,attribute-defined-outside-init +def _get_fn_name(fn): + if isinstance(fn, functools.partial): + return _get_fn_name(fn.func) + return getattr(fn, '__name__', 'unnamed_function') + + +def _indent(x: str, num_spaces: int): + indent_str = ' ' * num_spaces + lines = x.split('\n') + # skip last line because it is always empty and should not be indented. + assert not lines[-1] + return '\n'.join(indent_str + line for line in lines[:-1]) + '\n' + + +def _attr_repr(value: Any): + if callable(value) and ( + (isinstance(value, nn.Module) and value.__dict__.get('__name__', None)) + or (not isinstance(value, nn.Module) and getattr(value, '__name__', None)) + ): + value_rep = value.__name__ + else: + value_rep = repr(value) + return value_rep + + +def _module_repr(module: 'Module', num_spaces: int = 4): + """Returns a pretty printed representation of the module.""" + cls = type(module) + cls_name = cls.__name__ + rep = '' + + attributes = { + f.name: f.type + for f in dataclasses.fields(cls) + if f.name not in ('parent', 'name') and f.repr + } + child_modules = { + k: v + for k, v in module._state.children.items() # pytype: disable=attribute-error + if isinstance(v, Module) + } + if attributes: + rep += '# attributes\n' + for attr in attributes.keys(): + # TODO(jheek): can we get a nice string representation of attribute types? + value = module.__dict__.get(attr, None) + value_rep = _attr_repr(value) + rep += f'{attr} = {value_rep}\n' + if child_modules: + rep += '# children\n' + for name, child in child_modules.items(): + child_rep = _module_repr(child, num_spaces) + rep += f'{name} = {child_rep}\n' + if rep: + return f'{cls_name}(\n{_indent(rep, num_spaces)})' + else: + return f'{cls_name}()' + + +# Tabulation utilities. +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class _CallInfo: + index: int + path: Tuple[str, ...] + module: 'Module' + rngs: Optional[Dict[str, Union[core.scope.PRNGKey, core.scope.LazyRng]]] + mutable: bool + method: str + args: Tuple[Any, ...] + kwargs: Dict[str, Any] + outputs: Any + + +@dataclasses.dataclass +class _CallInfoContext(threading.local): + index: int + calls: List[_CallInfo] + + def get_call_index(self) -> int: + index = self.index + self.index += 1 + return index + + +@contextlib.contextmanager +def _tabulate_context(): + _context.call_info_stack.append(_CallInfoContext(0, [])) + try: + yield + finally: + _context.call_info_stack.pop() + + +# Track parent relationship across Modules. +# ----------------------------------------------------------------------------- +class _DynamicContext(threading.local): + """Dynamic context.""" + + # TODO(marcvanzee): switch to using contextvars once minimum python version is + # 3.7 + + def __init__(self): + self.module_stack = [ + None, + ] + self.capture_stack = [] + self.call_info_stack: list[_CallInfoContext] = [] + + +# The global context +_context = _DynamicContext() + + +class _Sentinel: + def __copy__(self): + return self # Do not copy singleton sentinel. + + def __deepcopy__(self, memo): + del memo + return self # Do not copy singleton sentinel. + + def __reduce__(self): + return _get_unspecified_parent, () + + +def _get_unspecified_parent(): + return _unspecified_parent + + +_unspecified_parent = _Sentinel() + + +# Enable automatic named_call wrapping for labelling profile traces. +# ----------------------------------------------------------------------------- +_use_named_call = config.flax_profile + + +def _derive_profiling_name(module, fn): + fn_name = _get_fn_name(fn) + method_suffix = f'.{fn_name}' if fn_name != '__call__' else '' + module_name = module.name or module.__class__.__name__ + return f'{module_name}{method_suffix}' + + +def enable_named_call(): + """Enables named call wrapping for labelling profile traces. + + When named call wrapping is enabled all JAX ops executed in a Module + will be run under ``jax.named_scope``. The ``Module`` class name will + show up around the operations belonging to that Module in the + Tensorboard profiling UI, simplifying the profiling process. + + Note that ``jax.named_scope`` only works for + compiled functions (e.g.: using jax.jit or jax.pmap). + """ + global _use_named_call + _use_named_call = True + + +def disable_named_call(): + """Disables named call wrapping. + + See ``enable_named_call`` + """ + global _use_named_call + _use_named_call = False + + +@contextlib.contextmanager +def override_named_call(enable: bool = True): + # pylint: disable=g-doc-return-or-yield + """Returns a context manager that enables/disables named call wrapping. + + Args: + enable: If true, enables named call wrapping for labelling profile traces. + (see ``enabled_named_call``). + """ + # pylint: enable=g-doc-return-or-yield + global _use_named_call + use_named_call_prev = _use_named_call + _use_named_call = enable + try: + yield + finally: + _use_named_call = use_named_call_prev + + +# Intercept module methods. +# ----------------------------------------------------------------------------- +@dataclasses.dataclass(frozen=True) +class InterceptorContext: + """Read only state showing the calling context for method interceptors. + + Attributes: + module: The Module instance whose method is being called. + method_name: The name of the method being called on the module. + orig_method: The original method defined on the module. Calling it will + short circuit all other interceptors. + """ + + module: 'Module' + method_name: str + orig_method: Callable[..., Any] + + +class ThreadLocalStack(threading.local): + """Thread-local stack.""" + + def __init__(self): + self._storage = [] + + def push(self, elem: Any) -> None: + self._storage.append(elem) + + def pop(self) -> Any: + return self._storage.pop() + + def __iter__(self) -> Iterator[Any]: + return iter(reversed(self._storage)) + + def __len__(self) -> int: + return len(self._storage) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self._storage})' + + +Args = Tuple[Any] +Kwargs = Dict[str, Any] +NextGetter = Callable[..., Any] +Interceptor = Callable[[NextGetter, Args, Kwargs, InterceptorContext], Any] +_global_interceptor_stack = ThreadLocalStack() + + +@contextlib.contextmanager +def intercept_methods(interceptor: Interceptor): + # pylint: disable=g-doc-return-or-yield + r"""Registers a new method interceptor. + + Method interceptors allow you to (at a distance) intercept method calls to + modules. It works similarly to decorators. You could modify args/kwargs before + calling the underlying method and/or modify the result returning from calling + the underlying method. Or you could completely skip calling the underlying + method and decide to do something differently. For example:: + + >>> import flax.linen as nn + >>> import jax.numpy as jnp + ... + >>> class Foo(nn.Module): + ... def __call__(self, x): + ... return x + ... + >>> def my_interceptor1(next_fun, args, kwargs, context): + ... print('calling my_interceptor1') + ... return next_fun(*args, **kwargs) + ... + >>> foo = Foo() + >>> with nn.intercept_methods(my_interceptor1): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + + You could also register multiple interceptors on the same method. Interceptors + will run in order. For example:: + + >>> def my_interceptor2(next_fun, args, kwargs, context): + ... print('calling my_interceptor2') + ... return next_fun(*args, **kwargs) + ... + >>> with nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor1 + calling my_interceptor2 + + You could skip other interceptors by directly calling the + ``context.orig_method``. For example:: + + >>> def my_interceptor3(next_fun, args, kwargs, context): + ... print('calling my_interceptor3') + ... return context.orig_method(*args, **kwargs) + >>> with nn.intercept_methods(my_interceptor3), \ + ... nn.intercept_methods(my_interceptor1), \ + ... nn.intercept_methods(my_interceptor2): + ... _ = foo(jnp.ones([1])) + calling my_interceptor3 + + The following methods couldn't be intercepted: + + 1. Methods decoratored with ``nn.nowrap``. + 2. Dunder methods including ``__eq__``, ``__repr__``, ``__init__``, ``__hash__``, and ``__post_init__``. + 3. Module dataclass fields. + 4. Module descriptors. + + Args: + interceptor: A method interceptor. + """ + _global_interceptor_stack.push(interceptor) + try: + yield + finally: + assert _global_interceptor_stack.pop() is interceptor + + +def run_interceptors( + orig_method: Callable[..., Any], + module: 'Module', + *args, + **kwargs, +) -> Any: + """Runs method interceptors.""" + method_name = _get_fn_name(orig_method) + fun = functools.partial(orig_method, module) + context = InterceptorContext(module, method_name, fun) + + def wrap_interceptor(interceptor, fun): + """Wraps `fun` with `interceptor`.""" + + @functools.wraps(fun) + def wrapped(*args, **kwargs): + return interceptor(fun, args, kwargs, context) + + return wrapped + + # Wraps interceptors around the original method. The innermost interceptor is + # the last one added and directly wrapped around the original bound method. + for interceptor in _global_interceptor_stack: + fun = wrap_interceptor(interceptor, fun) + return fun(*args, **kwargs) + + +# Utilities for pytrees of Modules defined inside setup() +# ----------------------------------------------------------------------------- + + +def _sorted_items(x): + """Returns items of a dict ordered by keys.""" + return sorted(x.items(), key=lambda x: x[0]) + + +def _get_suffix_value_pairs( + tree_or_leaf: Any, +) -> List[Tuple[str, Type['Module']]]: + """Helper for naming pytrees of submodules.""" + dict_or_leaf = serialization.to_state_dict(tree_or_leaf) + if not isinstance(dict_or_leaf, dict) or not dict_or_leaf: + return [('', tree_or_leaf)] + else: + flat_dict = traverse_util.flatten_dict(dict_or_leaf) + return [('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict)] + + +def _map_over_modules_in_tree(fn, tree_or_leaf): + """Helper for mapping function over submodules.""" + dict_or_leaf = serialization.to_state_dict(tree_or_leaf) + if not isinstance(dict_or_leaf, dict) or not dict_or_leaf: + return fn('', tree_or_leaf) + else: + flat_dict = traverse_util.flatten_dict(dict_or_leaf, keep_empty_nodes=True) + mapped_flat_dict = { + k: fn('_' + '_'.join(k), v) for k, v in _sorted_items(flat_dict) + } + return serialization.from_state_dict( + tree_or_leaf, traverse_util.unflatten_dict(mapped_flat_dict) + ) + + +def _freeze_attr(val: Any) -> Any: + """Recursively wrap the given attribute `var` in ``FrozenDict``.""" + if isinstance(val, (dict, FrozenDict)): + return FrozenDict({k: _freeze_attr(v) for k, v in val.items()}) + elif isinstance(val, tuple): + # Special case namedtuples and special JAX tuple structures otherwise they + # would be downgraded to normal tuples. + if hasattr(val, '_fields') or type(val).__name__ == 'PartitionSpec': + return type(val)(*[_freeze_attr(v) for v in val]) + else: + return tuple(_freeze_attr(v) for v in val) + elif isinstance(val, list): + return tuple(_freeze_attr(v) for v in val) + else: + return val + + +# Method wrapping of "compact methods" and setup() +# ----------------------------------------------------------------------------- +def compact(fun: _CallableT) -> _CallableT: + """Marks the given module method allowing inlined submodules. + + Methods wrapped in @compact can define submodules directly within the method. + + For instance:: + + >>> import flax.linen as nn + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x, features): + ... x = nn.Dense(features)(x) + ... ... + ... return x + + At most one method in each Module may be wrapped with @compact. + + Args: + fun: The Module method to mark as compact. + + Returns: + The given function ``fun`` marked as compact. + """ + fun.compact = True # type: ignore[attr-defined] + return fun + + +def nowrap(fun: _CallableT) -> _CallableT: + """Marks the given module method as a helper method that needn't be wrapped. + + Methods wrapped in ``@nowrap`` are private helper methods that needn't be wrapped + with the state handler or a separate named_call transform. + + This is needed in several concrete instances: + - if you're subclassing a method like Module.param and don't want this + overriden core function decorated with the state management wrapper. + - If you want a method to be callable from an unbound Module (e.g.: a + function of construction of arguments that doesn't depend on params/RNGs). + If you want to learn more about how Flax Modules manage their state read the + [The Flax Module lifecycle](https://flax.readthedocs.io/en/latest/developer_notes/module_lifecycle.html) + guide. + + For instance:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class Foo(nn.Module): + ... num_features: int + + ... @nn.nowrap + ... def _make_dense(self, num_features): + ... return nn.Dense(num_features) + + ... @nn.compact + ... def __call__(self, x): + ... # now safe to use constructor helper even if using named_call + ... dense = self._make_dense(self.num_features) + ... return dense(x) + + Args: + fun: The Module method to mark as nowrap. + + Returns: + The given function ``fun`` marked as nowrap. + """ + fun.nowrap = True # type: ignore[attr-defined] + return fun + + +def compact_name_scope(fun: _CallableT) -> _CallableT: + """Creates compact submodules from a method. + + This is a decorator that allows you to define compact submodules from a + method. It's intention is to make it easier to port code Haiku code to Flax + by providing the same functionality. + + Example:: + + >>> import flax.linen as nn + >>> import jax + >>> import jax.numpy as jnp + >>> from flax.core import pretty_repr + ... + >>> class Foo(nn.Module): + ... @nn.compact_name_scope + ... def up(self, x): + ... return nn.Dense(3)(x) + ... + ... @nn.compact_name_scope + ... def down(self, x): + ... return nn.Dense(3)(x) + ... + ... def __call__(self, x): + ... return self.up(x) + self.down(x) + ... + >>> module = Foo() + >>> variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 2))) + >>> params = variables['params'] + >>> print(pretty_repr(jax.tree_util.tree_map(jnp.shape, params))) + { + down: { + Dense_0: { + bias: (3,), + kernel: (2, 3), + }, + }, + up: { + Dense_0: { + bias: (3,), + kernel: (2, 3), + }, + }, + } + + You can also use ``compact_name_scope`` inside ``@compact`` methods or even + other + ``compact_name_scope`` methods. Methods that are decorated with + ``compact_name_scope`` + can also be called directly from ``init`` or ``apply`` via the ``method`` + argument:: + + >>> y_down = module.apply({'params': params}, jnp.ones((1, 2)), method='down') + >>> y_down.shape + (1, 3) + + Args: + fun: The Module method to mark as compact_name_scope. + + Returns: + The given function ``fun`` marked as compact_name_scope. + """ + + @functools.wraps(fun) + def compact_name_scope_wrapper(self: nn.Module, *args, **kwargs): + name = fun.__name__ + if not hasattr(self, '_compact_name_scope_modules'): + raise ValueError( + f'Cannot call compact_name_scope method {name!r} on a Module that has not been ' + f'setup. This is likely because you are calling {name!r} ' + 'from outside of init or apply.' + ) + module = self._compact_name_scope_modules[name] + return module(*args, **kwargs) + + compact_name_scope_wrapper.compact_name_scope = True # type: ignore[attr-defined] + compact_name_scope_wrapper.inner_fun = fun # type: ignore[attr-defined] + compact_name_scope_wrapper.nowrap = True # type: ignore[attr-defined] + return compact_name_scope_wrapper # type: ignore[return-value] + + +def _get_local_method_names( + cls: Any, exclude: Iterable[str] = () +) -> Tuple[str, ...]: + """Gets method names of a class, excluding class and static methods. + + Args: + cls: The class to get method names for. + exclude: Names to exclude from output. + + Returns: + A list of method names. + """ + true_methods = set() + for m in cls.__dict__: + if callable(cls.__dict__[m]) and not inspect.isclass( + cls.__dict__[m] + ): # pytype: disable=not-supported-yet + mtype = type(cls.__dict__[m]) + if mtype != staticmethod and mtype != classmethod: + true_methods.add(m) + return tuple(true_methods.difference(set(exclude))) + + +def _get_local_descriptor_names( + cls: Any, exclude: Iterable[str] = () +) -> Tuple[str, ...]: + """Gets descriptor names of a class. + + Args: + cls: The class to get property names for. + exclude: Names to exclude from output. + + Returns: + A list of property names. + """ + true_properties = set() + for m, attr in cls.__dict__.items(): + if not callable(attr) and ( + hasattr(attr, '__get__') + or hasattr(attr, '__set__') + or hasattr(attr, '__delete__') + ): + mtype = type(attr) + if mtype != staticmethod and mtype != classmethod: + true_properties.add(m) + return tuple(true_properties.difference(set(exclude))) + + +def wrap_method_once(fun: Callable[..., Any]) -> Callable[..., Any]: + """Manages Module state for a given user-defined method. + + Args: + fun: User-defined Module method to manage state for. + + Returns: + Wrapped method. + """ + # Don't rewrap methods that have already had the state management wrapper + # applied in the decorator stack. This wrapper should always be applied + # before transformation wrappers. + if hasattr(fun, 'method_handler_wrapped'): + return fun + + @functools.wraps(fun) + def wrapped_module_method(*args, **kwargs): + # We might have incorrectly wrappped a callable + # that is not a method. Check whether the first arg is self, + # otherwise call the wrapped function as is. + if args and isinstance(args[0], Module): + self, args = args[0], args[1:] + return self._call_wrapped_method(fun, args, kwargs) + else: + return fun(*args, **kwargs) + + wrapped_module_method.method_handler_wrapped = True # type: ignore[attr-defined] + return wrapped_module_method + + +def wrap_descriptor_once(descriptor) -> 'DescriptorWrapper': + """Wraps a descriptor to give better error messages. + + Args: + descriptor: User-defined Module attribute descriptor. + + Returns: + Wrapped descriptor. + """ + # Don't rewrap descriptors. + if isinstance(descriptor, DescriptorWrapper): + return descriptor + + return create_descriptor_wrapper(descriptor) + + +def _wrap_hash(hash_fn: Callable[..., Any]) -> Callable[..., Any]: + """Wraps a hash function with some check for Flax Modules.""" + + @functools.wraps(hash_fn) + def wrapped(self): + if self.scope is not None: + raise TypeError("Can't call __hash__ on modules that hold variables.") + try: + hash_value = hash_fn(self) + except TypeError as exc: + raise TypeError( + 'Failed to hash Flax Module. ' + 'The module probably contains unhashable attributes. ' + f'Module={self}' + ) from exc + return hash_value + + return wrapped + + +def _get_unbound_fn(method_or_fn: Callable[..., Any]) -> Callable[..., Any]: + """Returns an unbound function from a method that is possibly bound. + + This means that if the passed function belongs of an instance of a class, then + the returned function does no longer depend on the instance, which is passed + as the first argument to the function. + + Args: + method_or_fn: A class method or function. + + Returns: + An unbound version of input function. + """ + if inspect.ismethod(method_or_fn) and isinstance( + method_or_fn.__self__, Module + ): # pytype: disable=attribute-error + method_or_fn = method_or_fn.__func__ # pytype: disable=attribute-error + + # The method should be callable, and it should have at least one argument + # representing the class that is passed in. + if ( + not callable(method_or_fn) + or len(inspect.signature(method_or_fn).parameters) < 1 + ): + raise errors.ApplyModuleInvalidMethodError(method_or_fn) + + return method_or_fn + + +def _map_submodules(fn: Callable[['Module'], Any], tree): + """Map a function over all submodules in a tree.""" + g = lambda _, x: fn(x) if isinstance(x, Module) else x + return _freeze_attr(_map_over_modules_in_tree(g, tree)) + + +class SetupState(enum.IntEnum): + # setup() has not been called. + NEW = 0 + # setup() has been called outside a transform boundary. + TRANSFORMED = 1 + # setup() has been called. + DONE = 2 + + +@dataclasses.dataclass +class _ModuleInternalState: + """Ephemeral Module Evaluation State. + + For clarity, we collect all of the temporary flags and ephemeral state used by + Modules for autonaming and error messages here, alongside the rules used + to pass this ephemeral state across transform boundaries. + """ + + in_compact_method: bool = False + in_setup: bool = False + setup_called: SetupState = SetupState.NEW + is_initialized: bool = False + autoname_cursor: Dict[str, int] = dataclasses.field(default_factory=dict) + children: Dict[str, Union[str, 'Module']] = dataclasses.field( + default_factory=dict + ) + + def reset(self) -> None: + """Resets transient state. + + This function is called after each module method, so only attributes that + are method-dependent are reset. + """ + self.in_compact_method = False + self.in_setup = False + self.autoname_cursor = dict() + + def export(self) -> '_ModuleInternalState': + """Exports transform-preserved state across transform boundary.""" + setup_state = ( + SetupState.TRANSFORMED if self.setup_called else SetupState.NEW + ) + cloned = _ModuleInternalState( + in_compact_method=self.in_compact_method, + in_setup=self.in_setup, + setup_called=setup_state, + is_initialized=self.is_initialized, + autoname_cursor=dict(self.autoname_cursor), + ) + return cloned + + def reimport(self, other: '_ModuleInternalState') -> None: + """Re-imports transform-preserved state from across transform boundary.""" + self.in_compact_method = other.in_compact_method + self.in_setup = other.in_setup + self.is_initialized = other.is_initialized + self.autoname_cursor = dict(other.autoname_cursor) + + +_uninitialized_module_internal_state = _ModuleInternalState() + + +_UNDEFINED_COPY_PICKLE_METHODS = ( + '__getstate__', + '__setstate__', + '__getnewargs_ex__', + '__reduce__', + '__reduce_ex__', + '__copy__', + '__deepcopy__', +) + + +_caches: 'weakref.WeakKeyDictionary[Scope, weakref.WeakValueDictionary[FlaxId, Module]]' = weakref.WeakKeyDictionary() + + +tuple_reduce = lambda xs, x: xs + (x,) +tuple_init = lambda: () + + +capture_call_intermediates = lambda _, method_name: method_name == '__call__' + + +class ParentDescriptor: + """Wraps parent module references in weak refs. + + This prevents reference cycles from forming via parent links which can lead + to accidental OOMs in eager mode due to slow garbage collection as well as + spurious tracer leaks during jit compilation. + + Note: "descriptors" are the underlying python mechanism for implementing + dynamic @property decorators. We need to use a raw descriptor instead of the + more common decorator in order to force that the appropriate getter/setter + logic applies in subclasses even after various dataclass transforms. + """ + + def __get__(self, obj, objtype=None): + # check if obj is None, happens during %autoreload + if obj is None: + return None + parent = object.__getattribute__(obj, '_parent_ref') + return parent() if isinstance(parent, weakref.ReferenceType) else parent + + def __set__(self, obj, value): + maybe_weak = weakref.ref(value) if isinstance(value, Module) else value + object.__setattr__(obj, '_parent_ref', maybe_weak) + + +class Descriptor(tpe.Protocol): + __isabstractmethod__: bool + + def __get__(self, obj, objtype=None) -> Any: + ... + + def __set__(self, obj, value) -> None: + ... + + def __delete__(self, obj) -> None: + ... + + def __set_name__(self, owner, name) -> None: + ... + + +class DescriptorWrapper: + pass + + +def create_descriptor_wrapper(descriptor: Descriptor): + """Creates a descriptor wrapper that calls a get_fn on the descriptor.""" + + class _DescriptorWrapper(DescriptorWrapper): + """A descriptor that can wrap any descriptor.""" + + if hasattr(descriptor, '__isabstractmethod__'): + __isabstractmethod__ = descriptor.__isabstractmethod__ + + def __init__(self, wrapped: Descriptor): + self.wrapped = wrapped + + # conditionally define descriptor methods + if hasattr(descriptor, '__get__'): + + def __get__(self, *args, **kwargs): + # here we will catch internal AttributeError and re-raise it as a + # more informative and correct error message. + try: + return self.wrapped.__get__(*args, **kwargs) + except AttributeError as e: + raise errors.DescriptorAttributeError() from e + + if hasattr(descriptor, '__set__'): + + def __set__(self, *args, **kwargs): + return self.wrapped.__set__(*args, **kwargs) + + if hasattr(descriptor, '__delete__'): + + def __delete__(self, *args, **kwargs): + return self.wrapped.__delete__(*args, **kwargs) + + if hasattr(descriptor, '__set_name__'): + + def __set_name__(self, *args, **kwargs): + self.wrapped.__set_name__(*args, **kwargs) + + def __getattr__(self, name): + if 'wrapped' not in vars(self): + raise AttributeError() + return getattr(self.wrapped, name) + + return _DescriptorWrapper(descriptor) + + +# Base Module definition. +# ----------------------------------------------------------------------------- + + +def module_field(*, kw_only: bool = False, default: Optional[Any] = ...) -> Any: + ... + + +# The ModuleBase class is created only to make static analyzers happy +# mainly pytype and pyright. Some notes: +# * pyright (correctly) complains that Module itself is not a dataclass, even +# though all its subclasses and intances ARE dataclasses. Because there is no +# way to annotate this in a way that pyright understands, we create a +# ModuleBase class decorated with `dataclass_transform` such that pyright +# thinks Module is a dataclass (in reality only subclasses are instantiated +# so this is fine). +# * The `__dataclass_fields__` attribute is needed because pytype seems to +# not understand the `dataclass_transform` decorator, therefore we need +# to add the attribute manually. +# * Other attributes are annotated for completeness. Because we are using +# the `if typing.TYPE_CHECKING` pattern, these annotations are not present +# at runtime so they don't affect the dataclass behavior. +@tpe.dataclass_transform(field_specifiers=(module_field,)) # type: ignore[literal-required] +class ModuleBase: + if typing.TYPE_CHECKING: + scope: Optional[Scope] + _state: _ModuleInternalState + _parent_ref: Union['Module', weakref.ReferenceType['Module'], None] + __dataclass_fields__: Dict[str, dataclasses.Field] + + +class Module(ModuleBase): + """Base class for all neural network modules. + + Layers and models should subclass this class. + + All Flax Modules are Python 3.7 + `dataclasses `_. Since + dataclasses take over ``__init__``, you should instead override :meth:`setup`, + which is automatically called to initialize the module. + + Modules can contain submodules, and in this way can be nested in a tree + structure. Submodels can be assigned as regular attributes inside the + :meth:`setup` method. + + You can define arbitrary "forward pass" methods on your Module subclass. + While no methods are special-cased, ``__call__`` is a popular choice because + it allows you to use module instances as if they are functions:: + + >>> from flax import linen as nn + >>> from typing import Tuple + + >>> class Module(nn.Module): + ... features: Tuple[int, ...] = (16, 4) + + ... def setup(self): + ... self.dense1 = nn.Dense(self.features[0]) + ... self.dense2 = nn.Dense(self.features[1]) + + ... def __call__(self, x): + ... return self.dense2(nn.relu(self.dense1(x))) + + Optionally, for more concise module implementations where submodules + definitions are co-located with their usage, you can use the + :meth:`compact` wrapper. + """ + + if typing.TYPE_CHECKING: + name: Optional[str] = module_field(kw_only=True, default=None) + parent: Union['Module', _Sentinel, None] = module_field( + kw_only=True, default=None + ) + + def __init__(self, *args, **kwargs): + # this stub makes sure pytype accepts constructor arguments. + pass + + def __call__(self, *args, **kwargs) -> Any: + # this stub allows pytype to accept Modules as Callables. + pass + + @classmethod + def __init_subclass__(cls, kw_only: bool = False, **kwargs: Any) -> None: + """Automatically initializes all subclasses as custom dataclasses.""" + super().__init_subclass__(**kwargs) + # All Flax Modules are dataclasses. We force this convention since + # it encourages the stateless behavior needed to clone module instances for + # functional transformation. Instead of using a python metaclass, we + # automatically transform Modules into dataclasses at subclass creation + # time, and we set the last dataclass arguments to `parent` and `name`. + cls._customized_dataclass_transform(kw_only) + # We wrap user-defined methods including setup and __call__ to enforce + # a number of different checks and to provide clear error messages. + cls._verify_single_or_no_compact() + cls._find_compact_name_scope_methods() + cls._wrap_module_attributes() + # Set empty class defaults. + cls._state = _uninitialized_module_internal_state # type: ignore[attr-defined] + cls.scope: Optional[Scope] = None # type: ignore + # Handles weak referencing of parent Modules to prevent reference cycles. + cls._parent_ref = None # type: ignore[attr-defined] + cls.parent = ParentDescriptor() # type: ignore[assignment] + + @classmethod + def _customized_dataclass_transform(cls, kw_only: bool): + """Transforms `cls` into a dataclass, with custom additional behavior. + + 1. Inject `parent` and `name` fields. (If they are already present, + then check that they have the expected types.) + 2. Set compare, hash, and repr to False for non-init fields. + 3. Generate a hash function (if not provided by cls). + """ + # Check reserved attributes have expected type annotations. + annotations = dict(cls.__dict__.get('__annotations__', {})) + if annotations.get('parent', _ParentType) != _ParentType: + raise errors.ReservedModuleAttributeError(annotations) + if annotations.get('name', str) not in ('str', str, Optional[str]): + raise errors.ReservedModuleAttributeError(annotations) + + # any non-init field will only be set in setup + # During __hash__ and __eq__ the field is not set yet + # so it should not be used in compare, hash or repr. + for field in annotations: + field_meta = getattr(cls, field, None) + if isinstance(field_meta, dataclasses.Field) and not field_meta.init: + field_meta.compare = False + field_meta.hash = False + field_meta.repr = False + + extra_fields = [ + ( + 'parent', + _ParentType, + kw_only_dataclasses.field( + repr=False, default=_unspecified_parent, kw_only=True + ), + ), + ( + 'name', + Optional[str], + kw_only_dataclasses.field(default=None, kw_only=True), + ), + ] + + if kw_only: + if tuple(sys.version_info)[:3] >= (3, 10, 0): + for ( + name, + annotation, # pytype: disable=invalid-annotation + default, + ) in extra_fields: + setattr(cls, name, default) + cls.__annotations__[name] = annotation + dataclasses.dataclass( # type: ignore[call-overload] + unsafe_hash='__hash__' not in cls.__dict__, + repr=False, + kw_only=True, + )(cls) + else: + raise TypeError('`kw_only` is not available before Py 3.10.') + else: + # Now apply dataclass transform (which operates in-place). + # Do generate a hash function only if not provided by the class. + kw_only_dataclasses.dataclass( + cls, + unsafe_hash='__hash__' not in cls.__dict__, + repr=False, + extra_fields=extra_fields, + ) # pytype: disable=wrong-keyword-args + + cls.__hash__ = _wrap_hash(cls.__hash__) # type: ignore[method-assign] + + @classmethod + def _verify_single_or_no_compact(cls): + """Statically verifies that at most a single method is labelled compact.""" + methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)] + n_compact_fns = len( + [ + method_name + for method_name in methods + if hasattr(getattr(cls, method_name), 'compact') + ] + ) + if n_compact_fns > 1: + raise errors.MultipleMethodsCompactError() + + @classmethod + def _find_compact_name_scope_methods(cls): + """Finds all compact_name_scope methods in the class.""" + methods = [m[0] for m in inspect.getmembers(cls, predicate=callable)] + compact_name_scope_fns = tuple( + method_name + for method_name in methods + if hasattr(getattr(cls, method_name), 'compact_name_scope') + ) + cls._compact_name_scope_methods = compact_name_scope_fns + + @classmethod + def _wrap_module_attributes(cls): + """Wraps user-defined non-inherited methods and descriptors with state + + management functions. + """ + # wrap methods + method_exclusions = [f.name for f in dataclasses.fields(cls)] + [ + '__eq__', + '__repr__', + '__init__', + '__hash__', + '__post_init__', + ] + for key in _get_local_method_names(cls, exclude=method_exclusions): + method = getattr(cls, key) + if hasattr(method, 'nowrap'): + continue + setattr(cls, key, wrap_method_once(method)) + + # wrap descriptors + descriptor_exclusions = [f.name for f in dataclasses.fields(cls)] + [ + 'parent', + '__dict__', + ] + for key in _get_local_descriptor_names(cls, descriptor_exclusions): + # don't use getattr here, since it will call the descriptor + descriptor = cls.__dict__[key] + if hasattr(descriptor, 'nowrap'): + continue + setattr(cls, key, wrap_descriptor_once(descriptor)) + return cls + + def _call_wrapped_method(self, fun, args, kwargs): + """Calls a wrapped method. + + This function is responsible for setting up the thread local state + correctly before calling the method and cleaning up afterwards. + This includes storing intermediates, setup of the compact scope, + and making sure setup is called before any other method. + + Args: + fun: The wrapped method. + args: Named arguments passed to ``fun``. + kwargs: Keyword arguments passed to ``fun``. + + Returns: + The results of calling ``fun``. + """ + is_compact_method = hasattr(fun, 'compact') + fun_name = _get_fn_name(fun) + is_setup_method = fun_name == 'setup' + add_call_info = not is_setup_method and len(_context.call_info_stack) > 0 + # We lazily call setup() only when needed. + if is_setup_method: + if self.scope is None: + raise errors.CallSetupUnboundModuleError() + is_recurrent = self._state.in_setup + self._state.in_setup = True + else: + self._try_setup() + + if is_compact_method: + if self.scope is None: + raise errors.CallCompactUnboundModuleError() + is_recurrent = self._state.in_compact_method + self._state.in_compact_method = True + _context.module_stack.append(self) + try: + # get call info + if add_call_info: + assert self.scope is not None + call_index = _context.call_info_stack[-1].get_call_index() + + if _global_interceptor_stack: + run_fun = functools.partial(run_interceptors, fun) + else: + run_fun = fun + + # call method + if _use_named_call: + with jax.named_scope(_derive_profiling_name(self, fun)): + y = run_fun(self, *args, **kwargs) + else: + y = run_fun(self, *args, **kwargs) + + if _context.capture_stack: + filter_fn = _context.capture_stack[-1] + if filter_fn and filter_fn(self, fun_name): + self.sow('intermediates', fun_name, y) + if add_call_info: + _args, _kwargs, _y = flax.linen.summary._represent_tree( + (args, kwargs, y) + ) + _context.call_info_stack[-1].calls.append( + _CallInfo( + call_index, + self.path, + self.clone(), + self.scope.rngs, + self.scope.mutable, + fun.__name__, + _args, + _kwargs, + _y, + ) + ) + return y + finally: + _context.module_stack.pop() + if is_compact_method: + object.__setattr__(self, 'scope', self.scope.rewound()) + # setup or compact calls can be recurrent for example due to super calls + # resetting the state would cause is compact/setup method + # to be set to False prematurely. + if (is_compact_method or is_setup_method) and not is_recurrent: + self._state.reset() + + def __setattr__(self, name: str, val: Any): + """Sets an attribute on this Module. + + We overload setattr solely to support pythonic naming via assignment of + submodules in the special :meth:`setup` function:: + + self.submodule_name = MyModule(...) + + We also support lists and other general pytrees, e.g.:: + + self.submodules = [MyModule0(..), MyModule1(..), ...] + + Args: + name: Attribute to set. + val: Value of the attribute. + """ + fields = self.__dataclass_fields__ # pytype: disable=attribute-error + is_dataclass_attr = name in fields and fields[name].init + + if not self._state.in_setup: + if not self._state.is_initialized: + # Setting attributes before end of Module.__post_init__() + object.__setattr__(self, name, val) + return + else: + # We're past all initialization and setup logic: + # Raises a TypeError just like frozen python dataclasses. + raise errors.SetAttributeFrozenModuleError( + self.__class__.__name__, name, val + ) + + # We're inside the setup() method: + if is_dataclass_attr: + # These names are specified as dataclass fields. They should not be + # initialized within the setup() method, but can be modified freely + # before it. + raise errors.SetAttributeInModuleSetupError() + + # Values (that may be variables or submodules) are being defined and + # attached in setup(), we run some extra logic in that case. + self._register_submodules(name, val) + + def __getattr__(self, name: str) -> Any: + """Call setup() before getting any setup-defined attributes.""" + # We don't want to return anything for python copy / pickle methods. + if name in _UNDEFINED_COPY_PICKLE_METHODS: + raise AttributeError() + self._try_setup() + if name in self.__dict__: + return self.__dict__[name] + else: + msg = f'"{self.__class__.__name__}" object has no attribute "{name}".' + if self.scope is None: + msg += ( + f' If "{name}" is defined in \'.setup()\', remember these fields ' + "are only accessible from inside 'init' or 'apply'." + ) + raise AttributeError(msg) + + def __dir__(self) -> List[str]: + """Call setup() before listing attributes.""" + self._try_setup() + return object.__dir__(self) # type: ignore + + def __post_init__(self) -> None: + # DO NOT REMOVE - Marker for internal logging. + # In dataclasses, __init__ is overridden to process dataclass arguments, + # and __post_init__ is called immediately afterwards. Here, depending on the + # type of `parent` passed to initialize the Module, we either defer + # initialization, attach this Module as a submodule of a parent, or bind + # this Module at the top-level to variables and rngs. + + object.__setattr__(self, '_id', uuid()) + object.__setattr__(self, '_state', _ModuleInternalState()) + + # Typically we set the parent based on the dynamic module context. + if self.parent is _unspecified_parent: # pytype: disable=attribute-error + object.__setattr__(self, 'parent', _context.module_stack[-1]) + + # Initialization is deferred for top level Modules or any other "orphan" + # Modules until attachment by __setattr__ i.e. MyModule(..., parent=None) + if self.parent is None: + return + + # Register submodule on parent Module. + if isinstance(self.parent, Module): + # When initializing an unnamed Module inside setup() + # initialization is deferred until attachment by __setattr__ + # i.e. self.mymodule = MyModule(...) + self.name: Optional[str] + if ( + self.parent._state.in_setup and self.name is None + ): # pytype: disable=attribute-error + return + if not self.parent._initialization_allowed: + raise errors.AssignSubModuleError(self.__class__.__name__) + # Autonaming of submodules. + if self.name is None: # pytype: disable=attribute-error + prefix = f'{self.__class__.__name__}' + cursor = self.parent._state.autoname_cursor.get(prefix, 0) + self.name = f'{prefix}_{cursor}' + self.parent._state.autoname_cursor[prefix] = cursor + 1 + # Allow scope aliasing under transforms for submodules defined in setup. + reuse_scopes = ( + self.parent._state.in_setup + and self.parent._state.setup_called == SetupState.TRANSFORMED + ) + # Perform name-collision check. + if self.parent._name_taken(self.name, reuse_scopes=reuse_scopes): + parent_class = self.parent.__class__.__name__ + raise errors.NameInUseError('submodule', self.name, parent_class) + # Finalize attachment to parent and scope initialization. + self.parent._state.children[self.name] = self + assert self.parent.scope is not None + object.__setattr__( + self, 'scope', self.parent.scope.push(self.name, reuse=reuse_scopes) + ) + + # Top-level invocation with a functional Scope. + elif isinstance(self.parent, Scope): + object.__setattr__(self, 'scope', self.parent) + else: + raise ValueError('parent must be None, Module or Scope') + + # eagerly bind submodules if scope is available + if self.scope is not None: + for field in dataclasses.fields(self): + if field.name not in ('parent', 'name') and field.init: + self._register_submodules(field.name, getattr(self, field.name)) + + self._state.is_initialized = True + + def __repr__(self) -> str: + return _module_repr(self) + + def setup(self) -> None: + """Initializes a Module lazily (similar to a lazy ``__init__``). + + ``setup`` is called once lazily on a module instance when a module + is bound, immediately before any other methods like ``__call__`` are + invoked, or before a ``setup``-defined attribute on ``self`` is accessed. + + This can happen in three cases: + + 1. Immediately when invoking :meth:`apply`, :meth:`init` or + :meth:`init_and_output`. + + 2. Once the module is given a name by being assigned to an attribute of + another module inside the other module's ``setup`` method + (see :meth:`__setattr__`):: + + >>> class MyModule(nn.Module): + ... def setup(self): + ... submodule = nn.Conv(...) + + ... # Accessing `submodule` attributes does not yet work here. + + ... # The following line invokes `self.__setattr__`, which gives + ... # `submodule` the name "conv1". + ... self.conv1 = submodule + + ... # Accessing `submodule` attributes or methods is now safe and + ... # either causes setup() to be called once. + + 3. Once a module is constructed inside a method wrapped with + :meth:`compact`, immediately before another method is called or + ``setup`` defined attribute is accessed. + """ + pass + + def _register_submodules(self, name, val): + """Registers a submodule.""" + assert self.scope, 'Trying to register submodules on unbound scope.' + root = self.scope.root + cache = _caches.get(root, weakref.WeakValueDictionary()) + _caches[root] = cache + queue = [] + preserve_adopted_names = config.flax_preserve_adopted_names + if hasattr(type(self), 'preserve_adopted_names'): + preserve_adopted_names = type(self).preserve_adopted_names + + def adopt_attr_modules(cache, queue, suffix, subvalue): + if isinstance(subvalue, Module): + current_name = subvalue.name + adopted_name = None + if subvalue.parent is None: + # Preserve sharing-by-reference relationships during adoption + # via cache keyed on unique instance ids. + key = subvalue._id + # Module was passed from outside. It needs to be cloned. + # Outside modules are named by attachment, not an outer name, + # UNLESS we're using new adopted name policy, in which case an existing + # name will be used, as is often supplied by config systems. + if preserve_adopted_names: + adopted_name = object.__getattribute__(subvalue, 'name') + if key in cache: + subvalue = cache[key] + else: + subvalue = subvalue.clone(name=None) + cache[key] = subvalue + if subvalue.name is None: + object.__setattr__(subvalue, 'parent', self) + if adopted_name is None: + adopted_name = ( + f'{name}{suffix}' + if not isinstance(subvalue, CompactNameScope) + else current_name + ) + object.__setattr__(subvalue, 'name', adopted_name) + queue.append(subvalue) + return subvalue + + val = _freeze_attr( + _map_over_modules_in_tree( + functools.partial(adopt_attr_modules, cache, queue), val + ) + ) + object.__setattr__(self, name, val) + for x in queue: + x.__post_init__() + + def _try_setup(self, shallow: bool = False) -> None: + """Tries to setup module if scope is available and setup has not been called yet.""" + if ( + self.scope + and not self._state.in_setup + and self._state.setup_called != SetupState.DONE + ): + try: + self._state.in_setup = True + # A shallow setup will only register attribute submodules but it does + # not call the user's setup. This avoids running before a + # transformation. + for field in dataclasses.fields(self): + if field.name not in ('parent', 'name') and field.init: + self._register_submodules(field.name, getattr(self, field.name)) + if not shallow: + self.setup() + # create NonTransparent Modules + self._compact_name_scope_modules = { + name: CompactNameScope( + getattr(type(self), name).inner_fun, lambda: self, name=name + ) + for name in self._compact_name_scope_methods + } + + # We run static checks abstractly once for setup before any transforms + # to detect name collisions and other python errors. + elif self._state.setup_called == SetupState.NEW: + self._validate_setup() + finally: + self._state.in_setup = False + if not shallow: + self._state.setup_called = SetupState.DONE + + def _validate_setup(self) -> None: + """Abstractly evaluates setup only to run static checks.""" + + def run_setup_only(x): + wrapped_id = wrap_method_once(lambda m, x: x) + with TestScope({}, rngs={}, mutable=True).temporary() as root: + return wrapped_id(self.clone(parent=root), x) + + _ = jax.eval_shape(run_setup_only, 0) + + def _name_taken( + self, + name: str, + reuse_scopes: bool = False, + collection: Optional[str] = None, + ) -> bool: + assert self.scope is not None + if reuse_scopes: + return False + return self.scope.name_reserved(name, collection) + + @property + def _initialization_allowed(self): + return ( + not self._state.is_initialized # allow eager attachment in post-init + or self._state.in_setup + or self._state.in_compact_method + ) + + @property + def path(self): + """Get the path of this Module. Top-level root modules have an empty path ``()``. + Note that this method can only be used on bound modules that have a valid scope. + + Example usage:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class SubModel(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... print(f'SubModel path: {self.path}') + ... return x + + >>> class Model(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... print(f'Model path: {self.path}') + ... return SubModel()(x) + + >>> model = Model() + >>> variables = model.init(jax.random.key(0), jnp.ones((1, 2))) + Model path: () + SubModel path: ('SubModel_0',) + """ + + if self.scope is None: + raise ValueError("Can't access module paths on unbound modules.") + + return self.scope.path + + def clone( + self: M, + *, + parent: Optional[Union[Scope, 'Module', _Sentinel]] = None, + _deep_clone: Union[bool, weakref.WeakValueDictionary] = False, + _reset_names: bool = False, + **updates, + ) -> M: + """Creates a clone of this Module, with optionally updated arguments. + + NOTE: end users are encouraged to use the ``copy`` method. ``clone`` is used + primarily for internal routines, and ``copy`` offers simpler arguments and + better defaults. + + Args: + parent: The parent of the clone. The clone will have no parent if no + explicit parent is specified. + _deep_clone: A boolean or a weak value dictionary to control deep cloning + of submodules. If True, submodules will be cloned recursively. If a weak + value dictionary is passed, it will be used to cache cloned submodules. + This flag is used by init/apply/bind to avoid scope leakage. + _reset_names: If True, ``name=None`` is also passed to submodules when + cloning. Resetting names in submodules is necessary when calling ``.unbind``. + **updates: Attribute updates. + + Returns: + A clone of the this Module with the updated attributes and parent. + """ + attrs = { + f.name: getattr(self, f.name) for f in dataclasses.fields(self) if f.init + } + + attrs.update(parent=parent, **updates) + + # Here we implement deep cloning of submodules, this is necessary to avoid scope leakage + # from external submodules into init/apply/bind while preserving sharing-by-reference + # relationships between submodules. + if _deep_clone != False: + # We use a weak value dictionary to cache cloned submodules. When a shared + # submodule is cloned, its only cloned once else its fetched from the cache. + cache = ( + weakref.WeakValueDictionary() + if isinstance(_deep_clone, bool) + else _deep_clone + ) + + def clone_fn(m: Module) -> Module: + if hasattr(m, '_id'): + key = m._id + if key in cache: + return cache[key] + else: + if _reset_names: + clone = m.clone( + _deep_clone=cache, _reset_names=_reset_names, name=None + ) + else: + clone = m.clone(_deep_clone=cache) + cache[key] = clone + return clone + else: + # If the module doesn't have an _id attribute it could be a mock object + # so we return it as is. + return m + + # _map_submodules will map over all submodules inside attrs + # value here can be any pytree, non-module values are ignored + for field_name, value in attrs.items(): + if field_name == 'parent': + continue + attrs[field_name] = _map_submodules(clone_fn, value) + + module = self.__class__(**attrs) + + return module + + def copy( + self: M, + *, + parent: Optional[Union[Scope, 'Module', _Sentinel]] = _unspecified_parent, + name: Optional[str] = None, + **updates, + ) -> M: + """Creates a copy of this Module, with optionally updated arguments. + + Args: + parent: The parent of the copy. By default the current module is taken + as parent if not explicitly specified. + name: A new name for the copied Module, by default a new automatic name + will be given. + **updates: Attribute updates. + + Returns: + A copy of the this Module with the updated name, parent, and attributes. + """ + return self.clone( + parent=parent, name=name, _deep_clone=True, _reset_names=False, **updates + ) + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[True], + **init_kwargs, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[False], + **init_kwargs, + ) -> Variable[meta.AxisMetadata[T]]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + **init_kwargs, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ... + + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + **init_kwargs, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + """Declares and returns a variable in this Module. + + See :mod:`flax.core.variables` for more information. See also :meth:`param` + for a shorthand way to define read-only variables in the "params" + collection. + + Contrary to :meth:`param`, all arguments passing using ``init_fn`` should be + passed on explicitly:: + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(4)(x) + ... key = self.make_rng('stats') + ... mean = self.variable('stats', 'mean', nn.initializers.lecun_normal(), key, x.shape) + ... ... + ... return x * mean.value + >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) + >>> jax.tree_util.tree_map(jnp.shape, variables) + {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}}, 'stats': {'mean': (2, 4)}} + + In the example above, the function ``lecun_normal`` expects two arguments: + ``key`` and ``shape``, and both have to be passed on. The PRNG for ``stats`` + has to be provided explicitly when calling :meth:`init` and :meth:`apply`. + + Args: + col: The variable collection name. + name: The variable name. + init_fn: The function that will be called to compute the initial value of + this variable. This function will only be called the first time this + variable is used in this module. If None, the variable must already be + initialized otherwise an error is raised. + *init_args: The positional arguments to pass to init_fn. + unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed + value, see ``flax.nn.meta.unbox`` (default: True). + **init_kwargs: The key-word arguments to pass to init_fn + + Returns: + A :class:`flax.core.variables.Variable` that can be read or set via + ".value" attribute. Throws an error if the variable exists already. + """ + if not self._initialization_allowed: + raise ValueError( + 'Variables must be initialized in `setup()` or in a method ' + 'wrapped in `@compact`' + ) + if self._name_taken(name, collection=col): + raise errors.NameInUseError('variable', name, self.__class__.__name__) + assert self.scope is not None + v = self.scope.variable( + col, name, init_fn, *init_args, unbox=unbox, **init_kwargs + ) + self._state.children[name] = col + return v + + @overload + def param( + self, name: str, init_fn: Callable[..., T], *init_args, + ) -> T: + ... + + @overload + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: Literal[True], + **init_kwargs, + ) -> T: + ... + + @overload + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: Literal[False], + **init_kwargs, + ) -> meta.AxisMetadata[T]: + ... + + @overload + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: bool, + **init_kwargs, + ) -> Union[T, meta.AxisMetadata[T]]: + ... + + def param( + self, + name: str, + init_fn: Callable[..., T], + *init_args, + unbox: bool = True, + **init_kwargs, + ) -> Union[T, meta.AxisMetadata[T]]: + """Declares and returns a parameter in this Module. + + Parameters are read-only variables in the collection named "params". See + :mod:`flax.core.variables` for more details on variables. + + The first argument of ``init_fn`` is assumed to be a PRNG key, which is + provided automatically and does not have to be passed using ``init_args`` + or ``init_kwargs``:: + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(4)(x) + ... mean = self.param('mean', nn.initializers.lecun_normal(), x.shape) + ... ... + ... return x * mean + >>> variables = Foo().init({'params': jax.random.key(0), 'stats': jax.random.key(1)}, jnp.ones((2, 3))) + >>> jax.tree_util.tree_map(jnp.shape, variables) + {'params': {'Dense_0': {'bias': (4,), 'kernel': (3, 4)}, 'mean': (2, 4)}} + + In the example above, the function ``lecun_normal`` expects two arguments: + ``key`` and ``shape``, but only ``shape`` has to be provided explicitly; + ``key`` is set automatically using the PRNG for ``params`` that is passed + when initializing the module using :meth:`init`. + + Args: + name: The parameter name. + init_fn: The function that will be called to compute the initial value of + this variable. This function will only be called the first time this + parameter is used in this module. + *init_args: The positional arguments to pass to init_fn. + unbox: If True, ``AxisMetadata`` instances are replaced by their unboxed + value, see ``flax.nn.meta.unbox`` (default: True). + **init_kwargs: The key-word arguments to pass to init_fn. + + Returns: + The value of the initialized parameter. Throws an error if the parameter + exists already. + """ + if not self._initialization_allowed: + raise ValueError( + 'Parameters must be initialized in `setup()` or in a method ' + 'wrapped in `@compact`' + ) + if self._name_taken(name, collection='params'): + raise errors.NameInUseError('param', name, self.__class__.__name__) + assert self.scope is not None + v = self.scope.param(name, init_fn, *init_args, unbox=unbox, **init_kwargs) + self._state.children[name] = 'params' + return v + + def has_variable(self, col: str, name: str) -> bool: + """Checks if a variable of given collection and name exists in this Module. + + See :mod:`flax.core.variables` for more explanation on variables and + collections. + + Args: + col: The variable collection name. + name: The name of the variable. + + Returns: + True if the variable exists. + """ + if self.scope is None: + raise ValueError("Can't access variables on unbound modules") + return self.scope.has_variable(col, name) + + def is_mutable_collection(self, col: str) -> bool: + """Returns true if the collection ``col`` is mutable.""" + if self.scope is None: + raise ValueError("Can't check mutability on unbound modules") + return self.scope.is_mutable_collection(col) + + def has_rng(self, name: str) -> bool: + """Returns true if a PRNGSequence with name ``name`` exists.""" + if self.scope is None: + raise ValueError("Can't query for RNGs on unbound modules") + return self.scope.has_rng(name) + + def make_rng(self, name: str = 'params') -> PRNGKey: + """Returns a new RNG key from a given RNG sequence for this Module. + + The new RNG key is split from the previous one. Thus, every call to + ``make_rng`` returns a new RNG key, while still guaranteeing full + reproducibility. + + .. note:: + If an invalid name is passed (i.e. no RNG key was passed by + the user in ``.init`` or ``.apply`` for this name), then ``name`` + will default to ``'params'``. + + Example:: + + >>> import jax + >>> import flax.linen as nn + + >>> class ParamsModule(nn.Module): + ... def __call__(self): + ... return self.make_rng('params') + >>> class OtherModule(nn.Module): + ... def __call__(self): + ... return self.make_rng('other') + + >>> key = jax.random.key(0) + >>> params_out, _ = ParamsModule().init_with_output({'params': key}) + >>> # self.make_rng('other') will default to using the 'params' RNG stream + >>> other_out, _ = OtherModule().init_with_output({'params': key}) + >>> assert params_out == other_out + + Learn more about RNG's by reading the Flax RNG guide: + https://flax.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html + + Args: + name: The RNG sequence name. + + Returns: + The newly generated RNG key. + """ + if self.scope is None: + raise ValueError("Can't use RNGs on unbound modules") + return self.scope.make_rng(name) + + def is_initializing(self) -> bool: + """Returns True if running under self.init(...) or nn.init(...)(). + + This is a helper method to handle the common case of simple initialization + where we wish to have setup logic occur when only called under + ``module.init`` or ``nn.init``. For more complicated multi-phase + initialization scenarios it is better to test for the mutability of + particular variable collections or for the presence of particular + variables that potentially need to be initialized. + """ + if self.scope is None: + raise ValueError("Can't check if running under init() on unbound modules") + return self.scope.get_flag('initializing', False) + + def _module_checks(self): + """Run standard runtime checks.""" + + if not isinstance(self, Module): + raise errors.InvalidInstanceModuleError() + + overridden_post_init = self.__post_init__ != Module.__post_init__ + if overridden_post_init and not hasattr(self, '_id'): + raise errors.IncorrectPostInitOverrideError() + + @traceback_util.api_boundary + def bind( + self: M, + variables: VariableDict, + *args, + rngs: Optional[RNGSequences] = None, + mutable: CollectionFilter = False, + ) -> M: + """Creates an interactive Module instance by binding variables and RNGs. + + ``bind`` provides an "interactive" instance of a Module directly without + transforming a function with ``apply``. This is particularly useful for + debugging and interactive use cases like notebooks where a function would + limit the ability to split up code into different cells. + + Once the variables (and optionally RNGs) are bound to a ``Module`` it + becomes a stateful object. Note that idiomatic JAX is functional and + therefore an interactive instance does not mix well with vanilla JAX APIs. + ``bind()`` should only be used for interactive experimentation, and in all + other cases we strongly encourage users to use ``apply()`` instead. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> import flax.linen as nn + + >>> class AutoEncoder(nn.Module): + ... def setup(self): + ... self.encoder = nn.Dense(3) + ... self.decoder = nn.Dense(5) + ... + ... def __call__(self, x): + ... return self.decoder(self.encoder(x)) + + >>> x = jnp.ones((16, 9)) + >>> ae = AutoEncoder() + >>> variables = ae.init(jax.random.key(0), x) + >>> model = ae.bind(variables) + >>> z = model.encoder(x) + >>> x_reconstructed = model.decoder(z) + + Args: + variables: A dictionary containing variables keyed by variable + collections. See :mod:`flax.core.variables` for more details about + variables. + *args: Named arguments (not used). + rngs: a dict of PRNGKeys to initialize the PRNG sequences. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. + + Returns: + A copy of this instance with bound variables and RNGs. + """ + Module._module_checks(self) + + del args + scope = core.bind(variables, rngs=rngs, mutable=mutable) + return self.clone(parent=scope, _deep_clone=True) + + def unbind(self: M) -> Tuple[M, VariableDict]: + """Returns an unbound copy of a Module and its variables. + + ``unbind`` helps create a stateless version of a bound Module. + + An example of a common use case: to extract a sub-Module defined inside + ``setup()`` and its corresponding variables: 1) temporarily ``bind`` the + parent Module; and then 2) ``unbind`` the desired sub-Module. (Recall that + ``setup()`` is only called when the Module is bound.):: + + >>> class Encoder(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... ... + ... return nn.Dense(256)(x) + + >>> class Decoder(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... ... + ... return nn.Dense(784)(x) + + >>> class AutoEncoder(nn.Module): + ... def setup(self): + ... self.encoder = Encoder() + ... self.decoder = Decoder() + ... + ... def __call__(self, x): + ... return self.decoder(self.encoder(x)) + + >>> module = AutoEncoder() + >>> variables = module.init(jax.random.key(0), jnp.ones((1, 784))) + + >>> # Extract the Encoder sub-Module and its variables + >>> encoder, encoder_vars = module.bind(variables).encoder.unbind() + + Returns: + A tuple with an unbound copy of this Module and its variables. + """ + Module._module_checks(self) + + if self.scope is None: + raise errors.CallUnbindOnUnboundModuleError() + + variables = self.variables + module = self.clone(_deep_clone=True, _reset_names=True, name=None) + return module, variables + + @traceback_util.api_boundary + def apply( + self, + variables: VariableDict, + *args, + rngs: Optional[Union[PRNGKey, RNGSequences]] = None, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = False, + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + **kwargs, + ) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: + """Applies a module method to variables and returns output and modified variables. + + Note that ``method`` should be set if one would like to call ``apply`` on a + different class method than ``__call__``. For instance, suppose a + Transformer modules has a method called ``encode``, then the following calls + ``apply`` on that method:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + >>> import numpy as np + + >>> class Transformer(nn.Module): + ... def encode(self, x): + ... ... + + >>> x = jnp.ones((16, 9)) + >>> model = Transformer() + >>> variables = model.init(jax.random.key(0), x, method=Transformer.encode) + + >>> encoded = model.apply(variables, x, method=Transformer.encode) + + If a function instance is provided, the unbound function is used. For + instance, the example below is equivalent to the one above:: + + >>> encoded = model.apply(variables, x, method=model.encode) + + You can also pass a string to a callable attribute of the module. For + example, the previous can be written as:: + + >>> encoded = model.apply(variables, x, method='encode') + + Note ``method`` can also be a function that is not defined in + ``Transformer``. In that case, the function should have at least one + argument representing an instance of the Module class:: + + >>> def other_fn(instance, x): + ... # instance.some_module_attr(...) + ... instance.encode + ... ... + + >>> model.apply(variables, x, method=other_fn) + + If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` + RNG stream. If you want to use a different RNG stream or need to use + multiple streams, you can pass a dictionary mapping each RNG stream name + to its corresponding ``PRNGKey`` to ``apply``. If ``self.make_rng(name)`` + is called on an RNG stream name that isn't passed by the user, it will + default to using the ``'params'`` RNG stream. + + Example:: + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x, add_noise=False): + ... x = nn.Dense(16)(x) + ... x = nn.relu(x) + ... + ... if add_noise: + ... # Add gaussian noise + ... noise_key = self.make_rng('noise') + ... x = x + jax.random.normal(noise_key, x.shape) + ... + ... return nn.Dense(1)(x) + + >>> x = jnp.empty((1, 7)) + >>> module = Foo() + >>> rngs = {'params': jax.random.key(0), 'noise': jax.random.key(1)} + >>> variables = module.init(rngs, x) + >>> out0 = module.apply(variables, x, add_noise=True, rngs=rngs) + + >>> rngs['noise'] = jax.random.key(0) + >>> out1 = module.apply(variables, x, add_noise=True, rngs=rngs) + >>> # different output (key(1) vs key(0)) + >>> np.testing.assert_raises(AssertionError, np.testing.assert_allclose, out0, out1) + + >>> del rngs['noise'] + >>> # self.make_rng('noise') will default to using the 'params' RNG stream + >>> out2 = module.apply(variables, x, add_noise=True, rngs=rngs) + >>> # same output (key(0)) + >>> np.testing.assert_allclose(out1, out2) + + >>> # passing in a single key is equivalent to passing in {'params': key} + >>> out3 = module.apply(variables, x, add_noise=True, rngs=jax.random.key(0)) + >>> # same output (key(0)) + >>> np.testing.assert_allclose(out2, out3) + + Args: + variables: A dictionary containing variables keyed by variable + collections. See :mod:`flax.core.variables` for more details about + variables. + *args: Named arguments passed to the specified apply method. + rngs: a dict of PRNGKeys to initialize the PRNG sequences. The "params" + PRNG sequence is used to initialize parameters. + method: A function to call apply on. This is generally a function in the + module. If provided, applies this method. If not provided, applies the + ``__call__`` method of the module. A string can also be provided to + specify a method by name. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. + capture_intermediates: If ``True``, captures intermediate return values of + all Modules inside the "intermediates" collection. By default, only the + return values of all ``__call__`` methods are stored. A function can be + passed to change the filter behavior. The filter function takes the + Module instance and method name and returns a bool indicating whether + the output of that method invocation should be stored. + **kwargs: Keyword arguments passed to the specified apply method. + + Returns: + If ``mutable`` is False, returns output. If any collections are + mutable, returns ``(output, vars)``, where ``vars`` are is a dict + of the modified collections. + """ + Module._module_checks(self) + + if rngs is not None and not isinstance(rngs, dict): + if not core.scope._is_valid_rng(rngs): + raise errors.InvalidRngError( + 'RNGs should be of shape (2,) or PRNGKey in Module ' + f'{self.__class__.__name__}, but rngs are: {rngs}' + ) + rngs = {'params': rngs} + + if isinstance(method, str): + attribute_name = method + method = getattr(self, attribute_name) + if not callable(method): + class_name = type(self).__name__ + raise TypeError( + f"'{class_name}.{attribute_name}' must be a callable, got" + f' {type(method)}.' + ) + # if the `method` string is a submodule, we create a lambda function + # that calls the submodule, forwarding all arguments. + if isinstance(method, Module): + method = lambda self, *args, **kwargs: getattr(self, attribute_name)( + *args, **kwargs + ) + elif method is None: + method = self.__call__ + method = _get_unbound_fn(method) + return apply( + method, + self, + mutable=mutable, + capture_intermediates=capture_intermediates, + )(variables, *args, **kwargs, rngs=rngs) + + @traceback_util.api_boundary + def init_with_output( + self, + rngs: Union[PRNGKey, RNGSequences], + *args, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + **kwargs, + ) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]: + """Initializes a module method with variables and returns output and modified variables. + + Args: + rngs: The rngs for the variable collections. + *args: Named arguments passed to the init function. + method: An optional method. If provided, applies this method. If not + provided, applies the ``__call__`` method. A string can also be + provided to specify a method by name. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default, all collections except "intermediates" + are mutable. + capture_intermediates: If ``True``, captures intermediate return values of + all Modules inside the "intermediates" collection. By default only the + return values of all ``__call__`` methods are stored. A function can be + passed to change the filter behavior. The filter function takes the + Module instance and method name and returns a bool indicating whether + the output of that method invocation should be stored. + **kwargs: Keyword arguments passed to the init function. + + Returns: + ``(output, vars)``, where ``vars`` are is a dict of the modified + collections. + """ + Module._module_checks(self) + + if not isinstance(rngs, dict): + if not core.scope._is_valid_rng(rngs): + raise errors.InvalidRngError( + 'RNGs should be of shape (2,) or PRNGKey in Module ' + f'{self.__class__.__name__}, but rngs are: {rngs}' + ) + rngs = {'params': rngs} + + if isinstance(method, str): + attribute_name = method + method = getattr(self, attribute_name) + if not callable(method): + class_name = type(self).__name__ + raise TypeError( + f"'{class_name}.{attribute_name}' must be a callable, got" + f' {type(method)}.' + ) + elif method is None: + method = self.__call__ + method = _get_unbound_fn(method) + return init_with_output( + method, + self, + mutable=mutable, + capture_intermediates=capture_intermediates, + )(rngs, *args, **kwargs) + + @traceback_util.api_boundary + def init( + self, + rngs: Union[PRNGKey, RNGSequences], + *args, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[['Module', str], bool]] = False, + **kwargs, + ) -> Union[FrozenVariableDict, Dict[str, Any]]: + """Initializes a module method with variables and returns modified variables. + + ``init`` takes as first argument either a single ``PRNGKey``, or a + dictionary mapping variable collections names to their ``PRNGKeys``, and + will call ``method`` (which is the module's ``__call__`` function by + default) passing ``*args`` and ``**kwargs``, and returns + a dictionary of initialized variables. + + Example:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + >>> import numpy as np + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x, train): + ... x = nn.Dense(16)(x) + ... x = nn.BatchNorm(use_running_average=not train)(x) + ... x = nn.relu(x) + ... return nn.Dense(1)(x) + + >>> x = jnp.empty((1, 7)) + >>> module = Foo() + >>> key = jax.random.key(0) + >>> variables = module.init(key, x, train=False) + + If you pass a single ``PRNGKey``, Flax will use it to feed the ``'params'`` + RNG stream. If you want to use a different RNG stream or need to use + multiple streams, you can pass a dictionary mapping each RNG stream name + to its corresponding ``PRNGKey`` to ``init``. If ``self.make_rng(name)`` + is called on an RNG stream name that isn't passed by the user, it will + default to using the ``'params'`` RNG stream. + + Example:: + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(16)(x) + ... x = nn.relu(x) + ... + ... other_variable = self.variable( + ... 'other_collection', + ... 'other_variable', + ... lambda x: jax.random.normal(self.make_rng('other_rng'), x.shape), + ... x, + ... ) + ... x = x + other_variable.value + ... + ... return nn.Dense(1)(x) + + >>> module = Foo() + >>> rngs = {'params': jax.random.key(0), 'other_rng': jax.random.key(1)} + >>> variables0 = module.init(rngs, x) + + >>> rngs['other_rng'] = jax.random.key(0) + >>> variables1 = module.init(rngs, x) + >>> # equivalent params (key(0)) + >>> _ = jax.tree_util.tree_map( + ... np.testing.assert_allclose, variables0['params'], variables1['params'] + ... ) + >>> # different other_variable (key(1) vs key(0)) + >>> np.testing.assert_raises( + ... AssertionError, + ... np.testing.assert_allclose, + ... variables0['other_collection']['other_variable'], + ... variables1['other_collection']['other_variable'], + ... ) + + >>> del rngs['other_rng'] + >>> # self.make_rng('other_rng') will default to using the 'params' RNG stream + >>> variables2 = module.init(rngs, x) + >>> # equivalent params (key(0)) + >>> _ = jax.tree_util.tree_map( + ... np.testing.assert_allclose, variables1['params'], variables2['params'] + ... ) + >>> # equivalent other_variable (key(0)) + >>> np.testing.assert_allclose( + ... variables1['other_collection']['other_variable'], + ... variables2['other_collection']['other_variable'], + ... ) + + >>> # passing in a single key is equivalent to passing in {'params': key} + >>> variables3 = module.init(jax.random.key(0), x) + >>> # equivalent params (key(0)) + >>> _ = jax.tree_util.tree_map( + ... np.testing.assert_allclose, variables2['params'], variables3['params'] + ... ) + >>> # equivalent other_variable (key(0)) + >>> np.testing.assert_allclose( + ... variables2['other_collection']['other_variable'], + ... variables3['other_collection']['other_variable'], + ... ) + + Jitting ``init`` initializes a model lazily using only the shapes of the + provided arguments, and avoids computing the forward pass with actual + values. Example:: + + >>> module = nn.Dense(1) + >>> init_jit = jax.jit(module.init) + >>> variables = init_jit(jax.random.key(0), x) + + ``init`` is a light wrapper over ``apply``, so other ``apply`` arguments + like ``method``, ``mutable``, and ``capture_intermediates`` are also + available. + + Args: + rngs: The rngs for the variable collections. + *args: Named arguments passed to the init function. + method: An optional method. If provided, applies this method. If not + provided, applies the ``__call__`` method. A string can also be provided + to specify a method by name. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default all collections except "intermediates" + are mutable. + capture_intermediates: If ``True``, captures intermediate return values of + all Modules inside the "intermediates" collection. By default only the + return values of all ``__call__`` methods are stored. A function can be + passed to change the filter behavior. The filter function takes the + Module instance and method name and returns a bool indicating whether + the output of that method invocation should be stored. + **kwargs: Keyword arguments passed to the init function. + + Returns: + The initialized variable dict. + """ + Module._module_checks(self) + + _, v_out = self.init_with_output( + rngs, + *args, + method=method, + mutable=mutable, + capture_intermediates=capture_intermediates, + **kwargs, + ) + return v_out + + @traceback_util.api_boundary + def lazy_init( + self, + rngs: Union[PRNGKey, RNGSequences], + *args, + method: Optional[Callable[..., Any]] = None, + mutable: CollectionFilter = DenyList('intermediates'), + **kwargs, + ) -> FrozenVariableDict: + """Initializes a module without computing on an actual input. + + lazy_init will initialize the variables without doing unnecessary compute. + The input data should be passed as a ``jax.ShapeDtypeStruct`` which + specifies the shape and dtype of the input but no concrete data. + + Example:: + + >>> model = nn.Dense(features=256) + >>> variables = model.lazy_init( + ... jax.random.key(0), jax.ShapeDtypeStruct((1, 128), jnp.float32)) + + The args and kwargs args passed to ``lazy_init`` can be a mix of + concrete (jax arrays, scalars, bools) and abstract (ShapeDtypeStruct) + values. Concrete values are only necessary for arguments that affect + the initialization of variables. For example, the model might expect + a keyword arg that enables/disables a subpart of the model. + In this case, an explicit value (True/Flase) should be passed otherwise + ``lazy_init`` cannot infer which variables should be initialized. + + Args: + rngs: The rngs for the variable collections. + *args: arguments passed to the init function. + method: An optional method. If provided, applies this method. If not + provided, applies the ``__call__`` method. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default all collections except "intermediates" + are mutable. + **kwargs: Keyword arguments passed to the init function. + + Returns: + The initialized variable dict. + """ + Module._module_checks(self) + + def lazy_wrapper(rngs, *args, **kwargs): + return self.init(rngs, *args, method=method, mutable=mutable, **kwargs) + + return partial_eval.lazy_init(lazy_wrapper)(rngs, *args, **kwargs) + + @property + def variables(self) -> VariableDict: + """Returns the variables in this module.""" + if self.scope is None: + raise ValueError("Can't access variables on unbound modules") + return self.scope.variables() + + def get_variable(self, col: str, name: str, default: Optional[T] = None) -> T: + """Retrieves the value of a Variable. + + Args: + col: the variable collection. + name: the name of the variable. + default: the default value to return if the variable does not exist in + this scope. + + Returns: + The value of the input variable, of the default value if the variable + doesn't exist in this scope. + """ + if self.scope is None: + raise ValueError("Can't access variables on unbound modules") + return self.scope.get_variable(col, name, default) + + def put_variable(self, col: str, name: str, value: Any): + """Updates the value of the given variable if it is mutable, or an error otherwise. + + Args: + col: the variable collection. + name: the name of the variable. + value: the new value of the variable. + """ + if self.scope is None: + raise ValueError("Can't access variables on unbound modules") + self.scope.put_variable(col, name, value) + + @overload + def sow(self, col: str, name: str, value: Any) -> bool: + ... + + @overload + def sow( + self, + col: str, + name: str, + value: T, + reduce_fn: Callable[[K, T], K] = tuple_reduce, + init_fn: Callable[[], K] = tuple_init, # type: ignore + ) -> bool: + ... + + def sow( + self, + col: str, + name: str, + value: T, + reduce_fn: Callable[[K, T], K] = tuple_reduce, + init_fn: Callable[[], K] = tuple_init, # type: ignore + ) -> bool: + """Stores a value in a collection. + + Collections can be used to collect intermediate values without + the overhead of explicitly passing a container through each Module call. + + If the target collection is not mutable ``sow`` behaves like a no-op + and returns ``False``. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> import flax.linen as nn + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... h = nn.Dense(4)(x) + ... self.sow('intermediates', 'h', h) + ... return nn.Dense(2)(h) + + >>> x = jnp.ones((16, 9)) + >>> model = Foo() + >>> variables = model.init(jax.random.key(0), x) + >>> y, state = model.apply(variables, x, mutable=['intermediates']) + >>> jax.tree.map(jnp.shape, state['intermediates']) + {'h': ((16, 4),)} + + By default the values are stored in a tuple and each stored value + is appended at the end. This way all intermediates can be tracked when + the same module is called multiple times. Alternatively, a custom + init/reduce function can be passed:: + + >>> class Foo2(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... init_fn = lambda: 0 + ... reduce_fn = lambda a, b: a + b + ... self.sow('intermediates', 'h', x, + ... init_fn=init_fn, reduce_fn=reduce_fn) + ... self.sow('intermediates', 'h', x * 2, + ... init_fn=init_fn, reduce_fn=reduce_fn) + ... return x + + >>> x = jnp.ones((1, 1)) + >>> model = Foo2() + >>> variables = model.init(jax.random.key(0), x) + >>> y, state = model.apply( + ... variables, x, mutable=['intermediates']) + >>> print(state['intermediates']) + {'h': Array([[3.]], dtype=float32)} + + Args: + col: The name of the variable collection. + name: The name of the variable. + value: The value of the variable. + reduce_fn: The function used to combine the existing value with the new + value. The default is to append the value to a tuple. + init_fn: For the first value stored, ``reduce_fn`` will be passed the result + of ``init_fn`` together with the value to be stored. The default is an + empty tuple. + + Returns: + ``True`` if the value has been stored successfully, ``False`` otherwise. + """ + if self.scope is None: + raise ValueError("Can't store variables on unbound modules") + if not self.scope.is_mutable_collection(col): + return False + if self.scope.has_variable(col, name): + xs = self.scope.get_variable(col, name) + else: + self.scope.reserve(name, col) + self._state.children[name] = col + xs = init_fn() + xs = reduce_fn(xs, value) + self.scope.put_variable(col, name, xs) + return True + + def perturb( + self, name: str, value: T, collection: str = 'perturbations' + ) -> T: + """Add an zero-value variable ('perturbation') to the intermediate value. + + The gradient of ``value`` would be the same as the gradient of this + perturbation variable. Therefore, if you define your loss function with + both params and perturbations as standalone arguments, you can get the + intermediate gradients of ``value`` by running ``jax.grad`` on the perturbation + argument. + + .. note:: + This is an experimental API and may be tweaked later for better + performance and usability. + At its current stage, it creates extra dummy variables that occupies extra + memory space. Use it only to debug gradients in training. + + Example:: + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(3)(x) + ... x = self.perturb('dense3', x) + ... return nn.Dense(2)(x) + + >>> def loss(variables, inputs, targets): + ... preds = model.apply(variables, inputs) + ... return jnp.square(preds - targets).mean() + + >>> x = jnp.ones((2, 9)) + >>> y = jnp.ones((2, 2)) + >>> model = Foo() + >>> variables = model.init(jax.random.key(0), x) + >>> intm_grads = jax.grad(loss, argnums=0)(variables, x, y) + >>> print(intm_grads['perturbations']['dense3']) + [[-1.456924 -0.44332537 0.02422847] + [-1.456924 -0.44332537 0.02422847]] + + If perturbations are not passed to ``apply``, ``perturb`` behaves like a no-op + so you can easily disable the behavior when not needed:: + + >>> model.apply(variables, x) # works as expected + Array([[-1.0980128 , -0.67961735], + [-1.0980128 , -0.67961735]], dtype=float32) + >>> model.apply({'params': variables['params']}, x) # behaves like a no-op + Array([[-1.0980128 , -0.67961735], + [-1.0980128 , -0.67961735]], dtype=float32) + >>> intm_grads = jax.grad(loss, argnums=0)({'params': variables['params']}, x, y) + >>> 'perturbations' not in intm_grads + True + """ + if self.scope is None: + raise ValueError("Can't store variables on unbound modules") + + if self.is_mutable_collection(collection): + if not self.scope.has_variable(collection, name): + self.scope.reserve(name, collection) + self._state.children[name] = collection + self.scope.put_variable(collection, name, jnp.zeros_like(value)) # type: ignore + + if collection in self.scope.root._variables: + if self.scope.has_variable(collection, name): + value += self.scope.get_variable(collection, name) # type: ignore + else: + raise ValueError(f"Perturbation collection {collection} present, but " + f"missing perturbation variable {name}") + + return value + + def tabulate( + self, + rngs: Union[PRNGKey, RNGSequences], + *args, + depth: Optional[int] = None, + show_repeated: bool = False, + mutable: CollectionFilter = DenyList('intermediates'), + console_kwargs: Optional[Mapping[str, Any]] = None, + table_kwargs: Mapping[str, Any] = MappingProxyType({}), + column_kwargs: Mapping[str, Any] = MappingProxyType({}), + compute_flops: bool = False, + compute_vjp_flops: bool = False, + **kwargs, + ) -> str: + """Creates a summary of the Module represented as a table. + + This method has the same signature and internally calls ``Module.init``, + but instead of returning the variables, it returns the string summarizing + the Module in a table. ``tabulate`` uses ``jax.eval_shape`` to run the forward + computation without consuming any FLOPs or allocating memory. + + Additional arguments can be passed into the ``console_kwargs`` argument, for + example, ``{'width': 120}``. For a full list of ``console_kwargs`` arguments, + see: + https://rich.readthedocs.io/en/stable/reference/console.html#rich.console.Console + + Example:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... h = nn.Dense(4)(x) + ... return nn.Dense(2)(h) + + >>> x = jnp.ones((16, 9)) + + >>> # print(Foo().tabulate( + >>> # jax.random.key(0), x, compute_flops=True, compute_vjp_flops=True)) + + This gives the following output:: + + Foo Summary + ┏━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ + ┃ path ┃ module ┃ inputs ┃ outputs ┃ flops ┃ vjp_flops ┃ params ┃ + ┡━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ + │ │ Foo │ float32[16,9] │ float32[16,2] │ 1504 │ 4460 │ │ + ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ + │ Dense_0 │ Dense │ float32[16,9] │ float32[16,4] │ 1216 │ 3620 │ bias: │ + │ │ │ │ │ │ │ float32[4] │ + │ │ │ │ │ │ │ kernel: │ + │ │ │ │ │ │ │ float32[9,4] │ + │ │ │ │ │ │ │ │ + │ │ │ │ │ │ │ 40 (160 B) │ + ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ + │ Dense_1 │ Dense │ float32[16,4] │ float32[16,2] │ 288 │ 840 │ bias: │ + │ │ │ │ │ │ │ float32[2] │ + │ │ │ │ │ │ │ kernel: │ + │ │ │ │ │ │ │ float32[4,2] │ + │ │ │ │ │ │ │ │ + │ │ │ │ │ │ │ 10 (40 B) │ + ├─────────┼────────┼───────────────┼───────────────┼───────┼───────────┼─────────────────┤ + │ │ │ │ │ │ Total │ 50 (200 B) │ + └─────────┴────────┴───────────────┴───────────────┴───────┴───────────┴─────────────────┘ + + Total Parameters: 50 (200 B) + + **Note**: rows order in the table does not represent execution order, + instead it aligns with the order of keys in ``variables`` which are sorted + alphabetically. + + **Note**: ``vjp_flops`` returns ``0`` if the module is not differentiable. + + Args: + rngs: The rngs for the variable collections as passed to ``Module.init``. + *args: The arguments to the forward computation. + depth: controls how many submodule deep the summary can go. By default, + its ``None`` which means no limit. If a submodule is not shown because of + the depth limit, its parameter count and bytes will be added to the row + of its first shown ancestor such that the sum of all rows always adds + up to the total number of parameters of the Module. + show_repeated: If ``True``, repeated calls to the same module will be shown + in the table, otherwise only the first call will be shown. Default is + ``False``. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: + The name of a single mutable collection. ``list``: A list of names of + mutable collections. By default, all collections except 'intermediates' + are mutable. + console_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.console.Console`` when rendering the table. + Default arguments are ``{'force_terminal': True, 'force_jupyter': + False}``. + table_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table`` constructor. + column_kwargs: An optional dictionary with additional keyword arguments + that are passed to ``rich.table.Table.add_column`` when adding columns to + the table. + compute_flops: whether to include a ``flops`` column in the table listing + the estimated FLOPs cost of each module forward pass. Does incur actual + on-device computation / compilation / memory allocation, but still + introduces overhead for large modules (e.g. extra 20 seconds for a + Stable Diffusion's UNet, whereas otherwise tabulation would finish in 5 + seconds). + compute_vjp_flops: whether to include a ``vjp_flops`` column in the table + listing the estimated FLOPs cost of each module backward pass. + Introduces a compute overhead of about 2-3X of ``compute_flops``. + **kwargs: keyword arguments to pass to the forward computation. + + Returns: + A string summarizing the Module. + """ + from flax.linen import summary + + tabulate_fn = summary.tabulate( + self, + rngs, + depth=depth, + show_repeated=show_repeated, + mutable=mutable, + console_kwargs=console_kwargs, + table_kwargs=table_kwargs, + column_kwargs=column_kwargs, + compute_flops=compute_flops, + compute_vjp_flops=compute_vjp_flops, + ) + return tabulate_fn(*args, **kwargs) + + def module_paths( + self, + rngs: Union[PRNGKey, RNGSequences], + *args, + show_repeated: bool = False, + mutable: CollectionFilter = DenyList('intermediates'), + **kwargs, + ) -> dict[str, 'Module']: + """Returns a dictionary mapping module paths to module instances. + + This method has the same signature and internally calls ``Module.init``, + but instead of returning the variables, it returns a dictionary mapping + module paths to unbounded copies of module instances that were used + at runtime. ``module_paths`` uses ``jax.eval_shape`` to run the forward + computation without consuming any FLOPs or allocating memory. + + Example:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... h = nn.Dense(4)(x) + ... return nn.Dense(2)(h) + + >>> x = jnp.ones((16, 9)) + >>> modules = Foo().module_paths(jax.random.key(0), x) + >>> print({ + ... p: type(m).__name__ for p, m in modules.items() + ... }) + {'': 'Foo', 'Dense_0': 'Dense', 'Dense_1': 'Dense'} + + Args: + rngs: The rngs for the variable collections as passed to ``Module.init``. + *args: The arguments to the forward computation. + show_repeated: If ``True``, repeated calls to the same module will be + shown in the table, otherwise only the first call will be shown. + Default is ``False``. + mutable: Can be bool, str, or list. Specifies which collections should + be treated as mutable: ``bool``: all/no collections are mutable. + ``str``: The name of a single mutable collection. ``list``: A list of + names of mutable collections. By default, all collections except + 'intermediates' are mutable. + **kwargs: keyword arguments to pass to the forward computation. + + Returns: + A dict`ionary mapping module paths to module instances. + """ + from flax.linen import summary + + table = summary._get_module_table( + module=self, + depth=None, + show_repeated=show_repeated, + compute_flops=False, + compute_vjp_flops=False, + )(rngs, *args, **kwargs, mutable=mutable) + + return {'/'.join(row.path): row.module_copy for row in table} + + +_ParentType = Union[Type[Module], Scope, Type[_Sentinel], None] + + +def merge_param(name: str, a: Optional[T], b: Optional[T]) -> T: + """Merges construction- and call-time argument. + + This is a utility for supporting a pattern where a Module hyperparameter + can be passed either to ``__init__`` or ``__call__``, and the value that is + not ``None`` will be used. + + Example:: + + >>> import flax.linen as nn + >>> from typing import Optional + + >>> class Foo(nn.Module): + ... train: Optional[bool] = None + + ... def __call__(self, train: Optional[bool] = None): + ... train = nn.merge_param('train', self.train, train) + + An error is thrown when both arguments are ``None`` or both values are not + ``None``. + + Args: + name: the name of the parameter. Used for error messages. + a: option a + b: option b + + Returns: + a or b whichever is not ``None``. + """ + if a is None and b is None: + raise ValueError( + f'Parameter "{name}" must be passed to the constructor or at call time.' + ) + if a is not None and b is not None: + raise ValueError( + f'Parameter "{name}" was passed to the constructor and at call time.' + ' Should be passed just once.' + ) + if a is None: + assert b is not None + return b + return a + + +@traceback_util.api_boundary +def apply( + fn: Callable[..., Any], + module: Module, + mutable: CollectionFilter = False, + capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, +) -> Callable[..., Any]: + """Creates an apply function to call ``fn`` with a bound module. + + Unlike ``Module.apply`` this function returns a new function with the + signature ``(variables, *args, rngs=None, **kwargs) -> T`` where ``T`` is the + return type of ``fn``. If ``mutable`` is not ``False`` the return type is a + tuple where the second item is a ``FrozenDict`` with the mutated variables. + + The apply function that is returned can be directly composed with + JAX transformations like ``jax.jit``:: + + >>> class Foo(nn.Module): + ... def encode(self, x): + ... ... + ... def decode(self, x): + ... ... + + >>> def f(foo, x): + ... z = foo.encode(x) + ... y = foo.decode(z) + ... # ... + ... return y + + >>> variables = {} + >>> foo = Foo() + >>> f_jitted = jax.jit(nn.apply(f, foo)) + >>> f_jitted(variables, jnp.ones((1, 3))) + + Args: + fn: The function that should be applied. The first argument passed will be + a module instance of the ``module`` with variables and RNGs bound to it. + module: The ``Module`` that will be used to bind variables and RNGs to. The + ``Module`` passed as the first argument to ``fn`` will be a clone of + module. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: The + name of a single mutable collection. ``list``: A list of names of mutable + collections. + capture_intermediates: If ``True``, captures intermediate return values of all + Modules inside the "intermediates" collection. By default, only the return + values of all `__call__` methods are stored. A function can be passed to + change the filter behavior. The filter function takes the Module instance + and method name and returns a bool indicating whether the output of that + method invocation should be stored. + + Returns: + The apply function wrapping ``fn``. + """ + + @functools.wraps(fn) + def scope_fn(scope, *args, **kwargs): + _context.capture_stack.append(capture_intermediates) + try: + return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) + finally: + _context.capture_stack.pop() + + if capture_intermediates is True: # pylint: disable=g-bool-id-comparison + capture_intermediates = capture_call_intermediates + if capture_intermediates: + mutable = union_filters(mutable, 'intermediates') + return core.apply(scope_fn, mutable=mutable) + + +@traceback_util.api_boundary +def init_with_output( + fn: Callable[..., Any], + module: Module, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, +) -> Callable[..., Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: + """Creates an init function to call ``fn`` with a bound module that also returns the function outputs. + + Unlike ``Module.init_with_output`` this function returns a new function with + the signature ``(rngs, *args, **kwargs) -> (T, variables)`` where ``T`` is the + return type of ``fn``. The rngs can be a dict of PRNGKeys or a single + ```PRNGKey`` which is equivalent to passing a dict with one PRNGKey with the + name "params". + + The init function that is returned can be directly composed with + JAX transformations like ``jax.jit``:: + + >>> class Foo(nn.Module): + ... def encode(self, x): + ... ... + ... def decode(self, x): + ... ... + + >>> def f(foo, x): + ... z = foo.encode(x) + ... y = foo.decode(z) + ... # ... + ... return y + + >>> foo = Foo() + >>> f_jitted = jax.jit(nn.init_with_output(f, foo)) + >>> y, variables = f_jitted(jax.random.key(0), jnp.ones((1, 3))) + + Args: + fn: The function that should be applied. The first argument passed will be + a module instance of the ``module`` with variables and RNGs bound to it. + module: The ``Module`` that will be used to bind variables and RNGs to. The + ``Module`` passed as the first argument to ``fn`` will be a clone of + module. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: The + name of a single mutable collection. ``list``: A list of names of mutable + collections. By default, all collections except "intermediates" are + mutable. + capture_intermediates: If ``True``, captures intermediate return values of all + Modules inside the "intermediates" collection. By default, only the return + values of all `__call__` methods are stored. A function can be passed to + change the filter behavior. The filter function takes the Module instance + and method name and returns a bool indicating whether the output of that + method invocation should be stored. + + Returns: + The init function wrapping ``fn``. + """ + + @functools.wraps(fn) + def scope_fn(scope, *args, **kwargs): + _context.capture_stack.append(capture_intermediates) + try: + return fn(module.clone(parent=scope, _deep_clone=True), *args, **kwargs) + finally: + _context.capture_stack.pop() + + if capture_intermediates is True: # pylint: disable=g-bool-id-comparison + capture_intermediates = capture_call_intermediates + if capture_intermediates: + mutable = union_filters(mutable, 'intermediates') + return core.init(scope_fn, mutable=mutable) + + +@traceback_util.api_boundary +def init( + fn: Callable[..., Any], + module: Module, + mutable: CollectionFilter = DenyList('intermediates'), + capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False, +) -> Callable[..., Union[FrozenVariableDict, Dict[str, Any]]]: + """Creates an init function to call ``fn`` with a bound module. + + Unlike ``Module.init`` this function returns a new function with the signature + ``(rngs, *args, **kwargs) -> variables``. + The rngs can be a dict of PRNGKeys or a single ```PRNGKey`` which is + equivalent to passing a dict with one PRNGKey with the name "params". + + The init function that is returned can be directly composed with + JAX transformations like ``jax.jit``:: + + >>> class Foo(nn.Module): + ... def encode(self, x): + ... ... + ... def decode(self, x): + ... ... + + >>> def f(foo, x): + ... z = foo.encode(x) + ... y = foo.decode(z) + ... # ... + ... return y + + >>> foo = Foo() + >>> f_jitted = jax.jit(nn.init(f, foo)) + >>> variables = f_jitted(jax.random.key(0), jnp.ones((1, 3))) + + Args: + fn: The function that should be applied. The first argument passed will be + a module instance of the ``module`` with variables and RNGs bound to it. + module: The ``Module`` that will be used to bind variables and RNGs to. The + ``Module`` passed as the first argument to ``fn`` will be a clone of + module. + mutable: Can be bool, str, or list. Specifies which collections should be + treated as mutable: ``bool``: all/no collections are mutable. ``str``: The + name of a single mutable collection. ``list``: A list of names of mutable + collections. By default, all collections except "intermediates" are + mutable. + capture_intermediates: If `True`, captures intermediate return values of all + Modules inside the "intermediates" collection. By default, only the return + values of all `__call__` methods are stored. A function can be passed to + change the filter behavior. The filter function takes the Module instance + and method name and returns a bool indicating whether the output of that + method invocation should be stored. + + Returns: + The init function wrapping ``fn``. + """ + init_fn = init_with_output(fn, module, mutable, capture_intermediates) + + @functools.wraps(init_fn) + def init_wrapper(*args, **kwargs): + return init_fn(*args, **kwargs)[1] + + return init_wrapper + + +# TODO(cgarciae): we are defining CompactNameScope just to +# avoid a pytype bug with the Flax overlay. We should aim to +# remove in the at some point as its not ergonomic. +if not typing.TYPE_CHECKING: + + class CompactNameScope(Module): + fn: Callable + module_fn: Callable[[], Module] + + @compact + def __call__(self, *args, **kwargs) -> Any: + return self.fn(self.module_fn(), *args, **kwargs) +else: + + @dataclasses.dataclass + class CompactNameScope: + fn: Callable + module_fn: Callable + name: str + + def __call__(self, *args, **kwargs) -> Any: + ... diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/spmd.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/spmd.py new file mode 100644 index 000000000..56a4b9677 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/spmd.py @@ -0,0 +1,364 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for working with jit and partitioned models. + +This module introduces ``axis_rules``, ``logical_to_mesh_axes``, +``logical_to_mesh``, ``with_logical_constraint`` for appyling jit sharding +constraints in terms of "logical named axes" rather than jit's default mesh +axes. + +Additionally the ``LogicallyPartitioned`` metadata wrapper is defined as +well as the initializer function wrapper ``with_logical_partitioning``for +introducing logical axis metadata into a model's variables. +""" + +import collections +import contextlib +import dataclasses +import enum +import functools +import threading +from typing import Any, Callable, List, Optional, Sequence, Tuple, Union + +import jax +from jax import lax +from jax.interpreters import pxla + +from flax import struct +from flax.core import meta +from flax.typing import ( + Array, + LogicalNames, + LogicalRules, + ArrayPytree, # pylint: disable=invalid-name + LogicalPartitionSpec, # pylint: disable=unused-import + LogicalPartitionSpecPytree, # pylint: disable=invalid-name + ) + + +# Dynamic Axis Mapping Context +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass +class _AxisRules(threading.local): + """Dynamic logical axis to mesh axis binding context.""" + + rules: LogicalRules = () + + +# Global axis binding context. +_axis_rules = _AxisRules() + + +def set_logical_axis_rules(rules: LogicalRules): + """Sets the global logical axis to mesh axis binding.""" + _axis_rules.rules = rules + + +def get_logical_axis_rules() -> LogicalRules: + """Returns the global logical axis to mesh axis binding.""" + return _axis_rules.rules + + +@contextlib.contextmanager +def logical_axis_rules(rules: LogicalRules): + """Context manager for setting the logical to mesh axis bindings.""" + old_rules = _axis_rules.rules + try: + _axis_rules.rules = rules + yield + finally: + _axis_rules.rules = old_rules + + +class _UnassignedAxis: + """Sentinel class for unassigned logical axis name.""" + + def __repr__(self): + return 'UnassignedAxis' + + def __bool__(self): + return False + + +_unassigned_axis = _UnassignedAxis() + + +def _mesh_assignment_free(new_assignment, existing_assignments): + """Determines if a given mesh axis has already been assigned.""" + new = set(jax.tree_util.tree_leaves(new_assignment)) + existing = set(jax.tree_util.tree_leaves(existing_assignments)) + if existing.intersection(new): + return False + return True + + +def _logical_to_mesh_axes( + array_dim_names: Optional[Sequence[Optional[str]]], + rules: Optional[LogicalRules] = None, +) -> Optional[List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]]]: + """Same as logical_to_mesh_axes, but doesn't fill in _unassigned_axis.""" + if array_dim_names is None: + return None + if rules is None: + rules = _axis_rules.rules + axis_name_counts = collections.Counter(array_dim_names) + dups = tuple( + k for k, v in axis_name_counts.items() if v > 1 and k is not None + ) + if dups: + raise ValueError( + f'Unsupported: Dimensions {dups} occur more than once in array names.' + ) + if not isinstance(rules, (tuple, list)): + raise ValueError('Unknown axis rule specification type.') + # We assign mesh axes using a priority based ruleset over logical axis names. + result: List[Union[_UnassignedAxis, None, str, Tuple[str, ...]]] + result = [ + (_unassigned_axis if isinstance(name, str) else name) + for name in array_dim_names + ] + for rule_model_name, rule_mesh_names in rules: + if rule_model_name in array_dim_names: + pos = array_dim_names.index(rule_model_name) + if ( + _mesh_assignment_free(rule_mesh_names, result) + and result[pos] == _unassigned_axis + ): + result[pos] = rule_mesh_names + return result + + +def logical_to_mesh_axes( + array_dim_names: Optional[Sequence[Optional[str]]], + rules: Optional[LogicalRules] = None, +) -> Optional[jax.sharding.PartitionSpec]: + """Compute layout for an array. + + The rules are in order of precedence, and consist of pairs: + ``(ArrayDimensionName, MeshDimensionName)``, meaning that the given array + dimension (if present and unused) should be sharded across the given + mesh dimension (if present and unused). + + A Layout of an Array is expressed as a tuple with one element for each + dimension in the Array. The element is either None, or is the name of a + mesh-dimension, meaning that this dimension of the array is sharded across + this dimension of the mesh. + + For example, given an array with:: + + array_dim_names = ('batch', 'length', 'heads', 'features') + + and the layout rules are:: + + rules = (('batch', 'X'), + ('features', 'X'), + ('heads', 'Y'), + ('batch', 'Z')) + + then this function will return:: + + PartitionSpec('X', None, 'Y', None) + + Args: + array_dim_names: Tuple of array dimension names or None. + rules: Optional logical to mesh rules override. Defaults to using the + rules defined in the dynamic context set from the ``axis_rules`` function. + + Returns: + PartitionSpec for the parameter. + """ + result = _logical_to_mesh_axes(array_dim_names, rules) + if result is None: + return None + # We default to None - ie unsharded along the dimension. + result = [None if x is _unassigned_axis else x for x in result] + return jax.sharding.PartitionSpec(*result) + + +def logical_to_mesh(tree: Any, rules: Optional[LogicalRules] = None) -> Any: + """Applies logical_to_mesh_axes to pytrees of logical PartitionSpecs.""" + return jax.tree_util.tree_map( + lambda x: logical_to_mesh_axes(x, rules), + tree, + is_leaf=lambda x: isinstance(x, jax.sharding.PartitionSpec), + ) + + +def logical_to_mesh_sharding( + tree: Any, + mesh: jax.sharding.Mesh, + rules: Optional[LogicalRules] = None, +) -> Any: + """Convert pytrees of logical PartitionSpecs to shardings.""" + return jax.tree_util.tree_map( + lambda x: jax.sharding.NamedSharding(mesh, x), + logical_to_mesh(tree, rules), + is_leaf=lambda x: isinstance(x, jax.sharding.PartitionSpec), + ) + + +def _global_mesh_defined() -> bool: + """Checks if global mesh resource environment is defined.""" + env = pxla.thread_resources.env + return env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison + + +class RulesFallback(enum.Enum): + """How a sharding constraint should behave when no matching rule is found.""" + + AXIS_IS_UNSHARDED = 'axis_is_unsharded' + RAISE_ERROR = 'raise_error' + NO_CONSTRAINT = 'no_constraint' + + +def _with_sharding_constraint( + x: Array, + axis_resources: Optional[jax.sharding.PartitionSpec], + mesh: Optional[jax.sharding.Mesh] = None, +): + """Wrapper for lax.with_sharding_constraint, no-op on cpu or outside jit.""" + if jax.devices()[0].platform == 'cpu' or ( + not _global_mesh_defined() and mesh is None + ): + return x + else: + if mesh is not None and axis_resources is not None: + sharding = jax.sharding.NamedSharding(mesh, axis_resources) + return lax.with_sharding_constraint(x, sharding) + return lax.with_sharding_constraint(x, axis_resources) + + +def _with_sharding_constraint_one_fallback( + axis_resources: LogicalPartitionSpec, + x: Array, + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, + rules: Optional[LogicalRules] = None, + mesh: Optional[jax.sharding.Mesh] = None, +): + """Either imposes a sharding constraint or applies fallback.""" + mesh_axes = _logical_to_mesh_axes(axis_resources, rules) + if mesh_axes is None: + return _with_sharding_constraint(x, None, mesh=mesh) + + if fallback == RulesFallback.AXIS_IS_UNSHARDED: + mesh_axes = [None if x is _unassigned_axis else x for x in mesh_axes] + else: + if any(x is _unassigned_axis for x in mesh_axes): + if fallback == RulesFallback.RAISE_ERROR: + raise ValueError(f'Axis names {axis_resources} did not match a rule') + else: + return x + return _with_sharding_constraint( + x, jax.sharding.PartitionSpec(*mesh_axes), mesh=mesh + ) + + +def _is_axis_spec(x): + return ( + isinstance(x, str) + or x is jax.sharding.PartitionSpec.UNCONSTRAINED + or x is None + ) + + +def _is_logical_spec(x): + return x is None or ( + isinstance(x, tuple) and all(_is_axis_spec(e) for e in x) + ) + + +def with_logical_constraint( + x: ArrayPytree, + logical_axis_resources: LogicalPartitionSpecPytree, + rules: Optional[LogicalRules] = None, + mesh: Optional[jax.sharding.Mesh] = None, + fallback: RulesFallback = RulesFallback.AXIS_IS_UNSHARDED, +): + """Version of jit's with_sharding_constraint that uses logical axis names.""" + # If no axis binding is set, this is a no-op. + if rules is None: + rules = _axis_rules.rules + if not rules or logical_axis_resources is None: + return x + # Translate logical names to mesh assignments. + return jax.tree_util.tree_map( + functools.partial( + _with_sharding_constraint_one_fallback, + fallback=fallback, + rules=rules, + mesh=mesh, + ), + logical_axis_resources, + x, + is_leaf=_is_logical_spec, + ) + + +# Logical Partitioning Axis Metadata +# ------------------------------------------------------------------------------ + + +class LogicallyPartitioned(meta.Partitioned): + rules: Optional[LogicalRules] = struct.field(default=None, pytree_node=False) + + def unbox(self, apply_constraint=True) -> Any: + """Returns the wrapped value with the partitioning constraint applied.""" + if apply_constraint and (_global_mesh_defined() or self.mesh is not None): + return with_logical_constraint( + self.value, + self.get_partition_spec(), + rules=self.rules, + mesh=self.mesh, + ) + else: + return self.value + + +def with_logical_partitioning( + fn: Callable[..., Any], + names: LogicalNames, + mesh: Optional[jax.sharding.Mesh] = None, + rules: Optional[LogicalRules] = None, +) -> Callable[..., LogicallyPartitioned]: + """Wraps a function's return value with LogicallyPartitioned. + + Example:: + + >>> import flax.linen as nn + >>> kernel_init = nn.with_logical_partitioning( + ... nn.initializers.lecun_normal(), (None, "data")) + >>> partitioned_dense = nn.Dense(features=3, kernel_init=kernel_init) + + Args: + fn: The function to be wrapped. Typically this is an initializer. + names: The logical axis passed to ``LogicallyPartitioned``. + mesh: The mesh to use for the partitioning. If None, the global mesh + resource is used if available. + rules: Optional logical to mesh rules use. If None, the global rules + are used if available. + Returns: + A function wrapping ``fn`` that will return an instance of + ``LogicallyPartitioned``. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return LogicallyPartitioned( + fn(*args, **kwargs), names, mesh=mesh, rules=rules + ) # pytype: disable=wrong-keyword-args + + return wrapper diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/transforms.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/transforms.py new file mode 100644 index 000000000..4a8843646 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/linen/transforms.py @@ -0,0 +1,2149 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2023 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""JAX transformations on Modules. + +Jax functional transformations operate on pure functions. +Flax extends these transformations to also operate on Module's which +have stateful variables and PRNG sequences. We refer to these extended +versions as "lifted transformations". + +A lifted transformation can be applied to a ``Module`` class or a +function that takes a ``Module`` instance as its first argument. +""" + +import dataclasses +import functools +import inspect +from typing import ( + Any, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + +from flax import core +from flax import errors, struct, traceback_util +from flax import serialization +from flax.core import Scope, lift, meta +from flax.core.frozen_dict import FrozenDict +from flax.core.scope import ( + CollectionFilter, + PRNGSequenceFilter, +) +from flax.ids import FlaxId +from flax.linen import module as linen_module +from flax.linen.module import ( + Module, + Variable, + _derive_profiling_name, + _get_unbound_fn, + wrap_method_once, +) +from flax.typing import ( + InOutAxis, + InOutScanAxis, +) +import jax + +traceback_util.register_exclusion(__file__) + +# pylint: disable=protected-access,dangerous-default-value + + +# Utils +# ----------------------------------------------------------------------------- +def clean_clone(x): + """Remove scopes and tracers from children.""" + if isinstance(x, Module): + object.__setattr__( + x, 'children', {k: clean_clone(v) for k, v in x.children.items()} + ) + object.__setattr__(x, 'scope', None) + return x + + +@struct.dataclass +class VariablePlaceholder: + """Used to mark Variables in a JAX-compatible way when lifting arguments.""" + + collection: str = struct.field(pytree_node=False) + name: str = struct.field(pytree_node=False) + unbox: bool = struct.field(pytree_node=False) + id: int = struct.field(pytree_node=False) + + +@struct.dataclass +class InstancePlaceholder: + """Marks module instances in a JAX-compatible way when lifting arguments.""" + + cls: Type[Any] = struct.field(pytree_node=False) + attrs: Dict[Any, Any] = struct.field(pytree_node=False) + id: int = struct.field(pytree_node=False) + + +def _memoize_by_id(fn, refs): + """Memoization by module/variable id to handle aliasing in traversal.""" + + @functools.wraps(fn) + def wrapped_fn(x): + nonlocal refs + if isinstance(x, (VariablePlaceholder, InstancePlaceholder)): + x_id = x.id + elif isinstance(x, (Variable, Module)): + x_id = x._id + else: + return fn(x) + if x_id not in refs: + refs[x_id] = fn(x) + return refs[x_id] + + return wrapped_fn + + +def get_module_scopes(module, args=None, kwargs=None): + """Get all scopes on module, including constructor Module arguments. + + To properly functionalize a Module that has other bound Modules passed in + "from the outside" as dataclass attributes, we need to traverse all dataclass + fields to find the Scopes associated with the Module. Additionally, because + we allow Modules to be passed inside pytrees on the dataclass attributes, we + must traverse all dataclass attributes as pytrees to find all Modules. We + additionally handle lifting Variables (which are just references to data in + particular scopes) and Module instances that are passed as arguments to + methods. + + Args: + module: a bound flax Module. + args: an *args list possibly containing Variables or Module instances + referencing a scope. + kwargs: a **kwargs dict possibly containing Variables or Module instances + referencing a scope. + + Returns: + A list of all functional-core Scopes bound on self and inside dataclass + fields as well as any Scopes passed via argument Variables, an updated args + list, and an updated kwargs dict that have both had Variables replaced with + VariablePlaceholders and Module instances replaced with InstancePlaceholders + that are compatible with jax functions. + """ + scopes = [] + refs = {} + + # Gather scopes associated with Variables and Module instances passed as + # positional and keyword arguments. + @functools.partial(_memoize_by_id, refs=refs) + def get_arg_scope(x): + nonlocal scopes + if isinstance(x, Variable) and isinstance(x.scope, Scope): + scopes.append(x.scope) + return VariablePlaceholder(x.collection, x.name, x.unbox, x._id) + elif isinstance(x, Module) and isinstance(x.scope, Scope): + x._try_setup(shallow=True) + scopes.append(x.scope) + attrs = { + f.name: getattr(x, f.name) + for f in dataclasses.fields(x) + if f.name != 'parent' and f.init + } + attrs = jax.tree_util.tree_map(get_arg_scope, attrs) + return InstancePlaceholder(x.__class__, attrs, x._id) + return x + + new_args, new_kwargs = jax.tree_util.tree_map(get_arg_scope, (args, kwargs)) + + # Gather scopes in Variables and Submodules passed as Module attributes. + @functools.partial(_memoize_by_id, refs=refs) + def get_scopes(module): + nonlocal scopes + module._try_setup(shallow=True) + + def get_scopes_inner(x): + nonlocal scopes + if isinstance(x, Module) and isinstance(x.scope, Scope): + get_scopes(x) + elif isinstance(x, Variable) and isinstance(x.scope, Scope): + scopes.append(x.scope) + + attrs = { + f.name: getattr(module, f.name) + for f in dataclasses.fields(module) + if f.name != 'parent' and f.init + } + for leaf in jax.tree_util.tree_leaves(attrs): + get_scopes_inner(leaf) + scopes.append(module.scope) + + get_scopes(module) + return scopes, new_args, new_kwargs + + +def set_module_scopes(module, args, kwargs, scopes): + """Set all scopes on module, including those on Modules in dataclass fields. + + To properly functionalize a Module we must also "rehydrate" it with Scopes + from `get_module_scopes`. We need to set scopes not just on the Module but + also on any Module living inside dataclass attributes or even pytrees in its + dataclass attributes. We additionally handle restoring Variables and Module + instances from their placeholders in the method positional and keyword + arguments. The order of traversal through this method is the same as in + `get_module_scopes`, guaranteeing the correct Scopes are applied to each + Module. + + Args: + module: a flax Module. + args: an *args list possibly containing VariablePlaceholder or + InstancePlaceholder members. + kwargs: a **kwargs dict possibly containing VariablePlaceholder or + InstancePlaceholder members. + scopes: a list of Scopes corresponding to this Module and its arguments that + was created by the `get_module_scopes` function. + + Returns: + A copy of the module with it and its attributes bound to the scopes passed + to this function, an updated args list, and an updated kwargs dict with + updated Variable and Module instance references. + """ + idx = 0 + refs = {} + + # Set scopes associated with Variables and Module instances passed as + # positional and keyword arguments. + @functools.partial(_memoize_by_id, refs=refs) + def set_arg_scope(x): + nonlocal idx + if isinstance(x, VariablePlaceholder): + new_x = Variable( + scope=scopes[idx], collection=x.collection, name=x.name, unbox=x.unbox + ) + idx += 1 + return new_x + elif isinstance(x, InstancePlaceholder): + instance_scope = scopes[idx] + idx += 1 + instance_attrs = jax.tree_util.tree_map(set_arg_scope, x.attrs) + return x.cls(parent=instance_scope, **instance_attrs) + return x + + def is_placeholder(x): + return isinstance(x, (VariablePlaceholder, InstancePlaceholder)) + + new_args, new_kwargs = jax.tree_util.tree_map( + set_arg_scope, (args, kwargs), is_leaf=is_placeholder + ) + + # set scopes in Variables and Submodules passed as Module attributes + @functools.partial(_memoize_by_id, refs=refs) + def set_scopes(module): + nonlocal idx + + def set_scopes_inner(x): + nonlocal idx + if isinstance(x, Module) and isinstance(x.scope, Scope): + return set_scopes(x) + elif isinstance(x, Variable) and isinstance(x.scope, Scope): + new_x = Variable( + scope=scopes[idx], + collection=x.collection, + name=x.name, + unbox=x.unbox, + ) + idx += 1 + return new_x + else: + return x + + attrs = { + f.name: getattr(module, f.name) + for f in dataclasses.fields(module) + if f.name != 'parent' and f.init + } + new_attrs = jax.tree_util.tree_map(set_scopes_inner, attrs) + new_module = module.clone(parent=scopes[idx], **new_attrs) + idx += 1 + return new_module + + new_module = set_scopes(module) + assert len(scopes) == idx, f'scope list mismatch {len(scopes)} != {idx}' + return new_module, new_args, new_kwargs + + +def _test_transformed_return_values(tree, method_name): + """Tests whether the return value contains any Modules or Variables.""" + impure = any( + map( + lambda x: isinstance(x, (Module, Variable)), + jax.tree_util.tree_leaves(tree), + ) + ) + if impure: + raise errors.TransformedMethodReturnValueError(method_name) + + +# Class lifting +# ----------------------------------------------------------------------------- +def module_class_lift_transform( + transform, module_class, *trafo_args, methods=None, **trafo_kwargs +): + """Module class lift transform.""" + # TODO(marcvanzee): Improve docstrings (#1977). + # TODO(levskaya): find nicer argument convention for multi-method case? + + # Prepare per-method transform args, kwargs. + if methods is None: + # Default case, just transform __call__ + class_trafo_args = {'__call__': (trafo_args, trafo_kwargs)} + elif isinstance(methods, (list, tuple)): + # Transform every method in methods with given args, kwargs. + class_trafo_args = {m: (trafo_args, trafo_kwargs) for m in methods} + elif isinstance(methods, dict): + # Pass different trafo args per each method. + class_trafo_args = {k: ((), v) for k, v in methods.items()} + else: + raise ValueError( + 'transform methods argument must be None, tuple, list, or dict.' + ) + + # Handle partially initialized module class constructors. + if isinstance(module_class, functools.partial) and issubclass( + module_class.func, Module + ): + partial_object = module_class + module_class = module_class.func + else: + partial_object = None + + def create_trans_fn(fn_name, fn_trafo_args): + # get existing unbound method from class + fn = getattr(module_class, fn_name) + trafo_args, trafo_kwargs = fn_trafo_args + + # we need to create a scope-function from our class for the given method + @functools.wraps(fn) + def wrapped_fn(self, *args, **kwargs): + state = self._state.export() + + # make a scope-function to transform + def core_fn(scopes, *args, **kwargs): + # make a clone of self using its arguments + attrs = { + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.name != 'parent' and f.init + } + # we reference module_class, not self.__class__ to avoid infinite loop + cloned = module_class(parent=None, **attrs) + cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) + object.__setattr__(cloned, '_state', state.export()) + res = fn(cloned, *args, **kwargs) + self._state.reimport(cloned._state) + _test_transformed_return_values(res, fn_name) + return res + + # here we apply the given lifting transform to the scope-ingesting fn + trafo_fn = transform(core_fn, *trafo_args, **trafo_kwargs) + module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) + ret = trafo_fn(module_scopes, *args, **kwargs) + return ret + + return wrapped_fn + + transformed_fns = { + fn_name: create_trans_fn(fn_name, fn_trafo_args) + for fn_name, fn_trafo_args in class_trafo_args.items() + } + # construct new dynamic class w. transformed methods + transformed_cls = type( + transform.__name__.capitalize() + module_class.__name__, + (module_class,), + transformed_fns, + ) + # Handle partially initialized module class constructors. + if partial_object is not None: + transformed_cls = functools.partial( + transformed_cls, *partial_object.args, **partial_object.keywords + ) + return transformed_cls + + +# Function lifting as decorator on methods __inside__ class definition. +# ----------------------------------------------------------------------------- +def decorator_lift_transform( + transform, class_fn, *trafo_args, multi_scope=True, **trafo_kwargs +): + """Decorator for lifted transform.""" + # TODO(marcvanzee): Improve docstrings (#1977). + # Due to the ordering of method decorators, we must wrap the class_fn + # with the module state management wrapper first to maintain Module state + # correctly. + if isinstance(class_fn, tuple): + class_fns = class_fn + else: + class_fns = (class_fn,) + prewrapped_fns = [wrap_method_once(class_fn) for class_fn in class_fns] + + @functools.wraps(prewrapped_fns[0]) + def wrapped_fn(self, *args, **kwargs): + state = self._state.export() + + # make a scope-function to transform + def core_fn(prewrapped_fn, class_fn, scopes, *args, **kwargs): + if not multi_scope: + scopes = [scopes] + cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) + object.__setattr__(cloned, '_state', state.export()) + res = prewrapped_fn(cloned, *args, **kwargs) + self._state.reimport(cloned._state) + _test_transformed_return_values(res, getattr(class_fn, '__name__', None)) + return res + + core_fns = [ + functools.partial(core_fn, prewrapped_fn, class_fn) + for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) + ] + # here we apply the given lifting transform to the scope-ingesting fn + trafo_fn = transform(*core_fns, *trafo_args, **trafo_kwargs) + module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) + if not multi_scope: + if len(module_scopes) != 1: + # TODO(levskaya): transforms like jvp & vjp have args that follow the + # pytree structure of scopes. The user doesn't explicitly control shared + # modules passed as arguments to methods or as attributes to Module + # constructors. Therefore, there is no obvious API for specifying + # arguments per lifted Module. + raise NotImplementedError( + 'This transform does not yet support' + ' Modules that include other Modules passed as arguments.' + ) + module_scopes = module_scopes[0] + return trafo_fn(module_scopes, *args, **kwargs) + + return wrapped_fn + + +@dataclasses.dataclass(frozen=True) +class _HashableProxy: + """A hashable proxy object that is use to define a hash for Modules. + + The hash produced by _HashableProxy is useful for nn.jit to decide if a + function should be retraced or not + """ + + module: Module + hash_key: int + + @classmethod + def from_module(cls, module: Module) -> '_HashableProxy': + fingerprint = _module_fingerprint(module) + hash_key = hash(fingerprint) + return cls(module, hash_key) + + def __hash__(self): + return self.hash_key + + def __eq__(self, other): + return isinstance(other, _HashableProxy) and self.hash_key == other.hash_key + + +def _module_fingerprint(module: Module) -> tuple[type[Any], Any]: + return _fingerprint_recursive(module, (), {}) + + +def _fingerprint_recursive( + obj: Any, path: tuple[str, ...], seen_modules: dict[FlaxId, int] +) -> Any: + """Creates a hashable representation for a Module by traversing its structure recursively.""" + + def _get_fingerprint(name: str, value: Any) -> tuple[str, Any]: + return name, _fingerprint_recursive(value, (*path, name), seen_modules) + + if isinstance(obj, str): + return obj + elif isinstance(obj, Module): + fingerprint: Any + if obj._id in seen_modules: + # if we have already seen the module we just use the index + # as its static component + fingerprint = seen_modules[obj._id] + return type(obj), fingerprint + else: + # if its a new module we add it to the cache and give it + # a new index + seen_modules[obj._id] = len(seen_modules) + # TODO(cgarciae): define a way for the user of nn.jit to define + # what fields it wants to ignore per Module instance. + fingerprints = [] + for field in dataclasses.fields(obj): + if not hasattr(obj, field.name): + continue + if field.name not in ('parent', 'name'): + value = getattr(obj, field.name) + fingerprints.append(_get_fingerprint(field.name, value)) + # add state fingerprint + state_fingerprint = ( + _get_fingerprint('in_compact_method', obj._state.in_compact_method), + _get_fingerprint('in_setup', obj._state.in_setup), + _get_fingerprint('setup_called', obj._state.setup_called), + _get_fingerprint('is_initialized', obj._state.is_initialized), + _get_fingerprint('autoname_cursor', obj._state.autoname_cursor), + ) + fingerprints.append(('_state', state_fingerprint)) + # add scope fingerprint + scope = obj.scope + if scope is not None: + static_scope = ( + _get_fingerprint('mutable', scope.mutable), + _get_fingerprint('flags', scope.flags), + _get_fingerprint('rng_counts', scope.rng_counters), + _get_fingerprint('reservations', scope.reservations), + ) + _check_field_is_hashable((*path, 'scope'), static_scope) + fingerprints.append(('scope', static_scope)) + fingerprint = tuple(fingerprints) + return type(obj), fingerprint + elif dataclasses.is_dataclass(obj): + fingerprints = [] + for field in dataclasses.fields(obj): + if not hasattr(obj, field.name): + continue + value = getattr(obj, field.name) + value_fingerprint = _get_fingerprint(field.name, value) + fingerprints.append((field.name, value_fingerprint)) + return type(obj), tuple(fingerprints) + elif isinstance(obj, core.DenyList): + return type(obj), _get_fingerprint('deny', obj.deny) + elif isinstance(obj, dict): + fingerprint = tuple((k, _get_fingerprint(k, v)) for k, v in obj.items()) + return fingerprint + elif serialization.is_serializable(obj): + state = serialization.to_state_dict(obj) + fingerprint = _fingerprint_recursive(state, path, seen_modules) + return type(obj), fingerprint + elif isinstance(obj, Mapping): + return tuple((k, _get_fingerprint(k, v)) for k, v in obj.items()) + elif isinstance(obj, Iterable): + return tuple(_get_fingerprint(str(i), v) for i, v in enumerate(obj)) + else: + _check_field_is_hashable(path, obj) + return obj + + +def _check_field_is_hashable(path: tuple[str, ...], x: Any): + """Checks if a field is hashable.""" + try: + hash(x) + except Exception as e: + path_name = '/'.join(path) + raise ValueError(f"Value at '{path_name}' is not hashable: {e}") from e + + +def decorator_lift_transform_jit(class_fn, **trafo_kwargs): + """Decorator for lifted transform. + + Similar to `decorator_lift_transform` but specialized for `jit`, it reuses the + previous transform when available to avoid retracing. + """ + # TODO(marcvanzee): Improve docstrings (#1977). + # Due to the ordering of method decorators, we must wrap the class_fn + # with the module state management wrapper first to maintain Module state + # correctly. + transform = lift.jit + multi_scope = True + + if isinstance(class_fn, tuple): + class_fns = class_fn + else: + class_fns = (class_fn,) + prewrapped_fns = [wrap_method_once(class_fn) for class_fn in class_fns] + trafo_fn = None + + @functools.wraps(prewrapped_fns[0]) + def wrapped_fn(self: Module, *args, **kwargs): + nonlocal trafo_fn + state = self._state.export() + + # make a scope-function to transform + def core_fn( + prewrapped_fn, + class_fn, + scopes, + module_hash, + *args, + **kwargs, + ): + # self = hash_key.obj + self: Module = module_hash.module + if not multi_scope: + scopes = [scopes] + cloned, args, kwargs = set_module_scopes(self, args, kwargs, scopes) + object.__setattr__(cloned, '_state', state.export()) + res = prewrapped_fn(cloned, *args, **kwargs) + self._state.reimport(cloned._state) + _test_transformed_return_values(res, getattr(class_fn, '__name__', None)) + return res + + core_fns = [ + functools.partial(core_fn, prewrapped_fn, class_fn) + for prewrapped_fn, class_fn in zip(prewrapped_fns, class_fns) + ] + + # here we apply the given lifting transform to the scope-ingesting fn + if trafo_fn is None: + trafo_fn = transform(*core_fns, **trafo_kwargs) + + module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) + if not multi_scope: + if len(module_scopes) != 1: + # TODO(levskaya): transforms like jvp & vjp have args that follow the + # pytree structure of scopes. The user doesn't explicitly control shared + # modules passed as arguments to methods or as attributes to Module + # constructors. Therefore, there is no obvious API for specifying + # arguments per lifted Module. + raise NotImplementedError( + 'This transform does not yet support' + ' Modules that include other Modules passed as arguments.' + ) + module_scopes = module_scopes[0] + + # get a hashable proxy object for the Module + hash_key = _HashableProxy.from_module(self) + + return trafo_fn(module_scopes, hash_key, *args, **kwargs) + + return wrapped_fn + + +def module_class_lift_transform_jit(module_class, methods=None, **trafo_kwargs): + """Module class lift transform.""" + # TODO(marcvanzee): Improve docstrings (#1977). + # TODO(levskaya): find nicer argument convention for multi-method case? + transform = lift.jit + trafo_args = () + + # Prepare per-method transform args, kwargs. + if methods is None: + # Default case, just transform __call__ + class_trafo_args = {'__call__': (trafo_args, trafo_kwargs)} + elif isinstance(methods, (list, tuple)): + # Transform every method in methods with given args, kwargs. + class_trafo_args = {m: (trafo_args, trafo_kwargs) for m in methods} + elif isinstance(methods, dict): + # Pass different trafo args per each method. + class_trafo_args = {k: ((), v) for k, v in methods.items()} + else: + raise ValueError( + 'transform methods argument must be None, tuple, list, or dict.' + ) + + # Handle partially initialized module class constructors. + if isinstance(module_class, functools.partial) and issubclass( + module_class.func, Module + ): + partial_object = module_class + module_class = module_class.func + else: + partial_object = None + + def create_trans_fn(fn_name, fn_trafo_args): + # get existing unbound method from class + fn = getattr(module_class, fn_name) + trafo_args, trafo_kwargs = fn_trafo_args + trafo_fn = None + + # we need to create a scope-function from our class for the given method + @functools.wraps(fn) + def wrapped_fn(self: Module, *args, **kwargs): + nonlocal trafo_fn + state = self._state.export() + + # make a scope-function to transform + def core_fn(scopes, module_hash, *args, **kwargs): + self: Module = module_hash.module + # make a clone of self using its arguments + attrs = { + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if f.name != 'parent' and f.init + } + # we reference module_class, not self.__class__ to avoid infinite loop + cloned = module_class(parent=None, **attrs) + cloned, args, kwargs = set_module_scopes(cloned, args, kwargs, scopes) + object.__setattr__(cloned, '_state', state.export()) + res = fn(cloned, *args, **kwargs) + self._state.reimport(cloned._state) + _test_transformed_return_values(res, fn_name) + return res + + # here we apply the given lifting transform to the scope-ingesting fn + trafo_fn = trafo_fn or transform(core_fn, *trafo_args, **trafo_kwargs) + module_scopes, args, kwargs = get_module_scopes(self, args, kwargs) + + # get a hash for the Module by using its repr as a proxy + hash_key = _HashableProxy.from_module(self) + + ret = trafo_fn(module_scopes, hash_key, *args, **kwargs) + return ret + + return wrapped_fn + + transformed_fns = { + fn_name: create_trans_fn(fn_name, fn_trafo_args) + for fn_name, fn_trafo_args in class_trafo_args.items() + } + # construct new dynamic class w. transformed methods + transformed_cls = type( + transform.__name__.capitalize() + module_class.__name__, + (module_class,), + transformed_fns, + ) + # Handle partially initialized module class constructors. + if partial_object is not None: + transformed_cls = functools.partial( + transformed_cls, *partial_object.args, **partial_object.keywords + ) + return transformed_cls + + +# Utility to wrap a class or to use as decorator in def of class method. +# ----------------------------------------------------------------------------- + +TransformTarget = Union[Type[Module], Callable[..., Any]] +Target = TypeVar('Target', bound=TransformTarget) + + +def _is_module_class(target: TransformTarget) -> bool: + return ( + inspect.isclass(target) + and issubclass(target, Module) + or (isinstance(target, functools.partial)) + and _is_module_class(target.func) + ) + + +def lift_transform( + transform, target, *trafo_args, methods=None, **trafo_kwargs +): + """Applies to class or as a decorator on class fns.""" + # TODO(marcvanzee): Improve docstrings (#1977). + if _is_module_class(target): + return module_class_lift_transform( + transform, target, *trafo_args, methods=methods, **trafo_kwargs + ) + # we presume this is being used as a function decorator in class definition + elif callable(target) and not isinstance(target, Module): + return decorator_lift_transform( + transform, target, *trafo_args, **trafo_kwargs + ) + else: + raise errors.TransformTargetError(target) + + +def lift_direct_transform( + transform: Callable[..., Any], + targets: Tuple[Callable[..., Any], ...], + mdl: Module, + *args, + multi_scope=True, + **kwargs, +): + """Lift direct transform.""" + # TODO(marcvanzee): Improve docstrings (#1977). + for target in targets: + if _is_module_class(target): + raise ValueError( + f'The {transform.__name__} transform can only be applied on a Module' + ' method. That is function that takes a Module instance as its first' + ' arg.' + ) + elif not callable(target): + raise ValueError('transform target must be callable') + # normalize self.foo bound methods to class.foo unbound methods. + targets = tuple(_get_unbound_fn(target) for target in targets) + aug_transform = lambda *fns: functools.partial(transform, *fns) + return decorator_lift_transform( + aug_transform, targets, multi_scope=multi_scope + )(mdl, *args, **kwargs) + + +def vmap( + target: Target, + variable_axes: Mapping[CollectionFilter, InOutAxis] = FrozenDict(), + split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), + in_axes=0, + out_axes=0, + axis_size: Optional[int] = None, + axis_name: Optional[str] = None, + spmd_axis_name: Optional[str] = None, + metadata_params: Mapping[Any, Any] = {}, + methods=None, +) -> Target: + """A lifted version of ``jax.vmap``. + + See ``jax.vmap`` for the unlifted batch transform in Jax. + + ``vmap`` can be used to add a batch axis to a ``Module``. + For example we could create a version of ``Dense`` with + a batch axis that does not share parameters:: + + >>> import flax.linen as nn + >>> BatchDense = nn.vmap( + ... nn.Dense, + ... in_axes=0, out_axes=0, + ... variable_axes={'params': 0}, + ... split_rngs={'params': True}) + + By using ``variable_axes={'params': 0}``, we indicate that the + parameters themselves are mapped over and therefore not shared along + the mapped axis. Consequently, we also split the 'params' RNG, + otherwise the parameters would be initialized identically along + the mapped axis. + + Similarly, ``vmap`` could be used to add a batch axis with parameter + sharing:: + + >>> import flax.linen as nn + >>> BatchDense = nn.vmap( + ... nn.Dense, + ... in_axes=0, out_axes=0, + ... variable_axes={'params': None}, + ... split_rngs={'params': False}) + + Here we use ``variable_axes={'params': None}`` to indicate the parameter + variables are shared along the mapped axis. Consequently, the 'params' + RNG must also be shared. + + Args: + target: a ``Module`` or a function taking a ``Module`` as its first + argument. + variable_axes: the variable collections that are lifted into the batching + transformation. Use ``None`` to indicate a broadcasted collection or an + integer to map over an axis. For example, passing in + ``variable_axes={'params': None}`` will indicate that the + parameter variables should be shared along the mapped axis. + split_rngs: Split PRNG sequences will be different for each index of the + batch dimension. Unsplit PRNGs will be broadcasted. + in_axes: Specifies the mapping of the input arguments (see ``jax.vmap``). + out_axes: Specifies the mapping of the return value (see ``jax.vmap``). + axis_size: Specifies the size of the batch axis. This only needs to be + specified if it cannot be derived from the input arguments. + axis_name: Specifies a name for the batch axis. Can be used together with + parallel reduction primitives (e.g. ``jax.lax.pmean``, ``jax.lax.ppermute``, + etc.). Note, this is only used for pmap and shard map. For SPMD jit, you + do not need to manually synchronize. Just make sure that the axes are + correctly annotated and XLA:SPMD will insert the necessary collectives. + methods: If ``target`` is a ``Module``, the methods of ``Module`` to vmap over. + spmd_axis_name: Axis name added to any pjit sharding constraints appearing + in ``fn``. See also + https://github.com/google/flax/blob/main/flax/linen/partitioning.py. + metadata_params: arguments dict passed to AxisMetadata instances in the + variable tree. + + Returns: + A batched/vectorized version of ``target``, with the same arguments but with + extra axes at positions indicated by ``in_axes``, and the same return value, + but with extra axes at positions indicated by ``out_axes``. + """ + return lift_transform( + lift.vmap, + target, + variable_axes, + split_rngs, + methods=methods, + in_axes=in_axes, + out_axes=out_axes, + axis_size=axis_size, + axis_name=axis_name, + metadata_params=metadata_params, + spmd_axis_name=spmd_axis_name, + ) + + +def jit( + target: Target, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + static_argnums: Union[int, Iterable[int]] = (), + static_argnames: Union[str, Iterable[str]] = (), + donate_argnums: Union[int, Iterable[int]] = (), + device=None, + backend: Union[str, None] = None, + methods=None, +) -> Target: + """Lifted version of ``jax.jit``. + + Args: + target: a ``Module`` or a function taking a ``Module`` as its first + argument. + variables: The variable collections that are lifted. By default all + collections are lifted. + rngs: The PRNG sequences that are lifted. By default all PRNG sequences are + lifted. + static_argnums: An int or collection of ints specifying which positional + arguments to treat as static (compile-time constant). Operations that only + depend on static arguments will be constant-folded in Python (during + tracing), and so the corresponding argument values can be any Python + object. Static arguments should be hashable, meaning both ``__hash__`` and + ``__eq__`` are implemented, and immutable. Calling the jitted function + with different values for these constants will trigger recompilation. If + the jitted function is called with fewer positional arguments than + indicated by ``static_argnums`` then an error is raised. Arguments that + are not arrays or containers thereof must be marked as static. Defaults to + (). + static_argnames: An optional string or collection of strings specifying + which named arguments to treat as static (compile-time constant). See the + comment on ``static_argnums`` for details. If not provided but + ``static_argnums`` is set, the default is based on calling + ``inspect.signature(fun)`` to find corresponding named arguments. + donate_argnums: Specify which arguments are "donated" to the computation. It + is safe to donate arguments if you no longer need them once the + computation has finished. In some cases XLA can make use of donated + buffers to reduce the amount of memory needed to perform a computation, + for example recycling one of your input buffers to store a result. You + should not reuse buffers that you donate to a computation, JAX will raise + an error if you try to. + device: This is an experimental feature and the API is likely to change. + Optional, the Device the jitted function will run on. (Available devices + can be retrieved via :py:func:`jax.devices`.) The default is inherited + from XLA's DeviceAssignment logic and is usually to use + ``jax.devices()[0]``. + backend: a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + methods: If ``target`` is a ``Module``, the methods of ``Module`` to jit. + + Returns: + A wrapped version of target, set up for just-in-time compilation. + """ + # TODO(marcvanzee): Improve docstrings (#1977). + if _is_module_class(target): + return module_class_lift_transform_jit( + target, + variables=variables, + rngs=rngs, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + device=device, + backend=backend, + methods=methods, + ) + # we presume this is being used as a function decorator in class definition + elif callable(target) and not isinstance(target, Module): + return decorator_lift_transform_jit( + target, + variables=variables, + rngs=rngs, + static_argnums=static_argnums, + static_argnames=static_argnames, + donate_argnums=donate_argnums, + device=device, + backend=backend, + ) + else: + raise errors.TransformTargetError(target) + + +def checkpoint( + target: Target, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + concrete: bool = False, + prevent_cse: bool = True, + static_argnums: Union[int, Tuple[int, ...]] = (), + policy: Optional[Callable[..., bool]] = None, + methods=None, +) -> Target: + """Lifted version of ``jax.checkpoint``. + + Checkpointing is a technique for reducing memory usage by recomputing + activations during backpropagation. When training large models, it can be + helpful to checkpoint parts of the model to trade off memory usage for + additional computation. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> import flax.linen as nn + ... + >>> class CheckpointedMLP(nn.Module): + ... @nn.checkpoint + ... @nn.compact + ... def __call__(self, x): + ... x = nn.Dense(128)(x) + ... x = nn.relu(x) + ... x = nn.Dense(1)(x) + ... return x + ... + >>> model = CheckpointedMLP() + >>> variables = model.init(jax.random.key(0), jnp.ones((1, 16))) + + This function is aliased to ``remat`` just like ``jax.remat``. + + Args: + target: a ``Module`` or a function taking a ``Module`` + as its first argument. intermediate computations will be + re-computed when computing gradients for the target. + variables: The variable collections that are lifted. By default all + collections are lifted. + rngs: The PRNG sequences that are lifted. By default all PRNG sequences + are lifted. + concrete: Optional, boolean indicating whether ``fun`` may involve + value-dependent Python control flow (default ``False``). Support for such + control flow is optional, and disabled by default, because in some + edge-case compositions with :func:`jax.jit` it can lead to some extra + computation. + prevent_cse: Optional, boolean indicating whether to prevent common + subexpression elimination (CSE) optimizations in the HLO generated from + differentiation. This CSE prevention has costs because it can foil other + optimizations, and because it can incur high overheads on some backends, + especially GPU. The default is True because otherwise, under a ``jit`` or + ``pmap``, CSE can defeat the purpose of this decorator. But in some + settings, like when used inside a ``scan``, this CSE prevention mechanism + is unnecessary, in which case ``prevent_cse`` should be set to False. + static_argnums: Optional, int or sequence of ints, indicates which argument + values on which to specialize for tracing and caching purposes. Specifying + arguments as static can avoid ConcretizationTypeErrors when tracing, but + at the cost of more retracing overheads. + policy: Experimental checkpoint policy, see ``jax.checkpoint``. + methods: An optional list of method names that will be lifted, if ``methods`` + is None (default) only the ``__call__`` method will be lifted. If``target`` + is a function, ``methods`` is ignored. + + Returns: + A wrapped version of ``target``. When computing gradients intermediate + computations will be re-computed on the backward pass. + """ + # subtract 1 from each static_argnums because 'self' is not passed to the + # lifted function + static_argnums = jax.tree_util.tree_map(lambda x: x - 1, static_argnums) + return lift_transform( + lift.checkpoint, + target, + variables=variables, + rngs=rngs, + concrete=concrete, + static_argnums=static_argnums, + prevent_cse=prevent_cse, + policy=policy, + methods=methods, + ) + + +remat = checkpoint + + +def remat_scan( + target: Target, + lengths: Optional[Sequence[int]] = (), + policy: Optional[Callable[..., bool]] = None, + variable_broadcast: CollectionFilter = False, + variable_carry: CollectionFilter = False, + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict( + {True: 0} + ), + split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict({True: True}), +) -> Target: + """Combines remat and scan for memory efficiency and constant time compilation. + + ``remat_scan`` allows for constant compile times and sublinear + memory usage with respect to model depth. At a small constant + penalty. This is typically beneficial for very deep models. + + Example:: + + >>> import flax.linen as nn + + >>> class BigModel(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... DenseStack = nn.remat_scan(nn.Dense, lengths=(10, 10)) + ... # 100x dense with O(sqrt(N)) memory for gradient computation + ... return DenseStack(8, name="dense_stack")(x) + + Args: + target: a ``Module`` or a function taking a ``Module`` as its first + argument. + lengths: number of loop iterations at the given level. The total number of + iterations ``n = prod(lengths)``. each loop is rematerialized. This way the + memory consumption is proportional to ``n^(1 / d)`` where ``d = + len(lengths)``. Minimal memory consumptions requires tuning the lengths + such that the same amount of memory is consumed at each level of the + nested loop. + policy: Experimental checkpoint policy, see ``jax.checkpoint``. + variable_broadcast: Specifies the broadcasted variable collections. A + broadcasted variable should not depend on any computation that cannot be + lifted out of the loop. This is typically used to define shared parameters + inside the fn. + variable_carry: Specifies the variable collections that are carried through + the loop. Mutations to these variables are carried to the next iteration + and will be preserved when the scan finishes. + variable_axes: the variable collections that are scanned over. Defaults to + ``{True: 0}``. + split_rngs: Split PRNG sequences will be different for each loop iterations. + If split is False the PRNGs will be the same across iterations. Defaults + to ``{True: True}``. + + Returns: + A wrapped version of ``target`` that repeats itself prod(lengths) times. + """ + return lift_transform( + lift.remat_scan, + target, + lengths=lengths, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + variable_axes=variable_axes, + split_rngs=split_rngs, + policy=policy, + ) + + +def scan( + target: Target, + variable_axes: Mapping[CollectionFilter, InOutScanAxis] = FrozenDict(), + variable_broadcast: CollectionFilter = False, + variable_carry: CollectionFilter = False, + split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), + in_axes=0, + out_axes=0, + length: Optional[int] = None, + reverse: bool = False, + unroll: int = 1, + data_transform: Optional[Callable[..., Any]] = None, + metadata_params: Mapping[Any, Any] = {}, + methods=None, + _split_transpose: bool = False, +) -> Target: + """A lifted version of ``jax.lax.scan``. + + See ``jax.lax.scan`` for the unlifted scan in Jax. + + To improve consistency with ``vmap``, this version of scan + uses ``in_axes`` and ``out_axes`` to determine which arguments + are scanned over and along which axis. + + ``scan`` distinguishes between 3 different types of values inside the loop: + + #. **scan**: a value that is iterated over in a loop. All scan values must + have the same size in the axis they are scanned over. Scanned outputs + will be stacked along the scan axis. + + #. **carry**: A carried value is updated at each loop iteration. It must + have the same shape and dtype throughout the loop. + + #. **broadcast**: a value that is closed over by the loop. When a variable + is broadcasted they are typically initialized inside the loop body but + independent of the loop variables. + + The ``target`` should have the signature + ``(module, carry, *xs) -> (carry, ys)``, where ``xs`` and ``ys`` + are the scan values that go in and out of the loop. + + Example:: + + >>> import flax.linen as nn + >>> import jax + >>> import jax.numpy as jnp + ... + >>> class LSTM(nn.Module): + ... features: int + ... + ... @nn.compact + ... def __call__(self, x): + ... ScanLSTM = nn.scan( + ... nn.LSTMCell, variable_broadcast="params", + ... split_rngs={"params": False}, in_axes=1, out_axes=1) + ... + ... lstm = ScanLSTM(self.features) + ... input_shape = x[:, 0].shape + ... carry = lstm.initialize_carry(jax.random.key(0), input_shape) + ... carry, x = lstm(carry, x) + ... return x + ... + >>> x = jnp.ones((4, 12, 7)) + >>> module = LSTM(features=32) + >>> y, variables = module.init_with_output(jax.random.key(0), x) + + Note that when providing a function to ``nn.scan``, the scanning happens over + all arguments starting from the third argument, as specified by ``in_axes``. + The previous example could also be written using the functional form as:: + + >>> class LSTM(nn.Module): + ... features: int + ... + ... @nn.compact + ... def __call__(self, x): + ... + ... cell = nn.LSTMCell(self.features) + ... def body_fn(cell, carry, x): + ... carry, y = cell(carry, x) + ... return carry, y + ... scan = nn.scan( + ... body_fn, variable_broadcast="params", + ... split_rngs={"params": False}, in_axes=1, out_axes=1) + ... + ... input_shape = x[:, 0].shape + ... carry = cell.initialize_carry( + ... jax.random.key(0), input_shape) + ... carry, x = scan(cell, carry, x) + ... return x + ... + >>> module = LSTM(features=32) + >>> variables = module.init(jax.random.key(0), jnp.ones((4, 12, 7))) + + You can also use ``scan`` to reduce the compilation time of your JAX program + by merging multiple layers into a single scan loop, you can do this when + you have a sequence of identical layers that you want to apply iteratively + to an input. For example:: + + >>> class ResidualMLPBlock(nn.Module): + ... @nn.compact + ... def __call__(self, x, _): + ... h = nn.Dense(features=2)(x) + ... h = nn.relu(h) + ... return x + h, None + ... + >>> class ResidualMLP(nn.Module): + ... n_layers: int = 4 + ... + ... @nn.compact + ... def __call__(self, x): + ... ScanMLP = nn.scan( + ... ResidualMLPBlock, variable_axes={'params': 0}, + ... variable_broadcast=False, split_rngs={'params': True}, + ... length=self.n_layers) + ... x, _ = ScanMLP()(x, None) + ... return x + ... + >>> model = ResidualMLP(n_layers=4) + >>> variables = model.init(jax.random.key(42), jnp.ones((1, 2))) + + To reduce both compilation and memory usage, you can use :func:`remat_scan` + which will in addition checkpoint each layer in the scan loop. + + Args: + target: a ``Module`` or a function taking a ``Module`` as its first + argument. + variable_axes: the variable collections that are scanned over. + variable_broadcast: Specifies the broadcasted variable collections. A + broadcasted variable should not depend on any computation that cannot be + lifted out of the loop. This is typically used to define shared parameters + inside the fn. + variable_carry: Specifies the variable collections that are carried through + the loop. Mutations to these variables are carried to the next iteration + and will be preserved when the scan finishes. + split_rngs: Split PRNG sequences will be different for each loop iterations. + If split is False the PRNGs will be the same across iterations. + in_axes: Specifies the axis to scan over for the arguments. Should be a + prefix tree of the arguments. Use ``flax.core.broadcast`` to feed an entire + input to each iteration of the scan body. + out_axes: Specifies the axis to scan over for the return value. Should be a + prefix tree of the return value. + length: Specifies the number of loop iterations. This only needs to be + specified if it cannot be derived from the scan arguments. + reverse: If true, scan from end to start in reverse order. + unroll: how many scan iterations to unroll within a single iteration of a + loop (default: 1). + data_transform: optional function to transform raw functional-core variable + and rng groups inside lifted scan body_fn, intended for inline SPMD + annotations. + metadata_params: arguments dict passed to AxisMetadata instances in the + variable tree. + methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over. + _split_transpose: An experimental feature to split the transpose of a scan + into a scan and a map, backed by an experimental Jax lax.scan() feature. + + Returns: + The scan function with the signature ``(module, carry, *xs) -> (carry, + ys)``, where ``xs`` and ``ys`` are the scan values that go in and out of + the loop. + """ + return lift_transform( + lift.scan, + target, + variable_axes=variable_axes, + variable_broadcast=variable_broadcast, + variable_carry=variable_carry, + split_rngs=split_rngs, + in_axes=in_axes, + out_axes=out_axes, + length=length, + reverse=reverse, + unroll=unroll, + _split_transpose=_split_transpose, + data_transform=data_transform, + metadata_params=metadata_params, + methods=methods, + ) + + +def map_variables( + target: Target, + mapped_collections: CollectionFilter = True, + trans_in_fn: Callable[..., Any] = lift.id_fn, + trans_out_fn: Callable[..., Any] = lift.id_fn, + init: bool = False, + mutable: bool = False, + rngs: PRNGSequenceFilter = True, + variables: CollectionFilter = True, + methods=None, +) -> Target: + """Map Variables inside a module. + + ``map_variables`` can be used to transform the variables inside a module + both before and after the module is applied. This is useful among other + things for masking the weights of a module without having to modify the + module itself. + + Example:: + + >>> import jax + >>> import jax.numpy as jnp + >>> import flax.linen as nn + ... + >>> class CausalDense(nn.Module): + ... '''A dense layer that masks the weights such that the output is + ... causal, i.e. output i only depends on input <= i. + ... ''' + ... features: int + ... + ... def apply_mask(self, variables): + ... return (jax.tree_util.tree_map(jnp.triu, variables) + ... if not self.is_initializing() else variables) + ... + ... def setup(self): + ... # temporary class + ... _CausalDense = nn.map_variables( + ... nn.Dense, 'params', self.apply_mask, init=self.is_initializing()) + ... self.dense = _CausalDense(features=self.features, use_bias=False) + ... + ... def __call__(self, x): + ... return self.dense(x) + ... + >>> module = CausalDense(features=5) + >>> variables = module.init(jax.random.key(0), jnp.ones((1, 5))) + + Args: + target: the module or function to be transformed. + mapped_collections: the collection(s) to be transformed. + trans_in_fn: modifies the variables before applying the module or function. + trans_out_fn: modifies the variables after applying the module or function, + it is only applied if either ``init`` or ``mutable`` are not False. + init: If True, variables are initialized before transformation. + mutable: If True, the mapped variable collections will be mutable. + rngs: PRNGSequences added to the transformed scope (default: all). + variables: Additional Variable collections added to the transformed scope. + Besides those specified by ``target`` (default: all). + methods: If ``target`` is a ``Module``, the methods of ``Module`` to map + variables for. + + Returns: + a wrapped version of ``target`` that will map the specified collections. + """ + + return lift_transform( + lift.map_variables, + target, + mapped_collections, + trans_in_fn, + trans_out_fn, + init, + mutable, + rngs, + variables, + methods=methods, + ) + + +def vjp( + fn: Callable[..., Any], + mdl: Module, + *primals, + has_aux: bool = False, + reduce_axes=(), + vjp_variables: CollectionFilter = 'params', + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, + multi_scope: bool = False, +): + """A lifted version of ``jax.vjp``. + + See ``jax.vjp`` for the unlifted vector-Jacobian product (backward gradient). + + Note that a gradient is returned for all variables in the collections + specified by ``vjp_variables``. However, the backward function only expects + a cotangent for the return value of ``fn``. If variables require a co-tangent + as well they can be returned from ``fn`` using ``Module.variables``. + + Example:: + + >>> import flax.linen as nn + >>> import jax.numpy as jnp + + >>> class LearnScale(nn.Module): + ... @nn.compact + ... def __call__(self, x, y): + ... p = self.param('scale', nn.initializers.zeros_init(), ()) + ... return p * x * y + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x, y): + ... z, bwd = nn.vjp(lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) + ... params_grad, x_grad, y_grad = bwd(jnp.ones(z.shape)) + ... return z, params_grad, x_grad, y_grad + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + mdl: The module of which the variables will be differentiated. + *primals: A sequence of primal values at which the Jacobian of ``fn`` + should be evaluated. The length of ``primals`` should be equal to the + number of positional parameters to ``fn``. Each primal value should be a + tuple of arrays, scalar, or standard Python containers thereof. + has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default ``False``. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fn`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + VJP will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a VJP function that sums over the batch while ``vjp(f, *args)`` + will create a per-example VJP. + vjp_variables: The vjpfun will return a cotangent vector for all + variable collections specified by this filter. + variables: other variables collections that are available inside ``fn`` but + do not receive a cotangent. + rngs: the prngs that are available inside ``fn``. + multi_scope: for Modules containing multiple scopes from outside modules passed in, + allow for variable gradients to be returned for multiple scopes instead of erroring. + Returns: + If ``has_aux`` is ``False``, returns a ``(primals_out, vjpfun)`` pair, where + ``primals_out`` is ``fn(*primals)``. + ``vjpfun`` is a function from a cotangent vector with the same shape as + ``primals_out`` to a tuple of cotangent vectors with the same shape as + ``primals``, representing the vector-Jacobian product of ``fn`` evaluated at + ``primals``. If ``has_aux`` is ``True``, returns a + ``(primals_out, vjpfun, aux)`` tuple where ``aux`` is the auxiliary data + returned by ``fn``. + """ + return lift_direct_transform( + lift.vjp, + (fn,), + mdl, + *primals, + multi_scope=multi_scope, + has_aux=has_aux, + reduce_axes=reduce_axes, + vjp_variables=vjp_variables, + variables=variables, + rngs=rngs, + ) + + +def value_and_grad( + fn: Callable[..., Any], + mdl: Module, + *primals, + has_aux: bool = False, + reduce_axes=(), + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +): + """A limited, lifted equivalent of ``jax.value_and_grad``. + + Note that for this convenience function, gradients are only calculated for + the function inputs, and not with respect to any module variables. The + target function must return a scalar-valued output. For a more general + lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobian product. + + Example:: + + class LearnScale(nn.Module): + @nn.compact + def __call__(self, x, y): + p = self.param('scale', nn.initializers.zeros_init(), ()) + return p * x * y + + class Foo(nn.Module): + @nn.compact + def __call__(self, x, y): + z, (x_grad, y_grad) = nn.value_and_grad( + lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) + return z, x_grad, y_grad + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + mdl: The module of which the variables will be differentiated. + *primals: A sequence of primal values at which the Jacobian of ``fn`` + should be evaluated. The length of ``primals`` should be equal to the + number of positional parameters to ``fn``. Each primal value should be a + tuple of arrays, scalar, or standard Python containers thereof. + has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default ``False``. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fn`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + grad will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a grad function that sums over the batch while ``grad(f, *args)`` + will create a per-example grad. + variables: variables collections that are available inside ``fn`` but + do not receive a cotangent. + rngs: the prngs that are available inside ``fn``. + Returns: + If ``has_aux`` is ``False``, returns a ``primals_out, grads`` pair, where + ``primals_out`` is ``fn(*primals)``. ``grads`` are the gradients for the + corresponding primals and do not include the gradients for module variables. + If ``has_aux`` is ``True``, returns a + ``(primals_out, aux), grads`` tuple where ``aux`` is the auxiliary data + returned by ``fn``. + """ + + grad_partial = functools.partial( + lift_direct_transform, + lift.value_and_grad, + (fn,), + mdl, + *primals, + has_aux=has_aux, + reduce_axes=reduce_axes, + variables=variables, + rngs=rngs, + ) + + if has_aux: + out, aux, argument_grads = grad_partial() + if out.shape != (): + raise ValueError( + 'grad can only work on functions with ' + f'scalar-valued outputs. out shape={out.shape}' + ) + return (out, aux), argument_grads + else: + out, argument_grads = grad_partial() + if out.shape != (): + raise ValueError( + 'grad can only work on functions with ' + f'scalar-valued outputs. out shape={out.shape}' + ) + return out, argument_grads + + +def grad( + fn: Callable[..., Any], + mdl: Module, + *primals, + has_aux: bool = False, + reduce_axes=(), + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +): + """A limited, lifted equivalent of ``jax.grad``. + + Note that for this convenience function, gradients are only calculated for + the function inputs, and not with respect to any module variables. The + target function must return a scalar-valued output. For a more general + lifted vjp, see ``nn.vjp`` for the lifted vector-Jacobian product. + + Example:: + + class LearnScale(nn.Module): + @nn.compact + def __call__(self, x, y): + p = self.param('scale', nn.initializers.zeros_init(), ()) + return p * x * y + + class Foo(nn.Module): + @nn.compact + def __call__(self, x, y): + x_grad, y_grad = nn.grad( + lambda mdl, x, y: mdl(x, y), LearnScale(), x, y) + return x_grad, y_grad + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + mdl: The module of which the variables will be differentiated. + *primals: A sequence of primal values at which the Jacobian of ``fn`` + should be evaluated. The length of ``primals`` should be equal to the + number of positional parameters to ``fn``. Each primal value should be a + tuple of arrays, scalar, or standard Python containers thereof. + has_aux: Optional, bool. Indicates whether ``fn`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default ``False``. + reduce_axes: Optional, tuple of axis names. If an axis is listed here, and + ``fn`` implicitly broadcasts a value over that axis, the backward pass + will perform a ``psum`` of the corresponding gradient. Otherwise, the + grad will be per-example over named axes. For example, if ``'batch'`` + is a named batch axis, ``vjp(f, *args, reduce_axes=('batch',))`` will + create a grad function that sums over the batch while ``grad(f, *args)`` + will create a per-example grad. + variables: variables collections that are available inside ``fn`` but + do not receive a cotangent. + rngs: the prngs that are available inside ``fn``. + Returns: + If ``has_aux`` is ``False``, returns ``grads``, where ``grads`` are the + gradients for the corresponding primals and do not include the gradients + for module variables. + If ``has_aux`` is ``True``, returns a + ``(grads, aux)`` tuple where ``aux`` is the auxiliary data + returned by ``fn``. + """ + + value_and_grad_partial = functools.partial( + value_and_grad, + fn, + mdl, + *primals, + has_aux=has_aux, + reduce_axes=reduce_axes, + variables=variables, + rngs=rngs, + ) + + if has_aux: + (_, aux), argument_grads = value_and_grad_partial() + return argument_grads, aux + else: + _, argument_grads = value_and_grad_partial() + return argument_grads + + +def jvp( + fn: Callable[..., Any], + mdl: Module, + primals, + tangents, + variable_tangents, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> Union[Tuple[Any, Callable[..., Any]], Tuple[Any, Callable[..., Any], Any]]: + """A lifted version of ``jax.jvp``. + + See ``jax.jvp`` for the unlifted Jacobian-vector product (forward gradient). + + Note that no tangents are returned for variables. When variable tangents + are required their value should be returned explicitly by ``fn`` + using ``Module.variables``:: + + >>> import flax.linen as nn + >>> import jax.numpy as jnp + + >>> class LearnScale(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... p = self.param('test', nn.initializers._init(), ()) + ... return p * x + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... scale = LearnScale() + ... vars_t = jax.tree_util.tree_map(jnp.ones_like, + ... scale.variables.get('params', {})) + ... _, out_t = nn.jvp( + ... lambda mdl, x: mdl(x), scale, (x,), (jnp.zeros_like(x),), + ... variable_tangents={'params': vars_t}) + ... return out_t + + Example:: + + >>> def learn_scale(scope, x): + ... p = scope.param('scale', nn.initializers.zeros_init(), ()) + ... return p * x + + >>> def f(scope, x): + ... vars_t = jax.tree_util.tree_map(jnp.ones_like, scope.variables().get('params', {})) + ... x, out_t = lift.jvp( + ... learn_scale, scope, (x,), (jnp.zeros_like(x),), + ... variable_tangents={'params': vars_t}) + ... return out_t + + Args: + fn: Function to be differentiated. Its arguments should be arrays, scalars, + or standard Python containers of arrays or scalars. It should return an + array, scalar, or standard Python container of arrays or scalars. It will + receive the scope and primals as arguments. + mdl: The module of which the variables will be differentiated. + primals: The primal values at which the Jacobian of ``fun`` should be + evaluated. Should be either a tuple or a list of arguments, + and its length should be equal to the number of positional parameters of + ``fun``. + tangents: The tangent vector for which the Jacobian-vector product should be + evaluated. Should be either a tuple or a list of tangents, with the same + tree structure and array shapes as ``primals``. + variable_tangents: A dict or PyTree fo dicts with the same structure as + scopes. Each entry in the dict specifies the tangents for a variable + collection. Not specifying a collection in variable_tangents is + equivalent to passing a zero vector as the tangent. + variables: other variables collections that are available in ``fn`` but + do not receive a tangent. + rngs: the prngs that are available inside ``fn``. + + Returns: + A ``(primals_out, tangents_out)`` pair, where ``primals_out`` is + ``fun(*primals)``, and ``tangents_out`` is the Jacobian-vector product of + ``function`` evaluated at ``primals`` with ``tangents``. The + ``tangents_out`` value has the same Python tree structure and shapes as + ``primals_out``. + """ + return lift_direct_transform( + lift.jvp, + (fn,), + mdl, + primals, + tangents, + variable_tangents, + multi_scope=False, + variables=variables, + rngs=rngs, + ) + + +ModuleT = TypeVar('ModuleT', bound=Module) +C = TypeVar('C') + + +def while_loop( + cond_fn: Callable[[ModuleT, C], bool], + body_fn: Callable[[ModuleT, C], C], + mdl: ModuleT, + init: C, + carry_variables: CollectionFilter = False, + broadcast_variables: CollectionFilter = True, + split_rngs: Mapping[PRNGSequenceFilter, bool] = FrozenDict(), +) -> C: + """Lifted version of jax.lax.while_loop. + + The lifted scope is passed to ``cond_fn`` and ``body_fn``. + Broadcasted variables are immutable. The carry variable are + mutable but cannot change shape and dtype. + This also means you cannot initialize variables inside + the body. Consider calling ``body_fn`` once manually before + calling ``while_loop`` if variable initialization is required. + + Example:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class WhileLoopExample(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... def cond_fn(mdl, c): + ... return mdl.variables['state']['acc'] < 10 + ... def body_fn(mdl, c): + ... acc = mdl.variable('state', 'acc', lambda: jnp.array(0)) + ... acc.value += 1 + ... y = nn.Dense(c.shape[-1])(c) + ... return y + ... c = x + ... if self.is_mutable_collection('params'): + ... return body_fn(self, c) + ... else: + ... return nn.while_loop(cond_fn, body_fn, self, c, + ... carry_variables='state') + + >>> k = jax.random.key(0) + >>> x = jnp.ones((2, 2)) + >>> initial_vars = WhileLoopExample().init(k, x) + >>> result, state = WhileLoopExample().apply(initial_vars, x, mutable=['state']) + + Args: + cond_fn: Should return True as long as the loop should continue. + body_fn: The body of the while loop. + mdl: The Module which should be lifted into the loop. + init: The initial state passed to the loop + carry_variables: collections that are carried through the loop + and are therefore mutable (default: none). + broadcast_variables: collections that are closed over and are + therefore read-only (default: all collections) + split_rngs: Split PRNG sequences will be different for each loop iterations. + If split is False the PRNGs will be the same across iterations. + Returns: + The final state after executing the while loop. + """ + return lift_direct_transform( + lift.while_loop, + (cond_fn, body_fn), + mdl, + init, + carry_variables, + broadcast_variables, + split_rngs, + ) + + +def _cond_wrapper(t_fn, f_fn, scope, pred, *ops, variables, rngs): + return lift.cond( + pred, t_fn, f_fn, scope, *ops, variables=variables, rngs=rngs + ) + + +def cond( + pred: Any, + true_fun: Callable[..., C], + false_fun: Callable[..., C], + mdl: Module, + *operands, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> C: + """Lifted version of ``jax.lax.cond``. + + The returned values from ``true_fun`` and ``false_fun`` + must have the same Pytree structure, shapes, and dtypes. + The variables created or updated inside the + branches must also have the same structure. + Note that this constraint is violated when + creating variables or submodules in only one branch. + Because initializing variables in just one branch + causes the parameter structure to be different. + + Example:: + + >>> import flax.linen as nn + + >>> class CondExample(nn.Module): + ... @nn.compact + ... def __call__(self, x, pred): + ... self.variable('state', 'true_count', lambda: 0) + ... self.variable('state', 'false_count', lambda: 0) + ... def true_fn(mdl, x): + ... mdl.variable('state', 'true_count').value += 1 + ... return nn.Dense(2, name='dense')(x) + ... def false_fn(mdl, x): + ... mdl.variable('state', 'false_count').value += 1 + ... return -nn.Dense(2, name='dense')(x) + ... return nn.cond(pred, true_fn, false_fn, self, x) + + Args: + pred: determines if true_fun or false_fun is evaluated. + true_fun: The function evaluated when ``pred`` is ``True``. + The signature is (module, *operands) -> T. + false_fun: The function evaluated when ``pred`` is ``False``. + The signature is (module, *operands) -> T. + mdl: A Module target to pass. + *operands: The arguments passed to ``true_fun`` and ``false_fun`` + variables: The variable collections passed to the conditional + branches (default: all) + rngs: The PRNG sequences passed to the conditionals (default: all) + Returns: + The result of the evaluated branch (``true_fun`` or ``false_fun``). + """ + return lift_direct_transform( + _cond_wrapper, + (true_fun, false_fun), + mdl, + pred, + *operands, + variables=variables, + rngs=rngs, + ) + + +def _switch_wrapper(*args, variables, rngs, n_branches): + # first n_branches arguments are branches. + # then scope, index, and the rest are *operands + branches = args[:n_branches] + scope, index, *operands = args[n_branches:] + return lift.switch( + index, branches, scope, *operands, variables=variables, rngs=rngs + ) + + +def switch( + index: Any, + branches: Sequence[Callable[..., C]], + mdl: Module, + *operands, + variables: CollectionFilter = True, + rngs: PRNGSequenceFilter = True, +) -> C: + """Lifted version of ``jax.lax.switch``. + + The returned values from ``branches`` + must have the same Pytree structure, shapes, and dtypes. + The variables created or updated inside the + branches must also have the same structure. + Note that this constraint is violated when + creating variables or submodules in only one branch. + Because initializing variables in just one branch + causes the parameter structure to be different. + + Example:: + + >>> import flax.linen as nn + + >>> class SwitchExample(nn.Module): + ... @nn.compact + ... def __call__(self, x, index): + ... self.variable('state', 'a_count', lambda: 0) + ... self.variable('state', 'b_count', lambda: 0) + ... self.variable('state', 'c_count', lambda: 0) + ... def a_fn(mdl, x): + ... mdl.variable('state', 'a_count').value += 1 + ... return nn.Dense(2, name='dense')(x) + ... def b_fn(mdl, x): + ... mdl.variable('state', 'b_count').value += 1 + ... return -nn.Dense(2, name='dense')(x) + ... def c_fn(mdl, x): + ... mdl.variable('state', 'c_count').value += 1 + ... return nn.Dense(2, name='dense')(x) + ... return nn.switch(index, [a_fn, b_fn, c_fn], self, x) + + If you want to have a different parameter structure for each branch + you should run all branches on initialization before calling switch:: + + >>> class MultiHeadSwitchExample(nn.Module): + ... def setup(self) -> None: + ... self.heads = [ + ... nn.Sequential([nn.Dense(10), nn.Dense(7), nn.Dense(5)]), + ... nn.Sequential([nn.Dense(11), nn.Dense(5)]), + ... nn.Dense(5), + ... ] + ... + ... @nn.compact + ... def __call__(self, x, index): + ... def head_fn(i): + ... return lambda mdl, x: mdl.heads[i](x) + ... branches = [head_fn(i) for i in range(len(self.heads))] + ... + ... # run all branches on init + ... if self.is_mutable_collection('params'): + ... for branch in branches: + ... _ = branch(self, x) + ... + ... return nn.switch(index, branches, self, x) + + Args: + index: Integer scalar type, indicating which branch function to apply. + branches: Sequence of functions to be applied based on index. + The signature of each function is (module, *operands) -> T. + mdl: A Module target to pass. + *operands: The arguments passed to the branches. + variables: The variable collections passed to the conditional + branches (default: all) + rngs: The PRNG sequences passed to the conditionals (default: all) + Returns: + The result of the evaluated branch. + """ + return lift_direct_transform( + _switch_wrapper, + tuple(branches), + mdl, + index, + *operands, + variables=variables, + rngs=rngs, + n_branches=len(branches), + ) + + +# a version of lift.custom_vjp with a single scope function +# this avoids having to lift multiple functions in +# lift_transform. +def _custom_vjp_single_scope_fn( + fn: Callable[..., Any], + backward_fn: Callable[..., Any], + grad_vars: CollectionFilter = 'params', + nondiff_argnums=(), +): + nodiff_fn = functools.partial(fn, needs_residual=False) + forward_fn = functools.partial(fn, needs_residual=True) + return lift.custom_vjp( + nodiff_fn, forward_fn, backward_fn, grad_vars, nondiff_argnums + ) + + +def custom_vjp( + fn: Callable[..., Any], + forward_fn: Callable[..., Any], + backward_fn: Callable[..., Any], + grad_vars: CollectionFilter = 'params', + nondiff_argnums=(), +): + """Lifted version of ``jax.custom_vjp``. + + ``forward_fn`` and ``backward_fn`` together define a custom vjp for ``fn``. + The original ``fn`` will run in case a vjp (backward gradient) is not computed. + + The ``forward_fn`` receives the same arguments as ``fn`` but is expected to return + a tuple containing the output of ``fn(mdl, *args)`` and the residuals that are + passed to ``backward_fn``. + + The ``backward_fn`` receives the nondiff arguments, residuals, and the output + tangents. It should return a tuple containing the variable and input tangents. + + Note that the vjp function returned by ``nn.vjp`` can be passed as residual and + used in the ``backward_fn``. The scope is unavailable during the backward pass. + If the module is required in ``backward_fn``, a snapshot of the variables can + be taken and returned as a residual in the ``forward_fn``. + + Example:: + + >>> import flax.linen as nn + >>> import jax, jax.numpy as jnp + + >>> class Foo(nn.Module): + ... @nn.compact + ... def __call__(self, x): + ... def f(mdl, x): + ... return mdl(x) + ... + ... def fwd(mdl, x): + ... return nn.vjp(f, mdl, x) + ... + ... def bwd(vjp_fn, y_t): + ... params_t, *inputs_t = vjp_fn(y_t) + ... params_t = jax.tree_util.tree_map(jnp.sign, params_t) + ... return (params_t, *inputs_t) + ... + ... sign_grad = nn.custom_vjp( + ... f, forward_fn=fwd, backward_fn=bwd) + ... return sign_grad(nn.Dense(1), x).reshape(()) + + >>> x = jnp.ones((2,)) + >>> variables = Foo().init(jax.random.key(0), x) + >>> grad = jax.grad(Foo().apply)(variables, x) + + Args: + fn: The function to define a custom_vjp for. + forward_fn: A function with the same arguments as ``fn`` returning an tuple + with the original output and the residuals that will be passsed to + ``backward_fn``. + backward_fn: arguments are passed as + ``(*nondiff_args, residuals, tangents)`` The function should return a + tuple containing the tangents for the variable in the collections + specified by ``grad_vars`` and the input arguments (except the module and + nondiff args). + grad_vars: The collections for which a vjp will be computed + (default: "params"). + nondiff_argnums: arguments for which no vjp is computed. + Returns: + A function with the same signature as ``fn`` with the custom vjp. + """ + + def shared_forward_fn(*args, needs_residual, **kwargs): + if needs_residual: + return forward_fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + return decorator_lift_transform( + _custom_vjp_single_scope_fn, + shared_forward_fn, + backward_fn=backward_fn, + grad_vars=grad_vars, + nondiff_argnums=nondiff_argnums, + multi_scope=False, + ) + + +def named_call(class_fn, force=True): + """Labels a method for labelled traces in profiles. + + Note that it is better to use the `jax.named_scope` context manager directly + to add names to JAX's metadata name stack. + + Args: + class_fn: The class method to label. + force: If True, the named_call transform is applied even if it is globally + disabled. (e.g.: by calling `flax.linen.disable_named_call()`) + Returns: + A wrapped version of ``class_fn`` that is labeled. + """ + + # We use JAX's dynamic name-stack named_call. No transform boundary needed! + @functools.wraps(class_fn) + def wrapped_fn(self, *args, **kwargs): + if (not force and not linen_module._use_named_call) or self._state.in_setup: # pylint: disable=protected-access # pylint: disable=protected-access + return class_fn(self, *args, **kwargs) + full_name = _derive_profiling_name(self, class_fn) + return jax.named_call(class_fn, name=full_name)(self, *args, **kwargs) + + return wrapped_fn + + +def add_metadata_axis( + target: Target, + variable_axes: Mapping[CollectionFilter, InOutAxis] = FrozenDict(), + metadata_params: Dict[Any, Any] = {}, +) -> Target: + """A helper to manipulate boxed axis metadata. + + This is a helper to manipulate the *metadata* in boxed variables, similar + to how lifted ``vmap`` and ``scan`` will handle the introduction and stripping + of the new metadata axis across a transform boundary. + + Args: + target: a ``Module`` or a function taking a ``Module`` + as its first argument. + variable_axes: the variable collections whose axis metadata is being + transformed. Use `None` to indicate a broadcasted collection or an integer + to specify an axis index for an introduced axis. + methods: If `target` is a `Module`, the methods of `Module` to vmap over. + metadata_params: arguments dict passed to AxisMetadata instances in the + variable tree. + Returns: + A transformed version of ``target`` that performs a transform of the + axis metadata on its variables. + """ + + def add_fn(axis): + return lambda x: meta.add_axis(x, axis, metadata_params) + + def remove_fn(axis): + return lambda x: meta.remove_axis(x, axis, metadata_params) + + for col_name, axis in variable_axes.items(): + target = map_variables( + target, + col_name, + trans_in_fn=remove_fn(axis), + trans_out_fn=add_fn(axis), + mutable=True, + ) + return target diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/training/train_state.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/training/train_state.py new file mode 100644 index 000000000..1188dedc0 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/flax/flax/training/train_state.py @@ -0,0 +1,138 @@ +# Copyright 2024 The Flax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Union + +import optax + +import jax +from flax import core, struct +from flax.linen.fp8_ops import OVERWRITE_WITH_GRADIENT + + +class TrainState(struct.PyTreeNode): + """Simple train state for the common case with a single Optax optimizer. + + Example usage:: + + >>> import flax.linen as nn + >>> from flax.training.train_state import TrainState + >>> import jax, jax.numpy as jnp + >>> import optax + + >>> x = jnp.ones((1, 2)) + >>> y = jnp.ones((1, 2)) + >>> model = nn.Dense(2) + >>> variables = model.init(jax.random.key(0), x) + >>> tx = optax.adam(1e-3) + + >>> state = TrainState.create( + ... apply_fn=model.apply, + ... params=variables['params'], + ... tx=tx) + + >>> def loss_fn(params, x, y): + ... predictions = state.apply_fn({'params': params}, x) + ... loss = optax.l2_loss(predictions=predictions, targets=y).mean() + ... return loss + >>> loss_fn(state.params, x, y) + Array(3.3514676, dtype=float32) + + >>> grads = jax.grad(loss_fn)(state.params, x, y) + >>> state = state.apply_gradients(grads=grads) + >>> loss_fn(state.params, x, y) + Array(3.343844, dtype=float32) + + Note that you can easily extend this dataclass by subclassing it for storing + additional data (e.g. additional variable collections). + + For more exotic usecases (e.g. multiple optimizers) it's probably best to + fork the class and modify it. + + Args: + step: Counter starts at 0 and is incremented by every call to + ``.apply_gradients()``. + apply_fn: Usually set to ``model.apply()``. Kept in this dataclass for + convenience to have a shorter params list for the ``train_step()`` function + in your training loop. + params: The parameters to be updated by ``tx`` and used by ``apply_fn``. + tx: An Optax gradient transformation. + opt_state: The state for ``tx``. + """ + + step: Union[int, jax.Array] + apply_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] = struct.field(pytree_node=True) + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState = struct.field(pytree_node=True) + + def apply_gradients(self, *, grads, **kwargs): + """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. + + Note that internally this function calls ``.tx.update()`` followed by a call + to ``optax.apply_updates()`` to update ``params`` and ``opt_state``. + + Args: + grads: Gradients that have the same pytree structure as ``.params``. + **kwargs: Additional dataclass attributes that should be ``.replace()``-ed. + + Returns: + An updated instance of ``self`` with ``step`` incremented by one, ``params`` + and ``opt_state`` updated by applying ``grads``, and additional attributes + replaced as specified by ``kwargs``. + """ + if OVERWRITE_WITH_GRADIENT in grads: + grads_with_opt = grads['params'] + params_with_opt = self.params['params'] + else: + grads_with_opt = grads + params_with_opt = self.params + + updates, new_opt_state = self.tx.update( + grads_with_opt, self.opt_state, params_with_opt + ) + new_params_with_opt = optax.apply_updates(params_with_opt, updates) + + # As implied by the OWG name, the gradients are used directly to update the + # parameters. + if OVERWRITE_WITH_GRADIENT in grads: + new_params = { + 'params': new_params_with_opt, + OVERWRITE_WITH_GRADIENT: grads[OVERWRITE_WITH_GRADIENT], + } + else: + new_params = new_params_with_opt + return self.replace( + step=self.step + 1, + params=new_params, + opt_state=new_opt_state, + **kwargs, + ) + + @classmethod + def create(cls, *, apply_fn, params, tx, **kwargs): + """Creates a new instance with ``step=0`` and initialized ``opt_state``.""" + # We exclude OWG params when present because they do not need opt states. + params_with_opt = ( + params['params'] if OVERWRITE_WITH_GRADIENT in params else params + ) + opt_state = tx.init(params_with_opt) + return cls( + step=0, + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/attentions.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/attentions.py new file mode 100644 index 000000000..61f116375 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/attentions.py @@ -0,0 +1,1003 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Attentions Layers.""" + +import functools +import math +from typing import Optional, Sequence + +from flax import linen as nn +import jax +from jax import lax +from jax import random +from jax.ad_checkpoint import checkpoint_name +from jax.experimental import shard_map +from jax.experimental.pallas.ops import attention as pallas_attention +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask +from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel +import jax.numpy as jnp + +import common_types +from layers import embeddings +from layers import initializers +from layers import linears +from layers import quantizations + + +Array = common_types.Array +Config = common_types.Config +DType = common_types.DType +Mesh = common_types.Mesh +PRNGKey = common_types.PRNGKey + +DenseGeneral = linears.DenseGeneral +RotaryEmbedding = embeddings.RotaryEmbedding +NdInitializer = initializers.NdInitializer +Quant = quantizations.AqtQuantization + +AxisNames = common_types.AxisNames +BATCH = common_types.BATCH +LENGTH = common_types.LENGTH +HEAD = common_types.HEAD +D_KV = common_types.D_KV +DEFAULT_MASK_VALUE = -0.7 * float(jnp.finfo(jnp.dtype("float32")).max) + + +nd_dense_init = initializers.nd_dense_init +shard_map = shard_map.shard_map + +dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + +# pylint: disable=line-too-long, g-doc-args, g-doc-return-or-yield, bad-continuation, g-inconsistent-quotes +# pytype: disable=attribute-error + + +def apply_mask_to_logits(logits: Array, mask: Array): + """Applies a floating-point mask to a set of logits. + + The mask is represented as a tensor with some dtype where 0 represents true and values + below a large negative number (here set to + get_large_negative_number(logits.dtype) / 2) represent false. Applying the mask + leaves the logits alone in the true case and replaces them by + get_large_negative_number(logits.dtype) in the false case. Previously, this was + done by adding the logits to the mask; however, this leads to a bad fusion + decision in the compiler that saves the values in memory rather than + just the predicate. This implementation avoids that problem. + + from https://github.com/google/praxis/blob/4712a6b9ee13e224b86e235ff55f7c6bab9fbab3/praxis/py_utils.py#L706 + + Args: + logits: A JTensor of logit values. + mask: A JTensor of mask values with the encoding described in the + function documentation. + + Returns: + Masked logits. + """ + return jnp.where((mask >= DEFAULT_MASK_VALUE * 0.5), logits, DEFAULT_MASK_VALUE) + + +def _maybe_aqt_einsum(quant: Quant): + """Maybe overwrite dot general with aqt_dot_general.""" + return jnp.einsum if quant is None else quant.einsum() + + +class AttentionOp(nn.Module): + mesh: Mesh + attention_kernel: str + max_target_length: int + num_query_heads: int + num_kv_heads: int + float32_qk_product: bool = False + max_prefill_predict_length: int = -1 + float32_logits: bool = False + flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + dropout_rate: float = 0.0 + dtype: DType = jnp.float32 + quant: Optional[Quant] = None + quantize_kvcache: bool = False + + def check_attention_inputs(self, query: Array, key: Array, value: Array) -> None: + """Check attention inputs.""" + + assert key.ndim == value.ndim, "k, v must have same rank." + assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match." + assert key.shape[-2] == value.shape[-2], "k, v num_kv_heads must match." + assert key.shape[-3] == value.shape[-3], "k, v lengths must match." + assert query.shape[-1] == key.shape[-1], "q, k depths must match." + + # Following Pallas MHA Flash Attention Reference. + # https://github.com/google/jax/blob/main/jax/experimental/pallas/ops/tpu/flash_attention.py + # This mask models (1) separate sequences (decoder_segment_ids) and (2) causality + def generate_attention_mask(self, query, key, decoder_segment_ids: Array | None, model_mode: str) -> Array | None: + mask = None + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + mask = decoder_segment_ids[:, None, None, None, :] == common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + elif decoder_segment_ids is not None: + mask = decoder_segment_ids[:, :, None] == decoder_segment_ids[:, None, :] + mask = mask[:, None, None, :, :] + + causal_mask = None + # We enforce causality except for AUTOREGRESSION + if model_mode != common_types.MODEL_MODE_AUTOREGRESSIVE: + _, q_seq_len, _, _ = query.shape + _, kv_seq_len, _, _ = key.shape + mask_shape = (q_seq_len, kv_seq_len) + row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0) + col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1) + causal_mask = (col_ids <= row_ids)[None, None, None, :, :] + + if (mask is not None) and (causal_mask is not None): + output_mask = jnp.logical_and(mask, causal_mask) + elif mask is not None: + output_mask = mask + elif causal_mask is not None: + output_mask = causal_mask + else: + output_mask = None + + return jnp.where(output_mask, 0.0, DEFAULT_MASK_VALUE) if output_mask is not None else None + + def apply_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None, model_mode: str): + self.check_attention_inputs(query, key, value) + length = query.shape[-3] + if ( + self.attention_kernel == "dot_product" + or (self.attention_kernel == "autoselected" and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE) + or (self.attention_kernel == "autoselected" and length < 128) + ): + return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode) + elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected": + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) + return self.tpu_flash_attention(query, key, value, decoder_segment_ids), None, None + elif self.attention_kernel == "cudnn_flash_te": + if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + raise ValueError( + """Decode not supported with flash attention. + Use `dot_product` instead.""" + ) + return self.cudnn_flash_attention(query, key, value, decoder_segment_ids, model_mode), None, None + else: + raise ValueError(f"Unexpected attention kernel {self.attention_kernel=}.") + + def tpu_flash_attention(self, query: Array, key: Array, value: Array, decoder_segment_ids: Array | None) -> Array: + """TPU Flash Attention.""" + # Transpose to ('batch', 'heads', 'length', 'kv') + query = jnp.transpose(query, axes=(0, 2, 1, 3)) + key = jnp.transpose(key, axes=(0, 2, 1, 3)) + value = jnp.transpose(value, axes=(0, 2, 1, 3)) + + if decoder_segment_ids is not None: + decoder_segment_ids = splash_attention_kernel.SegmentIds(decoder_segment_ids, decoder_segment_ids) + axis_names = nn.logical_to_mesh_axes(self.flash_axis_names) + segment_axis_names = nn.logical_to_mesh_axes((BATCH, "activation_length_no_heads")) + + @functools.partial( + shard_map, + mesh=self.mesh, + in_specs=( + axis_names, + axis_names, + axis_names, + segment_axis_names, + ), + out_specs=axis_names, + check_rep=False, + ) + def wrap_flash_attention(query, key, value, decoder_segment_ids): + if decoder_segment_ids is not None: + assert ( + query.shape[2] == decoder_segment_ids.q.shape[1] + ), "Sharding along sequence dimension not allowed in tpu kernel attention" + block_sizes = splash_attention_kernel.BlockSizes( + block_q=min(512, query.shape[2]), + block_kv_compute=min(512, key.shape[2]), + block_kv=min(512, key.shape[2]), + block_q_dkv=min(512, query.shape[2]), + block_kv_dkv=min(512, key.shape[2]), + block_kv_dkv_compute=min(512, query.shape[2]), + block_q_dq=min(512, query.shape[2]), + block_kv_dq=min(512, query.shape[2]), + ) + + masks = [splash_attention_mask.CausalMask(shape=(query.shape[2], query.shape[2])) for i in range(query.shape[1])] + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes + ) + + return jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids) + + devices_in_data_fsdp = self.mesh.shape["data"] * self.mesh.shape["fsdp"] + assert (query.shape[0] / devices_in_data_fsdp).is_integer(), ( + "Batch dimension should be shardable among the devices in data and fsdp" " axis" + ) + x = wrap_flash_attention(query, key, value, decoder_segment_ids) + x = jnp.transpose(x, axes=(0, 2, 1, 3)) + return x + + def cudnn_flash_attention( + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + model_mode: str = common_types.MODEL_MODE_TRAIN, + ) -> Array: + """CUDNN Flash Attention with Transformer Engine. + 1. Stable API, supports GQA + 2. Supports head_dim till 128; head_dim=256 support will be added soon + """ + # These imports are only meant to work in a GPU build. + from transformer_engine.jax.flax.transformer import DotProductAttention # pytype: disable=import-error + + _, _, _, head_dim = query.shape # pylint: disable=unused-variable + + # generate attn_mask + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + + dpa_layer = DotProductAttention( + head_dim=head_dim, + num_attention_heads=self.num_query_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type="causal", # 'causal' or 'padding' + attn_bias_type="NO_BIAS", # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=self.dropout_rate, + dropout_rng_name="aqt", + dtype=self.dtype, + float32_logits=self.float32_logits, + qkv_layout="BSHD_BSHD_BSHD", # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0 / math.sqrt(head_dim), + transpose_batch_sequence=False, + ) + return dpa_layer(query, key, value, mask=attn_mask) + + def compute_local_attention(self, attn_weights: Array, value: Array) -> tuple[Array, Array, Array]: + """Computes the attention of a local subset of the kv cache. + Local attention results will need to be combined with any other local attentions and normalized + Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py + + Args: + attn_weights (Array): Product of query and key + value (Array): Current value + aqt_rng (PRNGKey | None): Optional rng + + Returns: + (local_out, local_max,): where + local_out is local unnormalized output + local_max is the local max of exponentials + local_sum is the sum of exponentials for this chunk, divided by exp(local_max). + """ + local_max = jnp.max(attn_weights, axis=-1, keepdims=True) + local_exps = jnp.exp(attn_weights - local_max) + local_sum = jnp.sum(local_exps, axis=-1, keepdims=True) + + local_sum = jnp.moveaxis(local_sum, -2, 1) + local_max = jnp.moveaxis(local_max, -2, 1) + + local_max = jnp.reshape(local_max, (local_max.shape[0], local_max.shape[1], local_max.shape[2] * local_max.shape[3], 1)) + local_sum = jnp.reshape(local_sum, (local_sum.shape[0], local_sum.shape[1], local_sum.shape[2] * local_sum.shape[3], 1)) + + local_out = self.wv_product(local_exps, value) + return local_out, local_max, local_sum + + def apply_attention_dot( + self, + query: Array, + key: Array, + value: Array, + decoder_segment_ids: Array | None, + model_mode: str = common_types.MODEL_MODE_TRAIN, + ): + """Apply Attention.""" + # Casting qk_product and softmaxt computation for float32 for model stability. + if self.float32_qk_product: + query = query.astype(jnp.float32) + key = key.astype(jnp.float32) + + attn_weights = self.qk_product(query, key) + + # Casting softmaxt computation for float32 for model stability. + if self.float32_logits: + attn_weights = attn_weights.astype(jnp.float32) + attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) + if attn_mask is not None: + attn_weights = apply_mask_to_logits(attn_weights, attn_mask) + return self.compute_local_attention(attn_weights, value) + + def qk_product(self, query: Array, key: Array) -> Array: + """Query-Key product. + + Args: + query: Query projection, in shape of [b, t, n, d], where b: batch size, t: + query length, n: number of heads, d: project dimension. + key: Key projection in shape of [b, s, n_kv, d] for where s: key length, n_kv is + kv heads (sometimes k). The number of group for query is n // n_kv (sometimes g). + + Returns: + results in shape [b, n_kv, n // n_kv, t, s]. + """ + b, t, n, d = query.shape + n_kv = key.shape[-2] + assert n_kv == self.num_kv_heads + query = jnp.reshape(query, (b, t, n_kv, n // n_kv, d)) + result = jnp.einsum("btkgd,bskd->bkgts", query, key) + return result # (4, 8, 1, 1, 6) + + def wv_product(self, attn_weights: Array, value: Array) -> Array: + """weighted value product. + + Args: + attn_weights: Computed results of qk_einsum, in shape [batch_size, num_kv_heads, group_size, q_len, k_len]. + value: Value projection, in shape of [batch_size, v_len, num_kv_heads, kv_dim]. + + Returns: + result in shape [batch_size, q_len, num_kv_heads * group_size, kv_dim] + """ + out = jnp.einsum("bkgts,bskd->btkgd", attn_weights, value) + b, t, n_kv, g, d = out.shape + result = jnp.reshape(out, (b, t, n_kv * g, d)) + return result + + def revert_kvlen_axis(self, kv): + """Revert key/value length axis. + + Args: + kv: in shape [b, ..., n, d, s]. + + Returns: + reshaped kv as [b, ..., s, n, d] + """ + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (1, 2, 0, 3)) + + def move_kvlen_axis(self, kv): + """Move key/value length axis to the end. + + Args: + kv: in shape [b, ..., s, n, d]. + + Returns: + reshaped kv as [b, ..., n, d, s] + """ + return jax.numpy.moveaxis(kv, (0, 1, 2, 3), (2, 0, 1, 3)) + + def cached_kv_shape(self, kv_shape): + """Cached KV shape. + + The key and value have dimension [batch, length, num_heads, head_dim], but + we cache them as [length, num_heads, batch, head_dim, ] for optimized read/write performance. + + Args: + kv_shape: shape of key or value for caching, as [b, ..., s, n, d]. + + Returns: + Swapped kv_shape as [b, ..., n, d, s] for cache. + """ + return (kv_shape[1], kv_shape[2], kv_shape[0], kv_shape[3]) + + def _get_prefill_cache(self, batch, heads, kv_head_size, quantize_kvcache): + dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 + + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) + cache_logical_shape = (batch, self.max_prefill_predict_length, heads, kv_head_size) + + cached_key = self.variable( + "cache", + "cached_prefill_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value = self.variable( + "cache", + "cached_prefill_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_segment_id = self.variable( + "cache", + "cache_prefill_segment_id", + nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + (cache_logical_shape[0], self.max_prefill_predict_length), + jnp.int32, + ) + + if self.quantize_kvcache: + cache_logical_shape_scale = (batch, self.max_prefill_predict_length, heads, 1) + cached_key_scale_var = self.variable( + "cache", + "cached_prefill_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_prefill_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + else: + cached_key_scale_var = None + cached_value_scale_var = None + + key_vars = (cached_key, cached_key_scale_var) + value_vars = (cached_value, cached_value_scale_var) + return key_vars, value_vars, cached_segment_id + + def _get_ar_cache(self, batch, heads, kv_head_size, quantize_kvcache): + dtype = jnp.int8 if quantize_kvcache else jnp.bfloat16 + cache_length = self.max_target_length - self.max_prefill_predict_length + kv_cache_layout = ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ) + cache_logical_shape = (batch, cache_length, heads, kv_head_size) + + # TODO(b/339703100): investigate the issue why with_logical_partitioning doesn't enforce sharding + cached_key = self.variable( + "cache", + "cached_ar_key", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_key.value = nn.with_logical_constraint( + cached_key.value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) + + cached_value = self.variable( + "cache", + "cached_ar_value", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape), + dtype, + ) + cached_value.value = nn.with_logical_constraint( + cached_value.value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) + + cached_segment_id = self.variable( + "cache", + "cache_ar_segment_id", + nn.with_logical_partitioning(jnp.zeros, ("cache_batch", "cache_sequence")), + (cache_logical_shape[0], cache_length), + jnp.int32, + ) + + if self.quantize_kvcache: + cache_logical_shape_scale = (batch, cache_length, heads, 1) + cached_key_scale_var = self.variable( + "cache", + "cached_ar_key_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + cached_value_scale_var = self.variable( + "cache", + "cached_ar_value_scale", + nn.with_logical_partitioning(jnp.zeros, kv_cache_layout), + self.cached_kv_shape(cache_logical_shape_scale), + jnp.bfloat16, + ) + else: + cached_key_scale_var = None + cached_value_scale_var = None + + cache_index = self.variable("cache", "cache_ar_index", nn.with_logical_partitioning(jnp.zeros, ()), (1,), jnp.int32) + key_vars = (cached_key, cached_key_scale_var) + value_vars = (cached_value, cached_value_scale_var) + return key_vars, value_vars, cached_segment_id, cache_index + + def kv_cache_prefill( + self, + key: Array, + value: Array, + decoder_segment_ids: Array, + ): + """In prefill mode, we zero out the existing cache, run the computation and + prepare the cache as necessary. + + Args: + key: in shape [b, s, n, d]. + value: in shape [b, s, n, d]. + decoder_segment_ids: [b, s] -- marking segment ids for tokens + + Returns: + key, value, decoder_segment_id. + + """ + batch, sequence, heads, kv_head_size = key.shape + assert key.dtype == value.dtype, "Key and Value Dtypes should match." + + cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + self._get_ar_cache(batch, heads, kv_head_size, self.quantize_kvcache) # initialize it now + + key_shaped_for_cache = self.move_kvlen_axis(key) + value_shaped_for_cache = self.move_kvlen_axis(value) + + if self.quantize_kvcache: + key_shaped_for_cache, key_scale = quantizations.quantize_kv(key_shaped_for_cache) + value_shaped_for_cache, value_scale = quantizations.quantize_kv(value_shaped_for_cache) + cached_prefill_key_var[1].value = key_scale + cached_prefill_value_var[1].value = value_scale + + cached_prefill_key_var[0].value = key_shaped_for_cache + cached_prefill_value_var[0].value = value_shaped_for_cache + + if decoder_segment_ids is not None: + cached_prefill_segment_id.value = decoder_segment_ids + + return key, value, decoder_segment_ids + + def update_ar_key_value( + self, + one_token_key: Array, + one_token_value: Array, + cached_key_vars: tuple[nn.Variable, nn.Variable | None], + cached_value_vars: tuple[nn.Variable, nn.Variable | None], + one_hot_indices: Array, + ) -> tuple[Array, Array]: + """Adds a single token's results to the ar kv cache + + Args: + one_token_key (Array): Key of one token to add to the cache + one_token_value (Array): Value of one token to add to the cache + cached_ar_key (tuple[nn.Variable, nn.Variable|None],): Cached keys to add new token key to, possibly with scale + cached_ar_value (tuple[nn.Variable, nn.Variable|None],: Cached values to add new token value to, possible with scale + one_hot_indices (Array): Location of the new token within the cache + + Returns: + tuple[Array, Array]: Updated caches for key and value with new token info added + """ + + cached_key_var, cached_key_scale_var = cached_key_vars + cached_value_var, cached_value_scale_var = cached_value_vars + + # In order to update the key, value caches with the current key and + # value, we move the length axis to the back + one_token_key = self.move_kvlen_axis(one_token_key) + one_token_value = self.move_kvlen_axis(one_token_value) + + if self.quantize_kvcache: + one_token_key, one_token_key_scale = quantizations.quantize_kv(one_token_key) + one_token_value, one_token_value_scale = quantizations.quantize_kv(one_token_value) + + one_hot_indices = one_hot_indices.astype(int) + + ar_key = cached_key_var.value + ar_key = jax.lax.dynamic_update_index_in_dim(ar_key, one_token_key, jnp.squeeze(one_hot_indices), 0) + ar_key = nn.with_logical_constraint( + ar_key, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) + cached_key_var.value = ar_key + + ar_value = cached_value_var.value + ar_value = jax.lax.dynamic_update_index_in_dim(ar_value, one_token_value, jnp.squeeze(one_hot_indices), 0) + ar_value = nn.with_logical_constraint( + ar_value, + ( + "cache_sequence", + "cache_heads", + "cache_batch", + "cache_kv", + ), + ) + cached_value_var.value = ar_value + + if self.quantize_kvcache: + ar_key_scale = jax.lax.dynamic_update_index_in_dim( + cached_key_scale_var.value, one_token_key_scale, jnp.squeeze(one_hot_indices), 0 + ) + ar_value_scale = jax.lax.dynamic_update_index_in_dim( + cached_value_scale_var.value, one_token_value_scale, jnp.squeeze(one_hot_indices), 0 + ) + cached_key_scale_var.value = ar_key_scale + cached_value_scale_var.value = ar_value_scale + + ar_key = quantizations.unquantize_kv(cached_key_var.value, cached_key_scale_var.value, one_token_key.dtype) + ar_value = quantizations.unquantize_kv(cached_value_var.value, cached_value_scale_var.value, one_token_value.dtype) + + # Move the keys and values back to their original shapes. + return self.revert_kvlen_axis(ar_key), self.revert_kvlen_axis(ar_value) + + def prefill_cache_var_model_var(self, cache_var, target_dtype): + if not self.quantize_kvcache: + return self.revert_kvlen_axis(cache_var[0].value) + else: + raw_cache, quant_scale = cache_var + raw_cache_unquantized = quantizations.unquantize_kv(raw_cache.value, quant_scale.value, target_dtype) + return self.revert_kvlen_axis(raw_cache_unquantized) + + def kv_cache_autoregressive( + self, + key: Array, + value: Array, + ): + """In autoregressive mode, we update the cache for this entry and + then return the full cache. + + Args: + key: in shape [b, 1, n, d]. + value: in shape [b, 1, n, d]. + decoder_segment_ids: [b, 1] -- marking segment ids for tokens + + Returns: + tuple of (key, value, segment_id) for both prefill and ar cache, + Raises: + ValueError: when key/value shape is not [batch, 1, num_heads, heads_dim]. + """ + batch, sequence, heads, kv_head_size = key.shape + if sequence != 1: + raise ValueError(f"Sequence length should be 1 during autoregression, got {sequence=}") + is_initialized = self.has_variable("cache", "cache_ar_index") + if not is_initialized: + raise ValueError("Error, we can't do autoregression if we haven't seeded the KV Cache.") + + cached_ar_key_var, cached_ar_value_var, cached_ar_segment_id, cache_ar_index = self._get_ar_cache( + batch, heads, kv_head_size, self.quantize_kvcache + ) + + key = nn.with_logical_constraint(key, (BATCH, LENGTH, HEAD, D_KV)) + value = nn.with_logical_constraint(value, (BATCH, LENGTH, HEAD, D_KV)) + + ar_key, ar_value = self.update_ar_key_value(key, value, cached_ar_key_var, cached_ar_value_var, cache_ar_index.value) + active_indicator = jnp.zeros((batch, 1), dtype=jnp.int32) + common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR + cached_ar_segment_id.value = jax.lax.dynamic_update_index_in_dim( + cached_ar_segment_id.value, active_indicator, jnp.squeeze(cache_ar_index.value), 1 + ) + cache_ar_index.value = jnp.mod(cache_ar_index.value + 1, self.max_target_length - self.max_prefill_predict_length) + + # Prep and return both prefill and ar caches + cached_prefill_key_var, cached_prefill_value_var, cached_prefill_segment_id = self._get_prefill_cache( + self.max_target_length, heads, kv_head_size, self.quantize_kvcache + ) + + cached_prefill = ( + self.prefill_cache_var_model_var(cached_prefill_key_var, key.dtype), + self.prefill_cache_var_model_var(cached_prefill_value_var, value.dtype), + cached_prefill_segment_id.value, + ) + return cached_prefill, (ar_key, ar_value, cached_ar_segment_id.value) + + def kv_cache(self, key: Array, value: Array, decoder_segment_ids: Array, model_mode: str) -> tuple: + """KV cache takes the current state and updates the state accordingly. + + The key and value have dimension [batch, length, num_heads, head_dim], + but we cache them as [batch, num_heads, head_dim, length] as a TPU + fusion optimization. This also enables the "scatter via one-hot + broadcast" trick, which means we do a one-hot broadcast instead of a + scatter/gather operations, resulting in a 3-4x speedup in practice. + + Args: + key: in shape [b, s, n, d]. + value: in shape [b, s, n, d]. + model_mode: model mode controlling model + + Returns: + two tuples of (k, v, decoder_segments) -- either can be Nones + + """ + if key.shape != value.shape: + raise ValueError(f"Can't KV cache with mismatched shapes {key.shape=}, {value.shape=}") + + if model_mode == common_types.MODEL_MODE_TRAIN: + return (key, value, decoder_segment_ids), None + elif model_mode == common_types.MODEL_MODE_PREFILL: + return self.kv_cache_prefill(key, value, decoder_segment_ids), None + elif model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + return self.kv_cache_autoregressive(key, value) + else: + raise ValueError(f"Model Mode isn't supported! {model_mode=}") + + def normalize_attention(self, local_outs, local_maxes, local_sums): + """Normalize across multiple localized attentions + + Args: + local_outs (list): List of unnormalized outputs entries for each local attention + local_maxes (list): List of max exponentials entries for each local attention + local_sums (list): List of exponential sum entries for each local attention + + Returns: + Array: Combined attention that has been normalized + """ + # Based on https://github.com/google-research/google-research/blob/master/scaling_transformer_inference_efficiency/attention.py + global_max = functools.reduce(jnp.maximum, local_maxes) + global_sum = sum( + [jnp.exp(local_max - global_max) * local_sum for (local_sum, local_max) in zip(local_sums, local_maxes)] + ) + + attn_out = 0 + for local_max, local_out in zip(local_maxes, local_outs): + local_normalizer = jnp.exp(local_max - global_max) / global_sum + attn_out += local_normalizer * local_out + return attn_out + + @nn.compact + def __call__(self, query, key, value, decoder_segment_ids, model_mode): + prefill_kv_cache, ar_kv_cache = self.kv_cache(key, value, decoder_segment_ids, model_mode) + + prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention( + query=query, + key=prefill_kv_cache[0], + value=prefill_kv_cache[1], + decoder_segment_ids=prefill_kv_cache[2], + model_mode=model_mode, + ) + + # Return the "prefill" cache if it actually the combined prefill+ar kv cache + if ar_kv_cache is None: + if prefill_exponentials_sum is not None: + return prefill_unnormalized_output / prefill_exponentials_sum + return prefill_unnormalized_output + + ar_unnormalized_output, ar_exponentials_max, ar_exponentials_sum = self.apply_attention( + query=query, + key=ar_kv_cache[0], + value=ar_kv_cache[1], + decoder_segment_ids=ar_kv_cache[2], + model_mode=model_mode, + ) + + unnormalized_outputs = [prefill_unnormalized_output, ar_unnormalized_output] + exponentials_maxes = [prefill_exponentials_max, ar_exponentials_max] + exponentials_sums = [prefill_exponentials_sum, ar_exponentials_sum] + return self.normalize_attention(unnormalized_outputs, exponentials_maxes, exponentials_sums) + + +class Attention(nn.Module): + """Generic Attention. + + Attributes: + num_query_heads: number of query attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + num_kv_heads: number of kv attention heads. + head_dim: dimension of each head. + mesh: Mesh, device mesh + attention_kernel: str, guidance on if we should use an attention kernel + dtype: the dtype of the computation. + weight_dtype: the dtype of the weights. + max_target_length: maximum target length + max_prefill_predict_length: size of the maximum prefill + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_qk_product: bool, if True then compute logits via float32 qk_product to avoid + numerical issues with bfloat16. + float32_logits: bool, if True then cast logits to float32 before softmax to avoid + numerical issues with bfloat16. + quant: Quant, stores quantization parameters, defaults to None implying no quantization. + quantize_kvcache: bool, quantize the kv cache. + """ + + config: Config + num_query_heads: int + num_kv_heads: int + head_dim: int + max_target_length: int + mesh: Mesh + attention_kernel: str + dtype: DType = jnp.float32 + weight_dtype: DType = jnp.float32 + max_prefill_predict_length: int = -1 + dropout_rate: float = 0.0 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal") + float32_qk_product: bool = False # computes logits in float32 for stability. + float32_logits: bool = False # cast logits in float32 for stability. + quant: Optional[Quant] = None + quantize_kvcache: bool = False + + query_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + key_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + value_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + out_axis_names: AxisNames = (BATCH, LENGTH, HEAD, D_KV) + + def query_projection(self, inputs_q: Array) -> Array: + """Query projection.""" + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + + def query_init(*args): + # pylint: disable=no-value-for-parameter + return self.kernel_init(*args) / depth_scaling + + query_proj = DenseGeneral( + features=(self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=query_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="query", + quant=self.quant, + )(inputs_q) + return query_proj + + def kv_projection(self, inputs_kv: Array, proj_name: str) -> Array: + """Projection for Key and Value. + + Args: + inputs_kv: inputs_kv: key/values of shape `[batch, kv_length, + num_kv_heads, kv_dim]`. + proj_name: name of projection, `key` or `value`. + + Returns: + Projection of key or value, in shape of `[batch, kv_length, head_dim]`. + """ + if self.num_kv_heads == -1: + raise ValueError("num_kv_heads is not defined.") + + if self.num_query_heads % self.num_kv_heads != 0: + raise ValueError("Invalid num_kv_heads for GQA.") + + kv_proj = DenseGeneral( + features=(self.num_kv_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + )(inputs_kv) + return kv_proj + + def qkv_projection(self, inputs: Array, proj_name: str): + """Fused QKV projection""" + + qkv_proj = DenseGeneral( + features=(3, self.num_query_heads, self.head_dim), + axis=-1, + kernel_init=self.kernel_init, + kernel_axes=("embed", "qkv", "heads", "kv"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name=proj_name, + quant=self.quant, + )(inputs) + qkv_proj = checkpoint_name(qkv_proj, "qkv_proj") + query, key, value = qkv_proj[:, :, 0, ...], qkv_proj[:, :, 1, ...], qkv_proj[:, :, 2, ...] + return query, key, value + + def out_projection(self, output_dim: int, out: Array) -> Array: + out_proj = DenseGeneral( + features=output_dim, + axis=(-2, -1), + kernel_init=self.kernel_init, + kernel_axes=("heads", "kv", "embed"), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + name="out", + quant=self.quant, + )(out) + return out_proj + + def key_rotary(self, key: Array, inputs_positions: Array): + """Apply Rotary Embedding to key.""" + key = RotaryEmbedding(embedding_dims=self.head_dim, name="key_rotary")(inputs=key, position=inputs_positions) + return key + + @nn.compact + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + inputs_positions: Array, + decoder_segment_ids: Array | None = None, + *, + model_mode: str = common_types.MODEL_MODE_TRAIN, + deterministic: bool = False, + ): + """Applies Attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + There are three modes: training, prefill and autoregression. During training, the KV cache + is ignored. During prefill, the cache is filled. During autoregression the cache is used. + + In the cache initialization call, `inputs_q` has a shape [batch, length, + q_features] and `inputs_kv`: [batch, length, kv_features]. During the + incremental decoding stage, query, key and value all have the shape [batch, + 1, qkv_features] corresponding to a single step. + + Args: + inputs_q: input queries of shape `[batch, q_length, q_features]`. + inputs_kv: key/values of shape `[batch, kv_length, kv_features]`. + model_mode: corresponding to train, prefill and decode. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, length, q_features]`. + """ + # apply projection. + if self.config.fused_qkv: + query, key, value = self.qkv_projection(inputs_q, proj_name="qkv_proj") + else: + query = self.query_projection(inputs_q) + key = self.kv_projection(inputs_kv, proj_name="key") + value = self.kv_projection(inputs_kv, proj_name="value") + + # apply ROPE + query = RotaryEmbedding(embedding_dims=self.head_dim, name="query_rotary")(inputs=query, position=inputs_positions) + key = self.key_rotary(key, inputs_positions) + + # annotate with sharding constraint. + query = nn.with_logical_constraint(query, self.query_axis_names) + query = checkpoint_name(query, "query_proj") + key = nn.with_logical_constraint(key, self.key_axis_names) + key = checkpoint_name(key, "key_proj") + value = nn.with_logical_constraint(value, self.value_axis_names) + value = checkpoint_name(value, "value_proj") + + attention_op = AttentionOp( + mesh=self.mesh, + attention_kernel=self.attention_kernel, + max_target_length=self.max_target_length, + max_prefill_predict_length=self.max_prefill_predict_length, + float32_qk_product=self.float32_qk_product, + float32_logits=self.float32_logits, + quant=self.quant, + quantize_kvcache=self.quantize_kvcache, + num_query_heads=self.num_query_heads, + num_kv_heads=self.num_kv_heads, + dropout_rate=self.dropout_rate, + dtype=self.dtype, + ) + + out = attention_op(query, key, value, decoder_segment_ids, model_mode) + + out = nn.with_logical_constraint(out, self.out_axis_names) + + # apply output projection, output dim is set to the input dim. + out = self.out_projection(inputs_q.shape[-1], out) + out = checkpoint_name(out, "out_proj") + return out diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/embeddings.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/embeddings.py new file mode 100644 index 000000000..9337986a0 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/embeddings.py @@ -0,0 +1,189 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Embedding Layers.""" + +from typing import Any, Optional + +from flax import linen as nn +import jax +from jax import lax +import jax.numpy as jnp +from layers import initializers + +Config = Any +Array = jnp.ndarray +DType = jnp.dtype + +Initializer = initializers.Initializer +default_embed_init = initializers.default_embed_init +with_logical_partitioning = nn.with_logical_partitioning + +_MAX_WAVELENGTH = 10_000 + + +class Embed(nn.Module): + """A parameterized function from integers [0, n) to d-dimensional vectors. + + Attributes: + num_embeddings: number of embeddings. + features: number of feature dimensions for each embedding. + dtype: the dtype of the embedding vectors (default: float32). + embedding_init: embedding initializer. + """ + + # pylint: disable=attribute-defined-outside-init + config: Config + num_embeddings: int + features: int + cast_input_dtype: Optional[DType] = None + dtype: DType = jnp.float32 + attend_dtype: Optional[DType] = None + embedding_init: Initializer = default_embed_init + + def setup(self): + self.embedding = self.param( + "embedding", + with_logical_partitioning(self.embedding_init, ("vocab", "embed")), + (self.num_embeddings, self.features), + self.config.weight_dtype, + ) + + def __call__(self, inputs: Array) -> Array: + """Embeds the inputs along the last dimension. + + Args: + inputs: input data, all dimensions are considered batch dimensions. + + Returns: + Output which is embedded input data. The output shape follows the input, + with an additional `features` dimension appended. + """ + cfg = self.config + if self.cast_input_dtype: + inputs = inputs.astype(self.cast_input_dtype) + if not jnp.issubdtype(inputs.dtype, jnp.integer): + raise ValueError("Input type must be an integer or unsigned integer.") + + if cfg.use_iota_embed: + iota = lax.iota(jnp.int32, self.num_embeddings) + one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype) + output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype)) + else: + output = jnp.asarray(self.embedding, self.dtype)[inputs] + output = nn.with_logical_constraint(output, ("activation_batch", "activation_length", "activation_embed")) + return output + + def attend(self, query: Array) -> Array: + """Attend over the embedding using a query array. + + Args: + query: array with last dimension equal the feature depth `features` of the + embedding. + + Returns: + An array with final dim `num_embeddings` corresponding to the batched + inner-product of the array of query vectors against each embedding. + Commonly used for weight-sharing between embeddings and logit transform + in NLP models. + """ + dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype + return jnp.dot(query, jnp.asarray(self.embedding, jnp.bfloat16).T) + + +class RotaryEmbedding(nn.Module): + """RoPE + + Attributes: + min_timescale: Start of the geometric index. Determines the periodicity of + the added signal. + max_timescale: End of the geometric index. Determines the frequency of the + added signal. + embedding_dims: Dimension of the embedding to be generated. + """ + + min_timescale: int = 1 + max_timescale: int = 10_000 + embedding_dims: int = 0 + cast_as_fprop_dtype: bool = True + fprop_dtype: DType = jnp.bfloat16 + + def setup(self) -> None: + if self.embedding_dims % 2: + raise ValueError("Embedding dim for rotary position embedding must be a multiple of 2.") + + def __call__( + self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks + inputs: jax.Array, + position: jax.Array, + ) -> jax.Array: + """Generates a jax.Array of sinusoids with different frequencies. + + Args: + inputs: The input sequence on which to apply the Rotary position + embedding. Since rotary position embeddings are applied to query and + keys after projection, it is assumed of shape [B, S, N, H]. + position: Optional position jax.Array which denotes the position of each + token in the sequence. This only needs to be supplied when the sequence + is packed. It is of shape [B, S]. + + Returns: + a jax.Array of shape [B, S, N, H] which includes the inputs together with + the rotary position embedding incorporated in it. + """ + assert position is not None + if len(inputs.shape) != 4: + raise ValueError("Input is assumed to be a rank 4 tensor of shape" "[batch, sequence, heads, dims].") + if self.embedding_dims != inputs.shape[3]: + raise ValueError( + "The embedding dims of the rotary position embedding" "must match the hidden dimension of the inputs." + ) + half_embedding_dim = self.embedding_dims // 2 + fraction = 2 * jnp.arange(0, half_embedding_dim) / self.embedding_dims + timescale = self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction + position = position[:, :, jnp.newaxis, jnp.newaxis] + sinusoid_inp = position / timescale + sin = jnp.sin(sinusoid_inp).astype(inputs.dtype) + cos = jnp.cos(sinusoid_inp).astype(inputs.dtype) + first_half, second_half = jnp.split(inputs, 2, axis=-1) + first_part = first_half * cos - second_half * sin + second_part = second_half * cos + first_half * sin + if self.cast_as_fprop_dtype: + first_part = first_part.astype(self.fprop_dtype) + second_part = second_part.astype(self.fprop_dtype) + x_out = jnp.concatenate((first_part, second_part), axis=-1) + return x_out + + +class PositionalEmbedding(nn.Module): + embedding_dims: int + max_wavelength: int = _MAX_WAVELENGTH + + def __call__( + self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks + input_embedding: jax.Array, + position: jax.Array, + ) -> jax.Array: + num_timescales = self.embedding_dims // 2 + log_timescale_increment = jnp.log(float(self.max_wavelength)) / jnp.maximum( + jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1 + ) + inv_timescales = jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) + position = position[:, :, jnp.newaxis] + inv_timescales = inv_timescales[jnp.newaxis, jnp.newaxis, :] + scaled_time = position * inv_timescales + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1) + # signal = jnp.pad(signal, [[0, jnp.mod(self.embedding_dims, 2)]]) + position_embedding = signal.astype(jnp.float32) + return input_embedding + position_embedding diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/initializers.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/initializers.py new file mode 100644 index 000000000..5916ecb0c --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/initializers.py @@ -0,0 +1,44 @@ +# Copyright 2023 Google LLC + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# https://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Initializers.""" + +from typing import Callable, Tuple, Union + +from flax import linen as nn +import jax +import common_types + +Array = common_types.Array +DType = common_types.DType +PRNGKey = common_types.PRNGKey +Shape = common_types.Shape + +Initializer = Callable[[PRNGKey, Shape, DType], Array] +InitializerAxis = Union[int, Tuple[int, ...]] +NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array] + +default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0) + +default_bias_init = jax.nn.initializers.constant(0.0) + + +def nd_dense_init(scale, mode, distribution): + """Initializer with in_axis, out_axis set at call time.""" + + def init_fn(key, shape, dtype, in_axis, out_axis): + fn = jax.nn.initializers.variance_scaling(scale, mode, distribution, in_axis, out_axis) + return fn(key, shape, dtype) + + return init_fn diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/linears.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/linears.py new file mode 100644 index 000000000..e1a8634ba --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/linears.py @@ -0,0 +1,354 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Linear Layers.""" + +import functools +import operator +from typing import Any, Callable, Iterable, Sequence, Tuple, Union, Optional + +import flax.linen as nn +import jax +from jax import lax +import jax.numpy as jnp +import common_types +from layers import initializers +from layers import normalizations +from layers import quantizations +import numpy as np +from jax.ad_checkpoint import checkpoint_name + +Array = common_types.Array +Config = common_types.Config +DType = common_types.DType +NdInitializer = initializers.NdInitializer + +nd_dense_init = initializers.nd_dense_init +bias_init = initializers.default_bias_init + +RMSNorm = normalizations.RMSNorm +Quant = quantizations.AqtQuantization + + +def _convert_to_activation_function(fn_or_string: Union[str, Callable[..., Any]]) -> Callable[..., Any]: + """Convert a string to an activation function.""" + if fn_or_string == "linear": + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError( + f"""Don't know how to convert {fn_or_string} + to an activation function""" + ) + + +def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]: + # A tuple by convention. len(axes_tuple) then also gives the rank efficiently. + return tuple(ax if ax >= 0 else ndim + ax for ax in axes) + + +def _canonicalize_tuple(x): + if isinstance(x, Iterable): + return tuple(x) + else: + return (x,) + + +class DenseGeneral(nn.Module): + """A linear transformation with flexible axes. + + Attributes: + features: tuple with numbers of output features. + axis: tuple with axes to apply the transformation on. + weight_dtype: the dtype of the weights (default: float32). + dtype: the dtype of the computation (default: float32). + kernel_init: initializer function for the weight matrix. + use_bias: whether to add bias in linear transformation + quant: quantization config, defaults to None implying no quantization. + """ + + features: Union[Iterable[int], int] + axis: Union[Iterable[int], int] = -1 + weight_dtype: DType = jnp.float32 + dtype: DType = jnp.float32 + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") + kernel_axes: Tuple[str, ...] = () + quant: Optional[Quant] = None + use_bias: bool = False + + @nn.compact + def __call__(self, inputs: Array) -> Array: + """Applies a linear transformation to the inputs along multiple dimensions. + + Args: + inputs: The nd-array to be transformed. + + Returns: + The transformed input. + """ + + def compute_dot_general(inputs, kernel, axis, contract_ind): + """Computes a dot_general operation that may be quantized.""" + dot_general = lax.dot_general + if self.quant: + dot_general_cls = self.quant.dot_general_cls() + dot_general = dot_general_cls() + return dot_general(inputs, kernel, ((axis, contract_ind), ((), ())), precision=None) + + features = _canonicalize_tuple(self.features) + axis = _canonicalize_tuple(self.axis) + + inputs = jnp.asarray(inputs, self.dtype) + axis = _normalize_axes(axis, inputs.ndim) + + kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features + kernel_in_axis = np.arange(len(axis)) + kernel_out_axis = np.arange(len(axis), len(axis) + len(features)) + if quantizations.in_serve_mode(self.quant): + # During aqt convert state we delete kernel weight from params to save memory. + # Instead they are retrieved from the tensors stored in the 'aqt' collection. + kernel = jnp.zeros(kernel_shape) + else: + kernel = self.param( + "kernel", + nn.with_logical_partitioning(self.kernel_init, self.kernel_axes), + kernel_shape, + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ) + kernel = jnp.asarray(kernel, self.dtype) + + contract_ind = tuple(range(0, len(axis))) + output = compute_dot_general(inputs, kernel, axis, contract_ind) + + if self.use_bias: + bias_axes, bias_shape = self.kernel_axes[-len(features) :], kernel_shape[-len(features) :] + bias = self.param( + "bias", + nn.with_logical_partitioning(bias_init, bias_axes), + bias_shape, + self.weight_dtype, + ) + bias = jnp.asarray(bias, self.dtype) + output += bias + return output + + +class MlpBlock(nn.Module): + """Transformer MLP / feed-forward block. + + Attributes: + intermediate_dim: Shared dimension of hidden layers. + activations: Type of activations for each layer. Each element is either + 'linear', a string function name in flax.linen, or a function. + kernel_init: Kernel function, passed to the dense layers. + deterministic: Whether the dropout layers should be deterministic. + intermediate_dropout_rate: Dropout rate used after the intermediate layers. + dtype: computation data type for the dense layer. + weight_dtype: weight data type for the dense layer. + use_bias: whether to add bias in all feedforward layers. + use_pre_norm: whether to add pre layer norm in mlp layers. + quant: Optional quantization config, no quantization if None. + """ + + config: Config + intermediate_dim: int = 2048 + activations: Sequence[Union[str, Callable[..., Any]]] = ("relu",) + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal") + intermediate_dropout_rate: float = 0.1 + dtype: Any = jnp.float32 + weight_dtype: Any = jnp.float32 + use_bias: bool = False + use_pre_norm: bool = False + quant: Optional[Quant] = None + + def get_norm_layer(self): + if self.config.decoder_block in ("default", "llama2", "mistral", "gemma"): + return RMSNorm + elif self.config.decoder_block == "gpt3": + from layers import gpt3 + + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=self.use_bias) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + + @nn.compact + def __call__(self, inputs, decode: bool = False, deterministic: bool = False): + """Applies Transformer MlpBlock module.""" + cfg = self.config + + if self.use_pre_norm: + inputs = self.get_norm_layer()( + name="mlp_layer_norm", + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + kernel_axes=("norm",), + epsilon=cfg.normalization_layer_epsilon, + )(inputs) + + # Iterate over specified MLP input activation functions. + # e.g. ('relu',) or ('gelu', 'linear') for gated-gelu. + activations = [] + if cfg.fused_mlp: + x = DenseGeneral( + (len(self.activations), self.intermediate_dim), + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "num_activations", "mlp"), + name="wi", + quant=self.quant, + use_bias=self.use_bias, + )(inputs) + for idx, act_fn in enumerate(self.activations): + y = _convert_to_activation_function(act_fn)(x[:, :, idx, ...]) + activations.append(y) + else: + for idx, act_fn in enumerate(self.activations): + dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}" + x = DenseGeneral( + self.intermediate_dim, + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("embed", "mlp"), + name=dense_name, + quant=self.quant, + use_bias=self.use_bias, + )(inputs) + x = _convert_to_activation_function(act_fn)(x) + activations.append(x) + + # Take elementwise product of above intermediate activations. + x = functools.reduce(operator.mul, activations) + x = checkpoint_name(x, "mlpwi") + # Apply dropout and final dense output projection. + x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))( + x, deterministic=deterministic + ) # Broadcast along length. + x = nn.with_logical_constraint(x, ("activation_batch", "activation_length", "activation_mlp")) + output = DenseGeneral( + inputs.shape[-1], + dtype=self.dtype, + weight_dtype=self.weight_dtype, + kernel_init=self.kernel_init, + kernel_axes=("mlp", "embed"), + name="wo", + quant=self.quant, + use_bias=self.use_bias, + )(x) + + output = checkpoint_name(output, "mlpwo") + return output + + +class MoeBlock(nn.Module): + """Mixture of Experts (MoE) block. + + Attributes: + num_experts: Number of experts. + num_experts_per_tok: Number of experts for each token. + kernel_init: Kernel function, passed to the dense layers. + kernel_axes: Tuple with axes to apply kernel function. + weight_dtype: Type for the weights. + dtype: Type for the dense layer. + """ + + config: Config + num_experts: int + num_experts_per_tok: int + kernel_init: NdInitializer + kernel_axes: Tuple[str, ...] + weight_dtype: DType = jnp.float32 + dtype: DType = jnp.float32 + + def generate_kernels(self, num_experts, base_emb_dim, mlp_dim): + + kernel_in_axis = np.arange(1) + kernel_out_axis = np.arange(1, 2) + kernel_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal') + + kernel_axes = ('exp', 'embed', 'mlp') + wo_kernel_axes = ('exp', 'mlp', 'embed') + + w0_kernel = self.param( + 'wi_0', + nn.with_logical_partitioning(kernel_init, kernel_axes), + (num_experts, base_emb_dim, mlp_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ) + w0_kernel = jnp.asarray(w0_kernel, self.dtype) + w1_kernel = self.param( + 'wi_1', + nn.with_logical_partitioning(kernel_init, kernel_axes), + (num_experts, base_emb_dim, mlp_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ) + w1_kernel = jnp.asarray(w1_kernel, self.dtype) + wo_kernel = self.param( + 'wo', + nn.with_logical_partitioning(kernel_init, wo_kernel_axes), + (num_experts, mlp_dim, base_emb_dim), + self.weight_dtype, + kernel_in_axis, + kernel_out_axis, + ) + wo_kernel = jnp.asarray(wo_kernel, self.dtype) + return w0_kernel, w1_kernel, wo_kernel + + @nn.compact + def __call__(self, inputs): + cfg = self.config + inputs = inputs.astype(cfg.dtype) + gate_logits = DenseGeneral( + self.num_experts, + dtype=self.dtype, + kernel_init=self.kernel_init, + kernel_axes=self.kernel_axes, + name="gate")(inputs) + + top_k_weights, top_k_indices = jax.lax.top_k(gate_logits, self.num_experts_per_tok) + flattened_top_k_weights = top_k_weights.reshape(-1, self.num_experts_per_tok) + + softmax_probs = jax.nn.softmax(flattened_top_k_weights.astype(jnp.float32), axis=-1).astype(self.weight_dtype) + softmax_probs = softmax_probs.reshape(gate_logits.shape[:-1] + (self.num_experts_per_tok,)) + + weights = jnp.zeros_like(gate_logits) + index_update = (jnp.arange(gate_logits.shape[0])[:, None, None], jnp.arange(gate_logits.shape[1])[:, None], top_k_indices) + weights = weights.at[index_update].set(softmax_probs) + + w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, + cfg.base_emb_dim, + cfg.mlp_dim) + + with jax.named_scope("wi_0"): + layer_w0 = jnp.einsum("BLE,NEH -> BLNH", inputs, w0_kernel) + with jax.named_scope("wi_1"): + layer_w1 = jnp.einsum("BLE,NEH -> BLNH", inputs, w1_kernel) + layer_w0_act = _convert_to_activation_function(cfg.mlp_activations[0])(layer_w0) + layer_multiply = jnp.multiply(layer_w0_act, layer_w1) + with jax.named_scope("wo"): + intermediate_layer = jnp.einsum("BLNH,NHE -> BLNE", layer_multiply, wo_kernel) + with jax.named_scope("w_sum"): + output = jnp.einsum("BLNE,BLN -> BLE", intermediate_layer, weights) + + return output diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/models.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/models.py new file mode 100644 index 000000000..c778e2d38 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/models.py @@ -0,0 +1,393 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Transformer models.""" +# pylint: disable=arguments-differ +# pylint: disable=no-name-in-module + +from typing import Callable, Optional + + +from flax import linen as nn +import functools +import jax +import jax.numpy as jnp +import common_types +from layers import attentions +from layers import embeddings +from layers import linears +from layers import normalizations, quantizations + +Array = common_types.Array +Config = common_types.Config +DType = common_types.DType +Mesh = common_types.Mesh +ScanIn = common_types.ScanIn + +Embed = embeddings.Embed +Attention = attentions.Attention +RMSNorm = normalizations.RMSNorm +PositionalEmbedding = embeddings.PositionalEmbedding +Quant = quantizations.AqtQuantization + +# ------------------------------------------------------------------------------ +# The network: Decoder & Transformer Definitions +# ------------------------------------------------------------------------------ + + +class DecoderLayer(nn.Module): + """Transformer decoder layer that attends to the encoder.""" + + config: Config + mesh: Mesh + quant: Optional[Quant] = None + + @nn.compact + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ): + cfg = self.config + mesh = self.mesh + + inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + + # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] + lnx = RMSNorm( + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="pre_self_attention_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + )(inputs) + lnx = nn.with_logical_constraint(lnx, ("activation_batch", "activation_length", "activation_embed")) + + attention_layer = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + name="self_attention", + quant=self.quant, + quantize_kvcache=cfg.quantize_kvcache, + ) + + attention_lnx = attention_layer( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + ) + + attention_lnx = nn.with_logical_constraint(attention_lnx, ("activation_batch", "activation_length", "activation_embed")) + + # MLP block. + mlp_lnx = linears.MlpBlock( + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="mlp", + config=cfg, + quant=self.quant, + )(lnx, deterministic=deterministic) + mlp_lnx = nn.with_logical_constraint(mlp_lnx, ("activation_batch", "activation_length", "activation_embed")) + + next_layer_addition = mlp_lnx + attention_lnx + + next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( + next_layer_addition, deterministic=deterministic + ) + + layer_output = next_layer_addition_dropped_out + inputs + layer_output = nn.with_logical_constraint( + layer_output, + ("activation_batch", "activation_length", "activation_embed"), + ) + + if cfg.record_internal_nn_metrics: + self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) + self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) + self.sow( + "intermediates", + "activation_fraction_zero", + jnp.sum(layer_output == 0) / jnp.size(layer_output), + ) + + return layer_output, None if cfg.scan_layers else layer_output + + +class Decoder(nn.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture.""" + + config: Config + shared_embedding: nn.Module + mesh: Mesh + quant: Optional[Quant] = None + + def get_decoder_layer(self): + if self.config.decoder_block == "default": + return DecoderLayer + elif self.config.decoder_block == "llama2": + from layers import llama2 + + return llama2.LlamaDecoderLayer + elif self.config.decoder_block == "mistral": + # TODO(ranran): update to Mistral with sliding window attention + from layers import mistral + + return mistral.MistralDecoderLayer + elif self.config.decoder_block == "gemma": + from layers import gemma + + return gemma.GemmaDecoderLayer + elif self.config.decoder_block == "gpt3": + from layers import gpt3 + + return gpt3.Gpt3DecoderLayer + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + + def get_norm_layer(self): + if self.config.decoder_block in ("default", "llama2", "mistral", "gemma"): + return RMSNorm + elif self.config.decoder_block == "gpt3": + from layers import gpt3 + + return functools.partial(gpt3.Gpt3LayerNorm, reductions_in_fp32=False, use_bias=True) + else: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block=}") + + @nn.compact + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=common_types.MODEL_MODE_TRAIN, + ): + cfg = self.config + mesh = self.mesh + assert decoder_input_tokens.ndim == 2 # [batch, len] + + # [batch, length] -> [batch, length, emb_dim] + y = self.shared_embedding(decoder_input_tokens.astype("int32")) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if cfg.use_untrainable_positional_embedding: + y = PositionalEmbedding(cfg.base_emb_dim)(y, decoder_positions) + + if cfg.trainable_position_size > 0: + y += Embed( + num_embeddings=cfg.trainable_position_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + name="position_embedder", + config=cfg, + )(decoder_positions) + + BlockLayer = self.get_decoder_layer() + + if cfg.remat_policy != "none": + if cfg.remat_policy == "minimal": + policy = jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + elif cfg.remat_policy == "save_dot_except_mlpwi": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + "mlpwo", + ) + elif cfg.remat_policy == "save_dot_except_mlp": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + "out_proj", + ) + elif cfg.remat_policy == "save_qkv_proj": + policy = jax.checkpoint_policies.save_only_these_names( + "query_proj", + "value_proj", + "key_proj", + "qkv_proj", + ) + elif cfg.remat_policy == "qkv_proj_offloaded": + policy = jax.checkpoint_policies.save_and_offload_only_these_names( + names_which_can_be_saved=[], + names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], + offload_src="device", + offload_dst="pinned_host", + ) + elif cfg.remat_policy == "minimal_offloaded": + policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims(offload_src="device", offload_dst="pinned_host") + elif cfg.remat_policy == "minimal_flash": + policy = jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + jax.checkpoint_policies.save_only_these_names( + "context", + ), + ) + else: + assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" + policy = None + BlockLayer = nn.remat( # pylint: disable=invalid-name + BlockLayer, + prevent_cse=not cfg.scan_layers, + policy=policy, + static_argnums=(-1, -2, -3, -4, -5), + ) + if cfg.scan_layers: + initializing = self.is_mutable_collection("params") + params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) + cache_spec = 0 + y, _ = nn.scan( + BlockLayer, + variable_axes={ + "params": params_spec, + "cache": cache_spec, + "intermediates": 0, + "aqt": 0, + "_overwrite_with_gradient": 0, + }, + split_rngs={ + "params": True, + "dropout": cfg.enable_dropout, + }, + in_axes=( + nn.broadcast, + nn.broadcast, + nn.broadcast, + nn.broadcast, + ), + length=cfg.num_decoder_layers, + metadata_params={nn.PARTITION_NAME: "layers"}, + )(config=cfg, mesh=mesh, name="layers", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + else: + for lyr in range(cfg.num_decoder_layers): + y = BlockLayer(config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant)( + y, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + ) + + y = self.get_norm_layer()( + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + name="decoder_norm", + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + )(y) + y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + + # [batch, length, emb_dim] -> [batch, length, vocab_size] + if cfg.logits_via_embedding: + # Use the transpose of embedding matrix for logit transform. + logits = self.shared_embedding.attend(y) + if self.config.normalize_embedding_logits: + # Correctly normalize pre-softmax logits for this shared case. + logits = logits / jnp.sqrt(y.shape[-1]) + else: + logits = linears.DenseGeneral( + cfg.vocab_size, + weight_dtype=cfg.weight_dtype, + dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + kernel_axes=("embed", "vocab"), + name="logits_dense", + )( + y + ) # We do not quantize the logits matmul. + logits = nn.with_logical_constraint(logits, ("activation_batch", "activation_length", "activation_vocab")) + logits = logits.astype(jnp.float32) + return logits + + +class Transformer(nn.Module): + """An decoder-only Transformer model.""" + + # Make new attributes required, so that all Transformer dependencies (train, decode, compile, etc) will error instead of silently use defaults. + # pylint: disable=attribute-defined-outside-init + config: Config + mesh: Mesh + quant: Quant + + def setup(self): + """Initialize shared_embedding & decoder layers.""" + + cfg = self.config + mesh = self.mesh + self.shared_embedding = Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + name="token_embedder", + config=cfg, + ) + + self.decoder = Decoder(config=cfg, shared_embedding=self.shared_embedding, mesh=mesh, quant=self.quant) + + def __call__( + self, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + enable_dropout=True, + model_mode=common_types.MODEL_MODE_TRAIN, + ): + """Applies Transformer decoder-branch on encoded-input and target.""" + + if decoder_segment_ids is not None and model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE: + raise ValueError( + f"During autoregressive decoding we assume the tokens are in the active sequence" + f" which is always {common_types.DECODING_ACTIVE_SEQUENCE_INDICATOR}." + ) + + logits = self.decoder( + decoder_input_tokens=decoder_input_tokens, + decoder_positions=decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=not enable_dropout, + model_mode=model_mode, + ) + return logits diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/normalizations.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/normalizations.py new file mode 100644 index 000000000..862c586c9 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/layers/normalizations.py @@ -0,0 +1,51 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Normalization Layers.""" + +from typing import Any, Tuple + +from flax import linen as nn +from jax import lax +import jax.numpy as jnp +from layers import initializers + +Initializer = initializers.Initializer + + +class RMSNorm(nn.Module): + """RMS normalization.""" + + epsilon: float = 1e-6 + dtype: Any = jnp.float32 + weight_dtype: Any = jnp.float32 + kernel_axes: Tuple[str, ...] = () + scale_init: Initializer = nn.initializers.ones + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + """Applies layer normalization on the input.""" + x = jnp.asarray(x, jnp.float32) + features = x.shape[-1] + mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True) + y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype) + scale = self.param( + "scale", + nn.with_logical_partitioning(self.scale_init, self.kernel_axes), + (features,), + self.weight_dtype, + ) + + scale = jnp.asarray(scale, self.dtype) + return y * scale diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/max_utils.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/max_utils.py new file mode 100644 index 000000000..0e372e294 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/max_utils.py @@ -0,0 +1,671 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +""" Common Max Utils needed by multiple modules""" +import checkpointing +import common_types +import functools +import time +import socket +import subprocess + +import max_logging + +import numpy as np +import jax +import jax.numpy as jnp +from jax.experimental import mesh_utils + + +import json +import yaml +import flax +from flax.training import train_state +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning + +import optax +import os +from typing import Tuple +from tensorboardX import writer + +from google.cloud import storage + + +def find_nans_and_infs(pytree): + def finder(x): + return jnp.any(jnp.isinf(x) | jnp.isnan(x)) + + bad_pytree = jax.tree_util.tree_map(finder, pytree) + return jax.tree_util.tree_flatten(bad_pytree) + + +def l2norm_pytree(x): + """L2 norm of a pytree of arrays.""" + return jnp.sqrt(jax.tree_util.tree_reduce(lambda x, y: x + jnp.sum(jnp.square(y)), x, initializer=0.0)) + + +def calculate_num_params_from_pytree(params): + params_sizes = jax.tree_util.tree_map(jax.numpy.size, params) + total_parameters = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes) + assert total_parameters >= 0 + return total_parameters + + +def calculate_total_params_per_chip(params): + def calculate_leaf_params_per_chip(arr): + shard = arr.addressable_shards[0] + return np.prod(shard.data.shape) + + params_sizes_per_chip = jax.tree_util.tree_map(calculate_leaf_params_per_chip, params) + total_parameters_per_chip = jax.tree_util.tree_reduce(lambda x, y: x + y, params_sizes_per_chip) + return total_parameters_per_chip + + +def calculate_bytes_from_pytree(params): + params_bytes = jax.tree_util.tree_map(lambda x: x.nbytes, params) + total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes) + return total_bytes + + +def summarize_size_from_pytree(params): + num_params = calculate_num_params_from_pytree(params) + num_bytes = calculate_bytes_from_pytree(params) + return num_params, num_bytes, num_bytes / num_params + +def initialize_summary_writer(config): + return writer.SummaryWriter(config.tensorboard_dir) if jax.process_index() == 0 else None + + +def close_summary_writer(summary_writer): + if jax.process_index() == 0: + summary_writer.close() + + +def _prepare_metrics_for_json(metrics, step, run_name): + """Converts metric dictionary into json supported types (e.g. float)""" + metrics_dict = {} + for val in metrics["scalar"]: + metrics_dict[val] = float(metrics["scalar"][val]) + metrics_dict["step"] = float(step) + metrics_dict["run_name"] = run_name + return metrics_dict + + +def write_metrics_locally(metrics, step, config, file): + """Writes metrics locally for testing""" + if step == 0: + file.truncate(0) + + metrics_dict = _prepare_metrics_for_json(metrics, step, config.run_name) + file.write(str(json.dumps(metrics_dict)) + "\n") + + if step == config.steps - 1: + file.close() + + +def add_config_to_summary_writer(config, summary_writer): + """Writes config params to tensorboard""" + if jax.process_index() == 0: + for key, value in config.get_keys().items(): + add_text_to_summary_writer(key, str(value), summary_writer) + + +def add_text_to_summary_writer(key, value, summary_writer): + """Writes given key-value pair to tensorboard as text/summary""" + if jax.process_index() == 0: + summary_writer.add_text(key, value) + + +def write_metrics_for_gcs(metrics, step, config, running_metrics): + """Writes metrics to gcs""" + metrics_dict_step = _prepare_metrics_for_json(metrics, step, config.run_name) + running_metrics.append(metrics_dict_step) + if (step + 1) % config.log_period == 0 or step == config.steps - 1: + start_step = (step // config.log_period) * config.log_period + metrics_filename = f"metrics_step_{start_step:06}_to_step_{step:06}.txt" + with open(metrics_filename, "w", encoding="utf8") as metrics_for_gcs: + for metrics_step in running_metrics: + metrics_for_gcs.write(str(json.dumps(metrics_step)) + "\n") + + metrics_for_gcs.close() + gcs_filename = os.path.join(config.metrics_dir, metrics_filename) + max_logging.log(f"Moving file {metrics_filename} to GCS...") + upload_blob(gcs_filename, metrics_filename) + max_logging.log(f"File {metrics_filename} moved successfully!") + running_metrics = [] # reset running_metrics to empty list + return running_metrics + + +def write_config_raw_keys_for_gcs(raw_keys): + """Writes config raw keys to GCS""" + if not raw_keys["save_config_to_gcs"] or jax.process_index() != 0: + return + max_logging.log("Writing config to GCS...") + + raw_keys_dict = dict(raw_keys) + filename = "config.yml" + with open(filename, "w", encoding="utf8") as config_for_gcs: + yaml.dump(raw_keys_dict, config_for_gcs) + config_for_gcs.close() + + gcs_filename = os.path.join(raw_keys["base_output_directory"], raw_keys["run_name"], filename) + max_logging.log(f"Moving file {filename} to GCS...") + upload_blob(gcs_filename, filename) + max_logging.log(f"File {filename} moved successfully!") + + +def parse_gcs_bucket_and_prefix(destination_gcs_name): + path_parts = destination_gcs_name.replace("gs://", "").split("/") + bucket = path_parts.pop(0) + key = "/".join(path_parts) + return bucket, key + + +def upload_blob(destination_gcs_name, source_file_name): + """Uploads a file to a GCS location""" + bucket_name, prefix_name = parse_gcs_bucket_and_prefix(destination_gcs_name) + storage_client = storage.Client() + bucket = storage_client.get_bucket(bucket_name) + blob = bucket.blob(prefix_name) + blob.upload_from_filename(source_file_name) + + +def maybe_initialize_jax_distributed_system(raw_keys): + """The best recipe to initialize the Jax Distributed System has varied over time. We keep a layer of + indirection in MaxText to avoid breaking the call sites unnecessarily. + + Currently jax.distributed.initialize() fully works as expected! + + For CPUs, we call jax.distributed.initialize() explicitly, with the specified arguments. + """ + if ( + raw_keys["enable_checkpointing"] and raw_keys["async_checkpointing"] and + raw_keys["compile_topology_num_slices"] == -1 and not raw_keys["enable_single_controller"] + ) or raw_keys["hardware"] == "gpu_multiprocess": + max_logging.log("Attempting to initialize the jax distributed system...") + jax.distributed.initialize() + max_logging.log("Jax distributed system initialized!") + elif is_gpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for GPU backend...") + initialize_jax_for_gpu() + max_logging.log("Jax distributed system initialized on GPU!") + elif is_cpu_backend(raw_keys): + max_logging.log("Attempting to initialize the jax distributed system for CPU backend...") + initialize_jax_for_cpu() + max_logging.log("Jax distributed system initialized on CPUs!") + + +def initialize_jax_for_gpu(): + """Jax distributed initialize for GPUs.""" + if os.environ.get("JAX_COORDINATOR_IP") is not None: + coordinator_ip = str(os.getenv("JAX_COORDINATOR_IP")) + coordinator_port = str(os.getenv("JAX_COORDINATOR_PORT")) + jax.distributed.initialize( + coordinator_address=f"{coordinator_ip}:{coordinator_port}", + num_processes=int(os.getenv("NNODES")), + process_id=int(os.getenv("NODE_RANK")), + ) + max_logging.log(f"JAX global devices: {jax.devices()}") + + +def initialize_jax_for_cpu(): + """Jax distributed initialize for CPUs. Includes retries until the coordinator is ready.""" + coordinator_ip_address = get_coordinator_ip_address() + coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK + # Env variables to be set in XPK or otherwise + job_index = int(os.environ.get("JOB_INDEX")) + job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) + processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + pid = job_index * processes_in_job + job_completion_index + max_logging.log(f" Jax process id is {pid} ") + # Explicit initialize is needed only for CPUs + jax.distributed.initialize( + coordinator_address=coordinator_address, process_id=pid, num_processes=int(os.environ.get("JAX_PROCESS_COUNT")) + ) + + +def is_cpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a CPU backend.""" + return raw_keys["hardware"] == "cpu" + + +def is_gpu_backend(raw_keys): + """Determine whether Maxtext is intended to run on a GPU backend.""" + return raw_keys["hardware"] == "gpu" + + +def get_coordinator_ip_address(): + """Get coordinator IP Address with retries""" + coordinator_address = "" + coordinator_ip_address = "" + if os.environ.get("JAX_COORDINATOR_ADDRESS") is not None: + coordinator_address = os.environ.get("JAX_COORDINATOR_ADDRESS") + coordinator_found = False + lookup_attempt = 1 + max_coordinator_lookups = 50 + while not coordinator_found and lookup_attempt <= max_coordinator_lookups: + try: + coordinator_ip_address = socket.gethostbyname(coordinator_address) + coordinator_found = True + except socket.gaierror: + max_logging.log( + f"Failed to recognize coordinator address {coordinator_address} on attempt {lookup_attempt}, retrying..." + ) + lookup_attempt += 1 + time.sleep(5) + max_logging.log(f"Coordinator IP address: {coordinator_ip_address}") + return coordinator_ip_address + + +def fill_unspecified_mesh_axes(parallelism_vals, target_product, parallelism_type): + """Evaluates unspecified DCN/ICI parallelism values""" + if -1 in parallelism_vals: + assert ( + parallelism_vals.count(-1) == 1 + ), f"Found unspecified values (-1) for more than one {parallelism_type}\ + parallelism axis. At most one axis can be unspecified." + + determined_val = target_product / np.prod(parallelism_vals) * -1 + + assert ( + determined_val >= 1 and determined_val.is_integer + ), f"Unspecified value unable to be determined with the given\ + {parallelism_type} parallelism values" + + parallelism_vals[parallelism_vals.index(-1)] = int(determined_val) + + target_type = "slices" if parallelism_type == "DCN" else "devices per slice" + + assert ( + np.prod(parallelism_vals) == target_product + ), f"Number of {target_type} {target_product} does not match\ + the product of the {parallelism_type} parallelism {np.prod(parallelism_vals)}" + + return parallelism_vals + + +def create_device_mesh(config, devices=None): + """Creates a device mesh with each slice in its own data parallel group. If there is only one slice, uses two replicas""" + if devices is None: + devices = jax.devices() + num_devices = len(devices) + num_slices = config.num_slices + num_devices_per_slice = num_devices // num_slices + + multi_slice_env = num_slices > 1 + + dcn_parallelism = [ + config.dcn_data_parallelism, + config.dcn_fsdp_parallelism, + config.dcn_fsdp_transpose_parallelism, + config.dcn_sequence_parallelism, + config.dcn_tensor_parallelism, + config.dcn_autoregressive_parallelism, + ] + ici_parallelism = [ + config.ici_data_parallelism, + config.ici_fsdp_parallelism, + config.ici_fsdp_transpose_parallelism, + config.ici_sequence_parallelism, + config.ici_tensor_parallelism, + config.ici_autoregressive_parallelism, + ] + + # Find possible unspecified parallelisms + ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") + + if multi_slice_env: + dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") + mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) + else: + mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) + + max_logging.log(f"Num_devices: {num_devices}, shape {mesh.shape}") + + return mesh + + +def unbox_logicallypartioned(boxed_pytree): + """Unboxes the flax.LogicallyPartitioned pieces + + Args: + boxed_pytree: a pytree that includes LogicallyPartitioned + leaves. + Returns: + a pytree where all all LogicallyPartitioned leaves have been unboxed. + """ + return jax.tree_util.tree_map( + lambda x: x.unbox() if isinstance(x, flax.linen.spmd.LogicallyPartitioned) else x, + boxed_pytree, + is_leaf=lambda k: isinstance(k, flax.linen.spmd.LogicallyPartitioned), + ) + + +def init_decode_state(apply_fn, params): + """Init train state with null opt state for decode.""" + state = train_state.TrainState(step=0, apply_fn=apply_fn, params=params, tx=None, opt_state={}) # type: ignore + return state + + +def init_training_state(apply_fn, params, tx): + """Init train state with null opt state for decode.""" + state = train_state.TrainState.create(apply_fn=apply_fn, params=params, tx=tx) + return state + + +def init_initial_state(model, tx, config, is_training, key): + """ + We pass in "static" objects like model, tx, config as JAX compares them by + object hash, and instantiating them inside causes pjit top-level annotations + to fail to match as pytree prefixes if we re-instantiate. + + Args: model, tx, config, is_training, key + """ + input_shape = (config.global_batch_size_to_load, config.max_target_length) + model_vars = model.init( + {"params": key, "dropout": key, "aqt": key}, + jnp.ones(input_shape, dtype=jnp.int32), + jnp.ones(input_shape, dtype=jnp.int32), + ) + if is_training: + return init_training_state(model.apply, model_vars, tx) + return init_decode_state(model.apply, model_vars) + + +def load_decode_model_vars(model, config, rng, mesh): + state, _ = setup_decode_state(model, config, rng, mesh, None) + return state.params + + +def setup_decode_state(model, config, rng, mesh, checkpoint_manager): + is_training = False + state, state_mesh_annotations, _ = setup_initial_state( + model, None, None, config, rng, mesh, checkpoint_manager, is_training + ) + return state, state_mesh_annotations + + +def setup_training_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager): + is_training = True + return setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training) + + +def setup_initial_state(model, data_iterator, tx, config, rng, mesh, checkpoint_manager, is_training=True): + """We initialize the model and optimizer state, and optionally load from a + checkpoint as necessary. + + Args: + model: the flax model to initialize + tx: the optax.GradientTransformation + config: config object + rng: jax.prng key + mesh: jax.devices() mesh + checkpoint_manager: an Orbax checkpointing.CheckpointManager object + is_training: True to initialize training state, False for decode state + + Returns: + state: the initialized train state + state_mesh_annotations: the mesh annotations for the train state + """ + + unboxed_abstract_state, state_mesh_annotations, state_mesh_shardings = get_abstract_state( + model, tx, config, rng, mesh, is_training + ) + + # Initialization + with nn_partitioning.axis_rules(config.logical_axis_rules): + restored, raw_params = checkpointing.load_state_if_possible( + checkpoint_manager, + data_iterator, + config.load_parameters_path, + config.load_full_state_path, + unboxed_abstract_state, + config.enable_single_replica_ckpt_restoring, + config.dataset_type, + ) + + if restored: + if "iter" in restored and restored["iter"] is not None: + data_iterator.local_iterator = restored["iter"] + state = restored["items"] + else: + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings)(rng) + if raw_params: # If we loaded a partial state, we need to merge it. + state = state.replace(params=raw_params) + + state = unbox_logicallypartioned(state) + return state, state_mesh_annotations, data_iterator + + +# Learning Rate Schedule +# ----------------------------------------------------------------------------- + + +def create_learning_rate_schedule(config): + """Creates a warmup and cosine decay learning rate schedule: + We take inspiration from Llama2's learning rate (LR) schedule, see https://arxiv.org/pdf/2307.09288.pdf section 2.2 + Learning rate schedule has either two or three parts: + 1) Linear warmup from 0 to [learning_rate] over steps 0 to [learning_rate_schedule_steps * warmup_steps_fraction] + 2) Cosine from [learning_rate] to [learning_rate * cosine_learning_rate_final_fraction] until learning_rate_schedule_steps + 3) Constant learning rate of 0 from learning_rate_schedule_steps to steps. + The zero learning rate section can be used to more accurately measure the fully trained model's performance. + """ + + def make_cos_schedule(init_lr, final_lr, len_steps): + def schedule(step): + pct = (step) / len_steps + a = 0.5 * (jnp.cos(jnp.pi * pct) + 1) + lr = init_lr * a + final_lr * (1 - a) + return lr + + return schedule + + lr = config.learning_rate + cos_final_lr = lr * config.cosine_learning_rate_final_fraction + + warmup_steps = int(config.learning_rate_schedule_steps * config.warmup_steps_fraction) + cos_steps = config.learning_rate_schedule_steps - warmup_steps + constant_zero_steps = config.steps - config.learning_rate_schedule_steps + + warmup_schedule = optax.linear_schedule(init_value=0.0, end_value=lr, transition_steps=warmup_steps) + cos_schedule = make_cos_schedule(lr, cos_final_lr, cos_steps) + constant_schedule = optax.constant_schedule(0.0) + + pieces = [warmup_schedule, cos_schedule] + boundaries = [ + warmup_steps, + warmup_steps + cos_steps, + ] + + if constant_zero_steps > 0: + pieces.append(constant_schedule) + boundaries.append(warmup_steps + cos_steps + constant_zero_steps) + + return optax.join_schedules(pieces, boundaries) + + +# Cross entropy implementation is taken from original T5X codebase: +# https://github.com/google-research/t5x/blob/ace831eea1e2742b4299cd1a9af7e4f302038351/t5x/losses.py#L25-L101 +@jax.custom_vjp +def cross_entropy_with_logits(logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float) -> Tuple[jnp.ndarray, jnp.ndarray]: + """Computes cross entropy loss with stable custom gradient. + Computes a stabilized-gradient version of: + -jnp.sum(targets * nn.log_softmax(logits), axis=-1) + If z_loss > 0, then an auxiliary loss equal to z_loss*log(z)^2 + will be added to the cross entropy loss (z = softmax normalization constant). + The two uses of z_loss are: + 1. To keep the logits from drifting too far from zero, which can cause + unacceptable roundoff errors in bfloat16. + 2. To encourage the logits to be normalized log-probabilities. + Args: + logits: [batch, length, num_classes] float array. + targets: categorical one-hot targets [batch, length, num_classes] float + array. + z_loss: coefficient for auxiliary z-loss loss term. + Returns: + tuple with the total loss and the z_loss, both + float arrays with shape [batch, length]. + """ + logits_sum = jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) + log_softmax = logits - logits_sum + loss = -jnp.sum(targets * log_softmax, axis=-1) + # Add auxiliary z-loss term. + log_z = jnp.squeeze(logits_sum, axis=-1) + total_z_loss = z_loss * jax.lax.square(log_z) + loss += total_z_loss + return loss, total_z_loss + + +def _cross_entropy_with_logits_fwd( + logits: jnp.ndarray, targets: jnp.ndarray, z_loss: float = 0.0 +) -> Tuple[ + Tuple[jnp.ndarray, jnp.ndarray], + Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], +]: + """Forward-mode of `cross_entropy_with_logits`.""" + max_logit = logits.max(axis=-1, keepdims=True) + shifted = logits - max_logit + exp_shifted = jnp.exp(shifted) + sum_exp = jnp.sum(exp_shifted, axis=-1, keepdims=True) + log_softmax = shifted - jnp.log(sum_exp) + loss = -jnp.sum(targets * log_softmax, axis=-1) + # Add auxiliary z-loss term. + log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) + total_z_loss = z_loss * jax.lax.square(log_z) + loss += total_z_loss + return (loss, total_z_loss), ( + logits, + targets, + z_loss, + exp_shifted, + sum_exp, # pytype: disable=bad-return-type #jax-ndarray + log_softmax, + log_z, + ) + + +def _cross_entropy_with_logits_bwd( + res: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], + g: Tuple[jnp.ndarray, jnp.ndarray], +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Backward-mode of `cross_entropy_with_logits`.""" + g = g[0] # Ignore z_loss component as that is only used for logging. + logits, targets, z_loss, exp_shifted, sum_exp, log_softmax, log_z = res + # z-loss term adds the (2 * z_loss * log_z) factor. + deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets + g_logits = jnp.expand_dims(g, axis=-1) * deriv + g_targets = -jnp.expand_dims(g, axis=-1) * log_softmax + return ( + jnp.asarray(g_logits, logits.dtype), + jnp.asarray(g_targets, targets.dtype), + jnp.array(0.0), + ) # sets z-loss coeff gradient to 0 + + +cross_entropy_with_logits.defvjp(_cross_entropy_with_logits_fwd, _cross_entropy_with_logits_bwd) + + +def get_abstract_state(model, tx, config, rng, mesh, is_training=True): + """Get a shaped abstraction of the state (including optimizer)""" + init_state_partial = functools.partial(init_initial_state, model, tx, config, is_training) + + with nn_partitioning.axis_rules(config.logical_axis_rules): + abstract_state = jax.eval_shape(init_state_partial, rng) + + state_logical_annotations = nn.get_partition_spec(abstract_state) + + state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) + + abstract_sharded_state = jax.jit(init_state_partial, in_shardings=None, out_shardings=state_mesh_shardings).eval_shape(rng) + + unboxed_abstract_sharded_state = unbox_logicallypartioned(abstract_sharded_state) + # Initialization + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) + return unboxed_abstract_sharded_state, state_mesh_annotations, state_mesh_shardings + + +def get_kv_cache_annotations(model, config, rng, mesh): + """Get a shaped abstraction of the state (including optimizer)""" + + def init_kv_cache(model, config): + input_shape = (config.global_batch_size_to_load, config.max_prefill_predict_length) + + model_vars = model.init( + {"params": rng, "dropout": rng, "aqt": rng}, + jnp.ones(input_shape), + jnp.ones(input_shape), + model_mode=common_types.MODEL_MODE_PREFILL, + ) + return model_vars["cache"] + + with nn_partitioning.axis_rules(config.logical_axis_rules): + init_kv_cache_partial = functools.partial(init_kv_cache, model, config) + abstract_state = jax.eval_shape(init_kv_cache_partial) + state_logical_annotations = nn.get_partition_spec(abstract_state) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) + return state_mesh_annotations + + +def print_pytree_shape(print_str, ptree): + print("\n") + print(print_str) + print(jax.tree_util.tree_map(lambda x: x.shape, ptree)) + + +def print_model_vars(print_str, model_vars): + for k in model_vars: + print(f"{print_str} key{k}:") + print(f"\t {model_vars[k]}") + + +def get_project(): + completed_command = subprocess.run(["gcloud", "config", "get", "project"], check=True, capture_output=True) + project_outputs = completed_command.stdout.decode().strip().split("\n") + if len(project_outputs) < 1 or project_outputs[-1] == "": + max_logging.log("You must specify config.vertex_tensorboard_project or set 'gcloud config set project '") + return None + return project_outputs[-1] + + +def delete_pytree(p): + def delete_leaf(leaf): + if isinstance(leaf, jax.Array): + leaf.delete() + del leaf + + jax.tree_util.tree_map(delete_leaf, p) + + +def summarize_pytree_data(params, name="Params", raw=False): + """Generate basic metrics of a given Pytree.""" + num_params, total_param_size, avg_param_size = summarize_size_from_pytree(params) + if not raw: + num_params_in_billions = num_params / 1e9 + total_param_size_in_gb = total_param_size / 1e9 + print(f"{name} stats: \n" + f"\tTotal number of params: {num_params_in_billions:.3f} billion \n" + f"\tTotal memory usage: {total_param_size_in_gb:.3f} GB \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n") + else: + print(f"{name} stats: \n" + f"\tTotal number of params: {num_params:.3f} \n" + f"\tTotal memory usage: {total_param_size:.3f} bytes \n" + f"\tAvg size: {avg_param_size:.3f} bytes\n") + return num_params, total_param_size, avg_param_size diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/maxtext_utils.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/maxtext_utils.py new file mode 100644 index 000000000..b07a80df5 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/maxtext_utils.py @@ -0,0 +1,230 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# pylint: disable=bare-except, consider-using-generator +"""Utils that are only interesting to MaxText. """ + +import jax +import optax +import max_utils +from jax.sharding import PartitionSpec as P +from jax.experimental.serialize_executable import deserialize_and_load + + +import pickle +import functools +from input_pipeline import input_pipeline_interface + +OVERWRITE_WITH_GRADIENT = "_overwrite_with_gradient" + +def get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config): + """Get the shardings (both state and data) for train_step""" + functional_train = get_functional_train_step(train_step, model, config) + functional_train.__name__ = "train_step" + data_pspec = P(*config.data_sharding) + state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = (state_mesh_shardings, None) # State, metrics + static_argnums = () # We partial out the static argnums of model and config + donate_argnums = 0 # This is the index of the state - we allow the compiler to make use of this memory. + return functional_train, in_shardings, out_shardings, static_argnums, donate_argnums + + +def get_functional_train_step(train_step, model, config): + return functools.partial(train_step, model, config) + + +def get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config): + """Get the shardings (both state and data) for eval_step""" + functional_eval = get_functional_eval_step(eval_step, model, config) + functional_eval.__name__ = "eval_step" + data_pspec = P(*config.data_sharding) + state_mesh_shardings = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_util.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + in_shardings = (state_mesh_shardings, data_sharding, None) # State, batch, rng + out_shardings = None # metrics + static_argnums = () # We partial out the static argnums of model, config + donate_argnums = () # state will be kept instead of being donated in eval_step + return functional_eval, in_shardings, out_shardings, static_argnums, donate_argnums + + +def get_functional_eval_step(eval_step, model, config): + return functools.partial(eval_step, model, config) + + +def load_compiled(config, partial_train, state): + """# Loading a serialized compiled train step function.""" + + # Currently partial_train and state are needed to reconstruct + # input/output shapes to construct the in_trees and out_trees for load API + # Parker is working on a serializing these + def load_serialized_compiled(save_name): + with open(save_name, "rb") as f: + serialized_compiled = pickle.load(f) + return serialized_compiled + + def get_train_input_output_trees(func, input_args, input_kwargs): + _, in_tree_recreated = jax.tree_util.tree_flatten((input_args, input_kwargs)) + out_shaped = jax.eval_shape(func, *input_args, **input_kwargs) + _, out_tree_recreated = jax.tree_util.tree_flatten(out_shaped) + return in_tree_recreated, out_tree_recreated + + serialized_compiled = load_serialized_compiled(config.compiled_trainstep_file) + shaped_batch = input_pipeline_interface.get_shaped_batch(config) + example_rng = jax.random.PRNGKey(0) + shaped_input_args = (state, shaped_batch, example_rng) + shaped_input_kwargs = {} + in_tree, out_tree = get_train_input_output_trees(partial_train, shaped_input_args, shaped_input_kwargs) + p_train_step = deserialize_and_load(serialized_compiled, in_tree, out_tree) + return p_train_step + + +def calculate_tflops_training_per_device(config, log=True): + """Calculate training TFLOP""" + ffn1_flops = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.mlp_dim + * config.emb_dim + * len(config.mlp_activations) + ) + ffn2_flops = 2 * config.per_device_batch_size * config.max_target_length * config.mlp_dim * config.emb_dim + total_ffn_flops = ffn1_flops + ffn2_flops + + if config.num_experts > 1: + # MoE: brute force implementation + gate_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_experts + total_ffn_flops = gate_flops + config.num_experts_per_tok * total_ffn_flops + + qkv_flops = ( + 2 + * config.per_device_batch_size + * config.max_target_length + * config.emb_dim + * (config.num_query_heads + 2 * config.num_kv_heads) + * config.head_dim + ) + attention_flops = ( + 4 * config.per_device_batch_size * config.max_target_length**2 * config.num_query_heads * config.head_dim + ) + projection_flops = ( + 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.num_query_heads * config.head_dim + ) + embedding_flops = 2 * config.per_device_batch_size * config.max_target_length * config.emb_dim * config.vocab_size + + # multiply by 3 for both feed forward and back proporgation flops + learnable_weight_tflops = ( + ((total_ffn_flops + qkv_flops + projection_flops) * config.num_decoder_layers + embedding_flops) * 3 / 10**12 + ) + # megatron tflops calculation does not account for causality in attention + attention_tflops = ( + attention_flops * config.num_decoder_layers * 3 / 10**12 + ) + + total_tflops = learnable_weight_tflops + attention_tflops + + if log: + print( + "Per train step:\n", + f"Total TFLOPs: {total_tflops:.2f} \n", + f"split as {100 * learnable_weight_tflops/total_tflops:.2f}% learnable weight flops", + f"and {100 * attention_tflops/total_tflops:.2f}% attention flops", + ) + return total_tflops, learnable_weight_tflops, attention_tflops + + +# https://arxiv.org/pdf/2204.02311.pdf Appendix B +def calculate_prefill_tflops_per_device(num_model_parameters, prefill_length, config, log=True): + """Calculate training TFLOP""" + learnable_weight_tflops = 2 * num_model_parameters * prefill_length / jax.device_count() / 1e12 + noncasual_attention_flops = ( + 4 + * config.num_query_heads + * config.num_decoder_layers + * config.head_dim + * prefill_length**2 + / jax.device_count() + / 1e12 + ) + causal_attention_tflops = noncasual_attention_flops / 2 # due to causality in attention + total_tflops = learnable_weight_tflops + causal_attention_tflops + + if log: + print( + "Per prefill step per device: \n", + f"\tTotal TFLOPs: {total_tflops:.2f} \n", + f"\t\tLearnable weight TFLOPs: {learnable_weight_tflops:.2f} ", + f"({100 * learnable_weight_tflops/total_tflops:.2f})% of Total\n", + f"\t\tCausal attention TFLOPs: {causal_attention_tflops:.2f} ", + f"({100 * causal_attention_tflops/total_tflops:.2f})% of Total", + ) + return total_tflops, learnable_weight_tflops, causal_attention_tflops + + +def assert_params_sufficiently_sharded(params, mesh, tolerance=0.02): + """Checks whether most params are sharded across sharding axis. + + This function determines whether the majority of parameters are distributed + across a specified sharding axes with an acceptable tolerance. It compares the + current distribution to a scenario where all parameters are fully sharded + across the 'fsdp', 'fsdp_transpose', 'sequence', and 'tensor' axes. + + Args: + params: params of the model state + mesh: mesh constructed from config + tolerance: float between 0.0 and 1.0 representing the allowed percentage of + non-sharded parameters. + Returns: + bool: True if the majority of parameters are sufficiently sharded + """ + total_num_params = max_utils.calculate_num_params_from_pytree(params) + product_num_devices_for_weight_sharding = 1 + for axis in ["fsdp", "fsdp_transpose", "sequence", "tensor"]: + product_num_devices_for_weight_sharding *= mesh.shape[axis] + total_num_params_per_chip = max_utils.calculate_total_params_per_chip(params) + perfectly_sharded_params_per_chip = total_num_params / product_num_devices_for_weight_sharding + assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, ( + "Number of parameters per chip must not be less than in the ideal sharded " + "scenario across `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes." + ) + assert total_num_params_per_chip / perfectly_sharded_params_per_chip - 1 < tolerance, ( + f"Number of unsharded parameters exceeds tolerance {tolerance * 100}% " "of total parameters." + ) + +def apply_gradient_clipping(raw_grads, state, clipping_threshold): + """Applies gradient clipping to raw gradients, with special handing for FLAX fp8 stats. + + Args: + raw_grads: A pytree of raw gradients. + state: The current optimizer state. + clipping_threshold: The gradient clipping threshold. + + Returns: + A pytree of clipped gradients. + """ + gradient_clip_transformation = optax.clip_by_global_norm(clipping_threshold) + if OVERWRITE_WITH_GRADIENT in raw_grads: + # Scales + Amax History for Delayed Tensor Scaling SHOULD NOT be clipped or affect clipping + fp8_stats = raw_grads.pop(OVERWRITE_WITH_GRADIENT) + grads, _ = gradient_clip_transformation.update(raw_grads, state, None) + grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands + raw_grads[OVERWRITE_WITH_GRADIENT] = fp8_stats # pytype: disable=unsupported-operands + else: + grads, _ = gradient_clip_transformation.update(raw_grads, state, None) + + return grads diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/train.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/train.py new file mode 100644 index 000000000..b929a9f22 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/opt/maxtext/MaxText/train.py @@ -0,0 +1,575 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +# pylint: disable=g-bad-todo, abstract-method, consider-using-with, ungrouped-imports +"""Training loop and Decoding of the model.""" + +# Calling jax.device_count here prevents a "TPU platform already registered" error. +# See github.com/google/maxtext/issues/20 for more + +import datetime +import os +import sys +import functools + +from typing import Sequence +from absl import app +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import grain.python as grain +import jax +import numpy as np +import orbax.checkpoint + +import checkpointing +import max_utils +import maxtext_utils +import max_logging +import optimizers +import profiler +import pyconfig +# pylint: disable-next=unused-import +import register_jax_proxy_backend +from vertex_tensorboard import VertexTensorboardManager + +from input_pipeline.input_pipeline_interface import create_data_iterator_with_tokenizer +from layers import models + +import jax.numpy as jnp +from jax import random +from jax.sharding import Mesh +from jax.experimental import checkify + +from cloud_tpu_diagnostics import diagnostic +from cloud_tpu_diagnostics.configuration import debug_configuration +from cloud_tpu_diagnostics.configuration import diagnostic_configuration +from cloud_tpu_diagnostics.configuration import stack_trace_configuration + +from layers import quantizations + +from ml_goodput_measurement import goodput + +Transformer = models.Transformer +EPS = 1e-8 + + +def validate_train_config(config): + """Validates the configuration is set correctly for train.py""" + + assert config.run_name, "Erroring out, need a real run_name" + if not config.dataset_path.startswith("gs://"): + max_logging.log("WARNING: 'dataset_path' might be pointing your local file system") + if not config.base_output_directory.startswith("gs://"): + max_logging.log("WARNING: 'base_output_directory' might be pointing your local file system") + assert config.steps > 0, "You must set steps or learning_rate_schedule_steps to a positive integer." + + +def get_first_step(state): + with jax.spmd_mode("allow_all"): + return int(state.step) + + +def load_next_batch(train_iter, example_batch, config): + """Loads the next batch. Can keep reusing the same batch for performance reasons""" + + if config.reuse_example_batch and example_batch is not None: + return example_batch + else: + return next(train_iter) + + +def record_scalar_metrics(metrics, step_time_delta, per_device_tflops, lr): + """Records scalar metrics to be written to tensorboard""" + metrics["scalar"].update({"perf/step_time_seconds": step_time_delta.total_seconds()}) + metrics["scalar"].update({"perf/per_device_tflops": per_device_tflops}) + metrics["scalar"].update({"perf/per_device_tflops_per_sec": per_device_tflops / step_time_delta.total_seconds()}) + metrics["scalar"].update({"learning/current_learning_rate": lr}) + + +_buffered_step = None +_buffered_metrics = None + + +def write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config): + """Entry point for all metrics writing in Train's Main. + TODO: would be better as a Class in the future (that initialized all state!) + + To avoid introducing an unnecessary dependency, we "double buffer" -- we hold + onto the last metrics and step and only publish when we receive a new metrics and step. + The logic is that this ensures that Jax is able to queues train_steps and we + don't block when turning "lazy" Jax arrays into real Python numbers. + """ + global _buffered_step, _buffered_metrics + + if _buffered_metrics is not None: + if _buffered_step is None: + raise ValueError(f"When writing metrics, {_buffered_step=} was none") + write_metrics_to_tensorboard(writer, _buffered_metrics, _buffered_step, config) + + if config.metrics_file: + max_utils.write_metrics_locally(_buffered_metrics, _buffered_step, config, local_metrics_file) + + if config.gcs_metrics and jax.process_index() == 0: + running_gcs_metrics = max_utils.write_metrics_for_gcs(_buffered_metrics, _buffered_step, config, running_gcs_metrics) + + _buffered_step = step + _buffered_metrics = metrics + + +def write_metrics_to_tensorboard(writer, metrics, step, config): + """Writes metrics to tensorboard""" + with jax.spmd_mode("allow_all"): + if jax.process_index() == 0: + for metric_name in metrics.get("scalar", []): + writer.add_scalar(metric_name, np.array(metrics["scalar"][metric_name]), step) + for metric_name in metrics.get("scalars", []): + writer.add_scalars(metric_name, metrics["scalars"][metric_name], step) + + full_log = step % config.log_period == 0 + + max_logging.log( + f"completed step: {step}, seconds: {metrics['scalar']['perf/step_time_seconds']:.3f}, " + f"TFLOP/s/device: {metrics['scalar']['perf/per_device_tflops_per_sec']:.3f}, " + f"loss: {metrics['scalar']['learning/loss']:.3f}" + ) + + if full_log and jax.process_index() == 0: + max_logging.log(f"To see full metrics 'tensorboard --logdir={config.tensorboard_dir}'") + writer.flush() + + +def save_checkpoint(checkpoint_manager, step, state, dataset_type="c4", data_iterator=None): + """Wrapper for saving checkpoint""" + if dataset_type == "c4-array_record": + return checkpoint_manager.save( + step, + args=orbax.checkpoint.args.Composite( + items=orbax.checkpoint.args.PyTreeSave(item=state), + iter=grain.PyGrainCheckpointSave(data_iterator.local_iterator), + ), + ) + else: + return checkpoint_manager.save( + step, args=orbax.checkpoint.args.Composite(items=orbax.checkpoint.args.PyTreeSave(item=state)) + ) + + +# ----------------------------------------------------------------------------- +# Top-level Functions +# ----------------------------------------------------------------------------- + + +def record_activation_metrics(output_metrics, intermediate_outputs, config): + """Adds the activation metrics to the metrics dict""" + + if config.scan_layers: + metrics_dict = intermediate_outputs["intermediates"]["decoder"]["decoder"] + + for layer_num in range(config.num_decoder_layers): + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = metrics_dict["activation_fraction_zero"][0][ + layer_num + ] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = metrics_dict["activation_mean"][0][layer_num] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = metrics_dict["activation_stdev"][0][layer_num] + else: + for layer_num in range(config.num_decoder_layers): + layer = intermediate_outputs["intermediates"]["decoder"][f"layers_{layer_num}"] + output_metrics["scalar"][f"activ_fraction_zero/layer_{layer_num:03d}"] = layer["activation_fraction_zero"][0] + output_metrics["scalar"][f"activ_mean/layer_{layer_num:03d}"] = layer["activation_mean"][0] + output_metrics["scalar"][f"activ_stdev/layer_{layer_num:03d}"] = layer["activation_stdev"][0] + + +def loss_fn(model, config, data, dropout_rng, params, is_train=True): + """loss_fn for both train and eval. + + Args: + model: A nn.Module + config: Config of parameters + data: Batch of data to apply to the model + dropout_rng: A key to use to generate rng for dropout + params: Model params + is_train: True for train_step and False for eval_step + + Returns: + loss: average loss + aux: a dictionary including intermediate_outputs, total_loss, and total_weights + """ + # inputs, targets, segments, positions = apply_args + rng1, aqt_rng = jax.random.split(dropout_rng) + + # decimate proportion of data when per_device_batch_size<1 + if is_train: + for k, v in data.items(): + data[k] = v[: config.global_batch_size_to_train_on, :] + + logits, intermediate_outputs = model.apply( + params, + data["inputs"], + data["inputs_position"], + decoder_segment_ids=data["inputs_segmentation"], + enable_dropout=config.enable_dropout if is_train else False, + rngs={"dropout": rng1, "params": aqt_rng}, + mutable="intermediates", + ) + one_hot_targets = jax.nn.one_hot(data["targets"], config.vocab_size) + xent, _ = max_utils.cross_entropy_with_logits(logits, one_hot_targets, 0.0) + xent = nn.with_logical_constraint(xent, ("activation_batch", "activation_length")) + # Mask out paddings at the end of each example. + xent = xent * (data["targets_segmentation"] != 0) + total_loss = jnp.sum(xent) + total_weights = jnp.sum(data["targets_segmentation"] != 0) + loss = total_loss / (total_weights + EPS) + aux = { + "intermediate_outputs": intermediate_outputs, + "total_loss": total_loss, + "total_weights": total_weights, + } + return loss, aux + + +def train_step(model, config, state, data, dropout_rng): + """ + + Args: + model: A nn.Module + state: A pytree of the current state of the model + data: Batch of data to apply to the model + dropout_rng: A key to use to generate rng for dropout + + Returns: + new_state: Same format as state. + metrics: Dictionary of model metrics such as loss, training rate, etc. + rng2: A new rng key that can be used in future calls. + + """ + train_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=True) + grad_fn = jax.value_and_grad(train_loss_fn, has_aux=True) + (loss, aux), raw_grads = grad_fn(state.params) + intermediate_outputs = aux["intermediate_outputs"] + + if config.gradient_clipping_threshold > 0: + grads = maxtext_utils.apply_gradient_clipping(raw_grads, state, config.gradient_clipping_threshold) + else: + grads = raw_grads + new_state = state.apply_gradients(grads=grads) + metrics = { + "scalar": { + "learning/loss": loss, + "learning/grad_norm": max_utils.l2norm_pytree(grads), + "learning/raw_grad_norm": max_utils.l2norm_pytree(raw_grads), + "learning/param_norm": max_utils.l2norm_pytree(new_state.params), + }, + "scalars": {}, + } + + if config.record_internal_nn_metrics: + record_activation_metrics(metrics, intermediate_outputs, config) + + return new_state, metrics + + +def eval_step(model, config, state, data, dropout_rng): + """eval_step no backprop and new state compared with train_step.""" + eval_loss_fn = functools.partial(loss_fn, model, config, data, dropout_rng, is_train=False) + loss, aux = eval_loss_fn(state.params) + total_loss = aux["total_loss"] + total_weights = aux["total_weights"] + metrics = { + "scalar": {"evaluation/loss": loss, "evaluation/total_loss": total_loss, "evaluation/total_weights": total_weights} + } + + return metrics + + +def create_goodput_recorder(config): + if config.enable_goodput_recording: + logger_name = f"goodput_{config.run_name}" + recorder = goodput.GoodputRecorder(config.run_name, logger_name, jax.process_index() == 0) + return recorder + return None + + +def record_goodput(recorder, config, step=None, job_start=False, job_end=False): + if recorder and config.enable_goodput_recording: + if job_start and step is None: + recorder.record_job_start_time() + if job_end and step is None: + recorder.record_job_end_time() + if step is not None: + recorder.record_step_start_time(step) + +def check_example_batch(config, example_batch): + if config.max_checkify: + jittable_f = checkify.checkify( + lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!") + ) + # Check if inputs in batch contains bad synthetic data. + err, _ = jax.jit(jittable_f)(example_batch['inputs'][: config.global_batch_size_to_train_on, :]) + err.throw() + +def setup_mesh_and_model(config): + """Set up the mesh and the model for training + + Args: + config + + Returns: + init_rng: RNG key + writer: Summary writer for tensorboard + checkpoint_manager: Orbax checkpointer + state_mesh_annotations: the mesh annotations for the train state + model: + mesh: + learning_rate_schedule: + tx: + """ + + init_rng = random.PRNGKey(config.init_weights_seed) + writer = max_utils.initialize_summary_writer(config) + logger = checkpointing.setup_checkpoint_logger(config) + checkpoint_manager = checkpointing.create_orbax_checkpoint_manager( + config.checkpoint_dir, + config.enable_checkpointing, + config.async_checkpointing, + config.checkpoint_period, + config.dataset_type, + logger, + ) + # Mesh definition + devices_array = max_utils.create_device_mesh(config) + mesh = Mesh(devices_array, config.mesh_axes) + + # Model and Optimizer definition + quant = quantizations.configure_quantization(config) + model = Transformer(config, mesh, quant=quant) + learning_rate_schedule = max_utils.create_learning_rate_schedule(config) + tx = optimizers.get_optimizer(config, learning_rate_schedule) + return init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx + + +def setup_train_loop(config): + """Set up prerequisites for the training loop - + checkpoint_manager, PRNG keys, Mesh, Model and optimizer. + Set up data iterator and tokenizer, initialize the model. + + Args: + config + + Returns: + init_rng: + writer: Summary writer for tensorboard + checkpoint_manager: Orbax checkpointer + state_mesh_annotations: the mesh annotations for the train state + model: + mesh: + learning_rate_schedule: + data_iterator: + state: the initialized train state + """ + init_rng, writer, checkpoint_manager, mesh, model, learning_rate_schedule, tx = setup_mesh_and_model(config) + data_iterator, eval_data_iterator, _ = create_data_iterator_with_tokenizer(config, mesh) + + state, state_mesh_annotations, data_iterator = max_utils.setup_training_state( + model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager + ) + + maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh) + + return ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) + + +def train_loop(config, state=None): + """Main Training loop. + Args: + config: + state: + ckpt_path: + Returns: + """ + # Create a GoodputRecorder to log information + recorder = create_goodput_recorder(config) + record_goodput(recorder, config, job_start=True) + + ( + init_rng, + writer, + checkpoint_manager, + state_mesh_annotations, + model, + mesh, + learning_rate_schedule, + data_iterator, + eval_data_iterator, + state, + ) = setup_train_loop(config) + # pylint: disable=line-too-long + ( + functional_train, + in_shard_train, + out_shard_train, + static_argnums_train, + donate_argnums_train, + ) = maxtext_utils.get_functional_train_with_signature(train_step, mesh, state_mesh_annotations, model, config) + + if eval_data_iterator: + # pylint: disable=line-too-long + ( + functional_eval, + in_shard_eval, + out_shard_eval, + static_argnums_eval, + donate_argnums_eval, + ) = maxtext_utils.get_functional_eval_with_signature(eval_step, mesh, state_mesh_annotations, model, config) + + num_model_parameters = max_utils.calculate_num_params_from_pytree(state.params) + max_logging.log(f"number parameters: {num_model_parameters/1e9:.3f} billion") + per_device_tflops, _, _ = maxtext_utils.calculate_tflops_training_per_device(config) + + # Write train config params, num model params, and XLA flags to tensorboard + max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer) + max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer) + max_utils.add_config_to_summary_writer(config, writer) + + # Define the compilation of functional_train, either by loading the compiled version or wrapping a new one in a jit + if config.compiled_trainstep_file != "": + print("Loading the compiled function...", flush=True) + # Need to pass train signature and state to determine i/o shapes of train_state for now. + p_train_step = maxtext_utils.load_compiled(config, functional_train, state) + # TODO: p_eval_step is not yet supported in load_compiled + p_eval_step = None + print("Loaded compiled function!", flush=True) + else: + p_train_step = jax.jit( + functional_train, + in_shardings=in_shard_train, + out_shardings=out_shard_train, + static_argnums=static_argnums_train, + donate_argnums=donate_argnums_train, + ) + + if eval_data_iterator: + p_eval_step = jax.jit( + functional_eval, + in_shardings=in_shard_eval, + out_shardings=out_shard_eval, + static_argnums=static_argnums_eval, + donate_argnums=donate_argnums_eval, + ) + else: + p_eval_step = None + + local_metrics_file = open(config.metrics_file, "a", encoding="utf8") if config.metrics_file else None + running_gcs_metrics = [] if config.gcs_metrics else None + + start_step = get_first_step(state) # this is the start_step for training + first_profiling_step = start_step + config.skip_first_n_steps_for_profiler + if config.profiler != "" and first_profiling_step >= config.steps: + raise ValueError("Profiling requested but initial profiling step set past training final step") + last_profiling_step = np.clip(first_profiling_step + config.profiler_steps - 1, first_profiling_step, config.steps - 1) + + example_batch = None + last_step_completion = datetime.datetime.now() + prof = profiler.Profiler(config) + for step in np.arange(start_step, config.steps): + if step == first_profiling_step: + prof.activate() + + with jax.profiler.StepTraceAnnotation("train", step_num=step): + example_batch = load_next_batch(data_iterator, example_batch, config) + check_example_batch(config, example_batch=example_batch) + nextrng = jax.jit(jax.random.fold_in)(init_rng, step) + record_goodput(recorder, config, step=step) + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + state, metrics = p_train_step(state, example_batch, nextrng) + + new_time = datetime.datetime.now() + record_scalar_metrics(metrics, new_time - last_step_completion, per_device_tflops, learning_rate_schedule(step)) + last_step_completion = new_time + + if checkpoint_manager is not None: + if save_checkpoint(checkpoint_manager, int(step), state, config.dataset_type, data_iterator): + max_logging.log(f"saved a checkpoint at step {step}") + + # Upon preemption, exit when and only when all ongoing saves are complete. + if checkpoint_manager.reached_preemption(step): + checkpoint_manager.wait_until_finished() + sys.exit() + + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, step, config) + + if config.eval_interval > 0 and step > start_step and step % config.eval_interval == 0: + assert eval_data_iterator + cumulative_eval_metrics = {"total_loss": 0.0, "total_weights": 0.0} + for eval_batch in eval_data_iterator: + with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + eval_metrics = p_eval_step(state, eval_batch, nextrng) + cumulative_eval_metrics["total_loss"] += float(eval_metrics["scalar"]["evaluation/total_loss"]) + cumulative_eval_metrics["total_weights"] += float(eval_metrics["scalar"]["evaluation/total_weights"]) + eval_loss = cumulative_eval_metrics["total_loss"] / (cumulative_eval_metrics["total_weights"] + EPS) + max_logging.log(f"average loss after {step=}: {eval_loss=}, total_weights={cumulative_eval_metrics['total_weights']}") + if eval_loss <= config.target_eval_loss: + max_logging.log(f"Early stop and exit loop after reaching {config.target_eval_loss=}") + prof.deactivate() + break + + if step == last_profiling_step: + prof.deactivate() + + if checkpoint_manager is not None: + checkpoint_manager.wait_until_finished() + write_metrics(writer, local_metrics_file, running_gcs_metrics, metrics, config.steps - 1, config) # final step metrics + max_utils.close_summary_writer(writer) + record_goodput(recorder, config, job_end=True) + return state + + +def main(argv: Sequence[str]) -> None: + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS", "") + " --xla_tpu_spmd_rng_bit_generator_unsafe=true" + pyconfig.initialize(argv) + config = pyconfig.config + validate_train_config(config) + os.environ["TFDS_DATA_DIR"] = config.dataset_path + vertex_tensorboard_manager = VertexTensorboardManager() + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) + + debug_config = debug_configuration.DebugConfig( + stack_trace_config=stack_trace_configuration.StackTraceConfig( + collect_stack_trace=config.collect_stack_trace, + stack_trace_to_cloud=config.stack_trace_to_cloud, + stack_trace_interval_seconds=config.stack_trace_interval_seconds, + ) + ) + diagnostic_config = diagnostic_configuration.DiagnosticConfig(debug_config) + with diagnostic.diagnose(diagnostic_config): + train_loop(config) + + +if __name__ == "__main__": + app.run(main) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/absl/app.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/absl/app.py new file mode 100644 index 000000000..d12397b31 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/absl/app.py @@ -0,0 +1,480 @@ +# Copyright 2017 The Abseil Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic entry point for Abseil Python applications. + +To use this module, define a ``main`` function with a single ``argv`` argument +and call ``app.run(main)``. For example:: + + def main(argv): + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + if __name__ == '__main__': + app.run(main) +""" + +import collections +import errno +import os +import pdb +import sys +import textwrap +import traceback + +from absl import command_name +from absl import flags +from absl import logging + +try: + import faulthandler +except ImportError: + faulthandler = None + +FLAGS = flags.FLAGS + +flags.DEFINE_boolean('run_with_pdb', False, 'Set to true for PDB debug mode') +flags.DEFINE_boolean('pdb_post_mortem', False, + 'Set to true to handle uncaught exceptions with PDB ' + 'post mortem.') +flags.DEFINE_alias('pdb', 'pdb_post_mortem') +flags.DEFINE_boolean('run_with_profiling', False, + 'Set to true for profiling the script. ' + 'Execution will be slower, and the output format might ' + 'change over time.') +flags.DEFINE_string('profile_file', None, + 'Dump profile information to a file (for python -m ' + 'pstats). Implies --run_with_profiling.') +flags.DEFINE_boolean('use_cprofile_for_profiling', True, + 'Use cProfile instead of the profile module for ' + 'profiling. This has no effect unless ' + '--run_with_profiling is set.') +flags.DEFINE_boolean('only_check_args', False, + 'Set to true to validate args and exit.', + allow_hide_cpp=True) + + +# If main() exits via an abnormal exception, call into these +# handlers before exiting. +EXCEPTION_HANDLERS = [] + + +class Error(Exception): + pass + + +class UsageError(Error): + """Exception raised when the arguments supplied by the user are invalid. + + Raise this when the arguments supplied are invalid from the point of + view of the application. For example when two mutually exclusive + flags have been supplied or when there are not enough non-flag + arguments. It is distinct from flags.Error which covers the lower + level of parsing and validating individual flags. + """ + + def __init__(self, message, exitcode=1): + super(UsageError, self).__init__(message) + self.exitcode = exitcode + + +class HelpFlag(flags.BooleanFlag): + """Special boolean flag that displays usage and raises SystemExit.""" + NAME = 'help' + SHORT_NAME = '?' + + def __init__(self): + super(HelpFlag, self).__init__( + self.NAME, False, 'show this help', + short_name=self.SHORT_NAME, allow_hide_cpp=True) + + def parse(self, arg): + if self._parse(arg): + usage(shorthelp=True, writeto_stdout=True) + # Advertise --helpfull on stdout, since usage() was on stdout. + print() + print('Try --helpfull to get a list of all flags.') + sys.exit(1) + + +class HelpshortFlag(HelpFlag): + """--helpshort is an alias for --help.""" + NAME = 'helpshort' + SHORT_NAME = None + + +class HelpfullFlag(flags.BooleanFlag): + """Display help for flags in the main module and all dependent modules.""" + + def __init__(self): + super(HelpfullFlag, self).__init__( + 'helpfull', False, 'show full help', allow_hide_cpp=True) + + def parse(self, arg): + if self._parse(arg): + usage(writeto_stdout=True) + sys.exit(1) + + +class HelpXMLFlag(flags.BooleanFlag): + """Similar to HelpfullFlag, but generates output in XML format.""" + + def __init__(self): + super(HelpXMLFlag, self).__init__( + 'helpxml', False, 'like --helpfull, but generates XML output', + allow_hide_cpp=True) + + def parse(self, arg): + if self._parse(arg): + flags.FLAGS.write_help_in_xml_format(sys.stdout) + sys.exit(1) + + +def parse_flags_with_usage(args): + """Tries to parse the flags, print usage, and exit if unparsable. + + Args: + args: [str], a non-empty list of the command line arguments including + program name. + + Returns: + [str], a non-empty list of remaining command line arguments after parsing + flags, including program name. + """ + try: + return FLAGS(args) + except flags.Error as error: + message = str(error) + if '\n' in message: + final_message = 'FATAL Flags parsing error:\n%s\n' % textwrap.indent( + message, ' ') + else: + final_message = 'FATAL Flags parsing error: %s\n' % message + sys.stderr.write(final_message) + sys.stderr.write('Pass --helpshort or --helpfull to see help on flags.\n') + sys.exit(1) + + +_define_help_flags_called = False + + +def define_help_flags(): + """Registers help flags. Idempotent.""" + # Use a global to ensure idempotence. + global _define_help_flags_called + + if not _define_help_flags_called: + flags.DEFINE_flag(HelpFlag()) + flags.DEFINE_flag(HelpshortFlag()) # alias for --help + flags.DEFINE_flag(HelpfullFlag()) + flags.DEFINE_flag(HelpXMLFlag()) + _define_help_flags_called = True + + +def _register_and_parse_flags_with_usage( + argv=None, + flags_parser=parse_flags_with_usage, +): + """Registers help flags, parses arguments and shows usage if appropriate. + + This also calls sys.exit(0) if flag --only_check_args is True. + + Args: + argv: [str], a non-empty list of the command line arguments including + program name, sys.argv is used if None. + flags_parser: Callable[[List[Text]], Any], the function used to parse flags. + The return value of this function is passed to `main` untouched. + It must guarantee FLAGS is parsed after this function is called. + + Returns: + The return value of `flags_parser`. When using the default `flags_parser`, + it returns the following: + [str], a non-empty list of remaining command line arguments after parsing + flags, including program name. + + Raises: + Error: Raised when flags_parser is called, but FLAGS is not parsed. + SystemError: Raised when it's called more than once. + """ + if _register_and_parse_flags_with_usage.done: + raise SystemError('Flag registration can be done only once.') + + define_help_flags() + + original_argv = sys.argv if argv is None else argv + args_to_main = flags_parser(original_argv) + if not FLAGS.is_parsed(): + raise Error('FLAGS must be parsed after flags_parser is called.') + + # Exit when told so. + if FLAGS.only_check_args: + sys.exit(0) + # Immediately after flags are parsed, bump verbosity to INFO if the flag has + # not been set. + if FLAGS['verbosity'].using_default_value: + FLAGS.verbosity = 0 + _register_and_parse_flags_with_usage.done = True + + return args_to_main + +_register_and_parse_flags_with_usage.done = False + + +def _run_main(main, argv): + """Calls main, optionally with pdb or profiler.""" + if FLAGS.run_with_pdb: + sys.exit(pdb.runcall(main, argv)) + elif FLAGS.run_with_profiling or FLAGS.profile_file: + # Avoid import overhead since most apps (including performance-sensitive + # ones) won't be run with profiling. + # pylint: disable=g-import-not-at-top + import atexit + if FLAGS.use_cprofile_for_profiling: + import cProfile as profile + else: + import profile + profiler = profile.Profile() + if FLAGS.profile_file: + atexit.register(profiler.dump_stats, FLAGS.profile_file) + else: + atexit.register(profiler.print_stats) + sys.exit(profiler.runcall(main, argv)) + else: + sys.exit(main(argv)) + + +def _call_exception_handlers(exception): + """Calls any installed exception handlers.""" + for handler in EXCEPTION_HANDLERS: + try: + if handler.wants(exception): + handler.handle(exception) + except: # pylint: disable=bare-except + try: + # We don't want to stop for exceptions in the exception handlers but + # we shouldn't hide them either. + logging.error(traceback.format_exc()) + except: # pylint: disable=bare-except + # In case even the logging statement fails, ignore. + pass + + +def run( + main, + argv=None, + flags_parser=parse_flags_with_usage, +): + """Begins executing the program. + + Args: + main: The main function to execute. It takes an single argument "argv", + which is a list of command line arguments with parsed flags removed. + The return value is passed to `sys.exit`, and so for example + a return value of 0 or None results in a successful termination, whereas + a return value of 1 results in abnormal termination. + For more details, see https://docs.python.org/3/library/sys#sys.exit + argv: A non-empty list of the command line arguments including program name, + sys.argv is used if None. + flags_parser: Callable[[List[Text]], Any], the function used to parse flags. + The return value of this function is passed to `main` untouched. + It must guarantee FLAGS is parsed after this function is called. + Should be passed as a keyword-only arg which will become mandatory in a + future release. + - Parses command line flags with the flag module. + - If there are any errors, prints usage(). + - Calls main() with the remaining arguments. + - If main() raises a UsageError, prints usage and the error message. + """ + try: + args = _run_init( + sys.argv if argv is None else argv, + flags_parser, + ) + while _init_callbacks: + callback = _init_callbacks.popleft() + callback() + try: + _run_main(main, args) + except UsageError as error: + usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode) + except: + exc = sys.exc_info()[1] + # Don't try to post-mortem debug successful SystemExits, since those + # mean there wasn't actually an error. In particular, the test framework + # raises SystemExit(False) even if all tests passed. + if isinstance(exc, SystemExit) and not exc.code: + raise + + # Check the tty so that we don't hang waiting for input in an + # non-interactive scenario. + if FLAGS.pdb_post_mortem and sys.stdout.isatty(): + traceback.print_exc() + print() + print(' *** Entering post-mortem debugging ***') + print() + pdb.post_mortem() + raise + except Exception as e: + _call_exception_handlers(e) + raise + +# Callbacks which have been deferred until after _run_init has been called. +_init_callbacks = collections.deque() + + +def call_after_init(callback): + """Calls the given callback only once ABSL has finished initialization. + + If ABSL has already finished initialization when ``call_after_init`` is + called then the callback is executed immediately, otherwise `callback` is + stored to be executed after ``app.run`` has finished initializing (aka. just + before the main function is called). + + If called after ``app.run``, this is equivalent to calling ``callback()`` in + the caller thread. If called before ``app.run``, callbacks are run + sequentially (in an undefined order) in the same thread as ``app.run``. + + Args: + callback: a callable to be called once ABSL has finished initialization. + This may be immediate if initialization has already finished. It + takes no arguments and returns nothing. + """ + if _run_init.done: + callback() + else: + _init_callbacks.append(callback) + + +def _run_init( + argv, + flags_parser, +): + """Does one-time initialization and re-parses flags on rerun.""" + if _run_init.done: + return flags_parser(argv) + command_name.make_process_name_useful() + # Set up absl logging handler. + logging.use_absl_handler() + args = _register_and_parse_flags_with_usage( + argv=argv, + flags_parser=flags_parser, + ) + if faulthandler: + try: + faulthandler.enable() + except Exception: # pylint: disable=broad-except + # Some tests verify stderr output very closely, so don't print anything. + # Disabled faulthandler is a low-impact error. + pass + _run_init.done = True + return args + + +_run_init.done = False + + +def usage(shorthelp=False, writeto_stdout=False, detailed_error=None, + exitcode=None): + """Writes __main__'s docstring to stderr with some help text. + + Args: + shorthelp: bool, if True, prints only flags from the main module, + rather than all flags. + writeto_stdout: bool, if True, writes help message to stdout, + rather than to stderr. + detailed_error: str, additional detail about why usage info was presented. + exitcode: optional integer, if set, exits with this status code after + writing help. + """ + if writeto_stdout: + stdfile = sys.stdout + else: + stdfile = sys.stderr + + doc = sys.modules['__main__'].__doc__ + if not doc: + doc = '\nUSAGE: %s [flags]\n' % sys.argv[0] + doc = flags.text_wrap(doc, indent=' ', firstline_indent='') + else: + # Replace all '%s' with sys.argv[0], and all '%%' with '%'. + num_specifiers = doc.count('%') - 2 * doc.count('%%') + try: + doc %= (sys.argv[0],) * num_specifiers + except (OverflowError, TypeError, ValueError): + # Just display the docstring as-is. + pass + if shorthelp: + flag_str = FLAGS.main_module_help() + else: + flag_str = FLAGS.get_help() + try: + stdfile.write(doc) + if flag_str: + stdfile.write('\nflags:\n') + stdfile.write(flag_str) + stdfile.write('\n') + if detailed_error is not None: + stdfile.write('\n%s\n' % detailed_error) + except IOError as e: + # We avoid printing a huge backtrace if we get EPIPE, because + # "foo.par --help | less" is a frequent use case. + if e.errno != errno.EPIPE: + raise + if exitcode is not None: + sys.exit(exitcode) + + +class ExceptionHandler(object): + """Base exception handler from which other may inherit.""" + + def wants(self, exc): + """Returns whether this handler wants to handle the exception or not. + + This base class returns True for all exceptions by default. Override in + subclass if it wants to be more selective. + + Args: + exc: Exception, the current exception. + """ + del exc # Unused. + return True + + def handle(self, exc): + """Do something with the current exception. + + Args: + exc: Exception, the current exception + + This method must be overridden. + """ + raise NotImplementedError() + + +def install_exception_handler(handler): + """Installs an exception handler. + + Args: + handler: ExceptionHandler, the exception handler to install. + + Raises: + TypeError: Raised when the handler was not of the correct type. + + All installed exception handlers will be called if main() exits via + an abnormal exception, i.e. not one of SystemExit, KeyboardInterrupt, + FlagsError or UsageError. + """ + if not isinstance(handler, ExceptionHandler): + raise TypeError('handler of type %s does not inherit from ExceptionHandler' + % type(handler)) + EXCEPTION_HANDLERS.append(handler) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/base.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/base.py new file mode 100644 index 000000000..38fc609aa --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/base.py @@ -0,0 +1,339 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base interfaces and datatypes.""" + +from typing import Any, Callable, NamedTuple, Optional, Protocol, runtime_checkable, Sequence, Union + +import chex +import jax +import jax.numpy as jnp + +NO_PARAMS_MSG = ( + 'You are using a transformation that requires the current value of ' + 'parameters, but you are not passing `params` when calling `update`.') + +PyTree = Any +Shape = Sequence[int] + +OptState = chex.ArrayTree # States are arbitrary nests of `jnp.ndarrays`. +Params = chex.ArrayTree # Parameters are arbitrary nests of `jnp.ndarrays`. +Updates = Params # Gradient updates are of the same type as parameters. + +Schedule = Callable[[chex.Numeric], chex.Numeric] +ScheduleState = Any +ScalarOrSchedule = Union[float, jax.Array, Schedule] + + +@runtime_checkable +class StatefulSchedule(Protocol): + """Base interface for stateful schedules.""" + + def init( + self + ) -> ScheduleState: + """Initialize the state of the stateful schedule.""" + + def update( + self, + state: ScheduleState, + **extra_args, + ) -> ScheduleState: + """Updates the current schedule state.""" + + def __call__( + self, + state: ScheduleState, + **extra_args, + ) -> chex.Numeric: + """Computes the current schedule value.""" + + +class TransformInitFn(Protocol): + """A callable type for the `init` step of a `GradientTransformation`. + + The `init` step takes a tree of `params` and uses these to construct an + arbitrary structured initial `state` for the gradient transformation. This + may hold statistics of the past updates or any other non static information. + """ + + def __call__(self, params: Params) -> OptState: + """The `init` function. + + Args: + params: The initial value of the parameters. + + Returns: + The initial state of the gradient transformation. + """ + + +class TransformUpdateFn(Protocol): + """A callable type for the `update` step of a `GradientTransformation`. + + The `update` step takes a tree of candidate parameter `updates` (e.g. their + gradient with respect to some loss), an arbitrary structured `state`, and the + current `params` of the model being optimised. The `params` argument is + optional, it must however be provided when using transformations that require + access to the current values of the parameters. + + For the case where additional arguments are required, an alternative interface + may be used, see ``TransformUpdateExtraArgsFn`` for details. + """ + + def __call__( + self, + updates: Updates, + state: OptState, + params: Optional[Params] = None + ) -> tuple[Updates, OptState]: + """The `update` function. + + Args: + updates: A tree of candidate updates. + state: The state of the gradient transformation. + params: (Optionally) the current value of the parameters. + + Returns: + The transformed updates, and the updated state. + """ + + +class TransformUpdateExtraArgsFn(Protocol): + """An update function accepting additional keyword arguments.""" + + def __call__( + self, + updates: Updates, + state: OptState, + params: Optional[Params] = None, + **extra_args: Any, + ) -> tuple[Updates, OptState]: + """Update function with optional extra arguments. + + For example, an update function that requires an additional loss parameter + (which might be useful for implementing learning rate schedules that depend + on the current loss value) could be expressed as follows: + + >>> def update(updates, state, params=None, *, loss, **extra_args): + ... del extra_args + ... # use loss value + + Note that the loss value is keyword only, (it follows a ``*`` in the + signature of the function). This implies users will get explicit errors if + they try to use this gradient transformation without providing the required + argument. + + Args: + updates: The gradient updates passed to this transformation. + state: The state associated with this transformation + params: Optional params. + **extra_args: Additional keyword arguments passed to this transform. All + implementors of this interface should accept arbitrary keyword + arguments, ignoring those that are not needed for the current + transformation. Transformations that require specific extra args should + specify these using keyword-only arguments. + Returns: + Transformed updates, and an updated value for the state. + """ + + +class GradientTransformation(NamedTuple): + """A pair of pure functions implementing a gradient transformation. + + Optax optimizers are all implemented as _gradient transformations_. + A gradient transformation is defined to be a pair of pure functions, which + are combined together in a `NamedTuple` so that they can be referred to by + name. + + Note that an extended API is provided for users wishing to build optimizers + that take additional arguments during the update step. For more details, + see ``GradientTransoformationExtraArgs``. + + Since gradient transformations do not contain any internal state, all stateful + optimizer properties (such as the current step count when using optimizer + scheduels, or momemtum values) are passed through optax gradient + transformations by using the optimizer _state_ pytree. Each time a gradient + transformation is applied, a new state is computed and returned, ready to be + passed to the next call to the gradient transformation. + + Since gradient transformations are pure, idempotent functions, the only way + to change the behaviour of a gradient transformation between steps, is to + change the values in the optimizer state. To see an example of mutating the + optimizer state in order to control the behaviour of an optax gradient + transformation, see the meta-learning example in the optax documentation. + + Attributes: + init: A pure function which, when called with an example instance of the + parameters whose gradients will be transformed, returns a pytree + containing the initial value for the optimizer state. + update: A pure function which takes as input a pytree of updates (with the + same tree structure as the original params pytree passed to init), the + previous optimizer state (which may have been initialized using the init + function), and optionally the current params. The update function then + returns the computed gradient updates, and a new optimizer state. + """ + init: TransformInitFn + update: TransformUpdateFn + + +class GradientTransformationExtraArgs(GradientTransformation): + """A specialization of GradientTransformation that supports extra args. + + Extends the existing GradientTransformation interface by adding support for + passing extra arguments to the update function. + + Note that if no extra args are provided, then the API of this function is + identical to the case of ``TransformUpdateFn``. This means that we can safely + wrap any gradient transformation (that does not support extra args) as one + that does. The new gradient transformation will accept (and ignore) any + extra arguments that a user might pass to it. This is the behavior implemented + by ``optax.with_extra_args_support()``. + + Attributes: + update: Overrides the type signature of the update in the base type to + accept extra arguments. + """ + update: TransformUpdateExtraArgsFn + + +class EmptyState(NamedTuple): + """An empty state for the simplest stateless transformations.""" + + +def identity() -> GradientTransformation: + """Stateless identity transformation that leaves input gradients untouched. + + This function passes through the *gradient updates* unchanged. + + Note, this should not to be confused with `set_to_zero`, which maps the input + updates to zero - which is the transform required for the *model parameters* + to be left unchanged when the updates are applied to them. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(_): + return EmptyState() + + def update_fn(updates, state, params=None): + del params + return updates, state + + return GradientTransformation(init_fn, update_fn) + + +def set_to_zero() -> GradientTransformation: + """Stateless transformation that maps input gradients to zero. + + The resulting update function, when called, will return a tree of zeros + matching the shape of the input gradients. This means that when the updates + returned from this transformation are applied to the model parameters, the + model parameters will remain unchanged. + + This can be used in combination with `multi_transform` or `masked` to freeze + (i.e. keep fixed) some parts of the tree of model parameters while applying + gradient updates to other parts of the tree. + + When updates are set to zero inside the same jit-compiled function as the + calculation of gradients, optax transformations, and application of updates to + parameters, unnecessary computations will in general be dropped. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return EmptyState() + + def update_fn(updates, state, params=None): + del params # Unused by the zero transform. + return jax.tree_util.tree_map(jnp.zeros_like, updates), state + + return GradientTransformation(init_fn, update_fn) + + +def stateless( + f: Callable[[Updates, Optional[Params]], Updates], +) -> GradientTransformation: + """Creates a stateless transformation from an update-like function. + + This wrapper eliminates the boilerplate needed to create a transformation that + does not require saved state between iterations. + + Args: + f: Update function that takes in updates (e.g. gradients) and parameters + and returns updates. The parameters may be `None`. + + Returns: + An `optax.GradientTransformation`. + """ + + def init_fn(_): + return EmptyState() + + def update_fn(updates, state, params=None): + del state + return f(updates, params), EmptyState() + + return GradientTransformation(init_fn, update_fn) + + +def stateless_with_tree_map( + f: Callable[[chex.Array, Optional[chex.Array]], chex.Array], +) -> GradientTransformation: + """Creates a stateless transformation from an update-like function for arrays. + + This wrapper eliminates the boilerplate needed to create a transformation that + does not require saved state between iterations, just like optax.stateless. + In addition, this function will apply the tree_map over update/params for you. + + Args: + f: Update function that takes in an update array (e.g. gradients) and + parameter array and returns an update array. The parameter array may be + `None`. + + Returns: + An `optax.GradientTransformation`. + """ + + def init_fn(_): + return EmptyState() + + def update_fn(updates, state, params=None): + del state + if params is not None: + return jax.tree_util.tree_map(f, updates, params), EmptyState() + else: + f_ = lambda u: f(u, None) + return jax.tree_util.tree_map(f_, updates), EmptyState() + + return GradientTransformation(init_fn, update_fn) + + +def with_extra_args_support( + tx: GradientTransformation, +) -> GradientTransformationExtraArgs: + """Wraps a gradient transformation, so that it ignores extra args.""" + + if isinstance(tx, GradientTransformationExtraArgs): + return tx + + def update(updates, state, params=None, **extra_args): + del extra_args + return tx.update(updates, state, params) + + return GradientTransformationExtraArgs(tx.init, update) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/clipping.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/clipping.py new file mode 100644 index 000000000..0893cdd82 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/clipping.py @@ -0,0 +1,307 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient clipping transformations. + +Note that complex numbers are also supported, see +https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 +""" + +import chex +import jax +import jax.numpy as jnp + +from optax._src import base +from optax._src import linear_algebra +from optax._src import numerics + +ClipState = base.EmptyState + + +def clip(max_delta: chex.Numeric) -> base.GradientTransformation: + """Clips updates element-wise, to be in ``[-max_delta, +max_delta]``. + + Args: + max_delta: The maximum absolute value for each element in the update. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ClipState() + + def update_fn(updates, state, params=None): + del params + updates = jax.tree_util.tree_map( + lambda g: jnp.clip(g, -max_delta, max_delta), updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def clip_by_block_rms(threshold: float) -> base.GradientTransformation: + """Clips updates to a max rms for the gradient of each param vector or matrix. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + threshold: The maximum rms for the gradient of each param vector or matrix. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return base.EmptyState() + + def update_fn(updates, state, params=None): + del params + + def _clip_fn(u): + clip_denom = jnp.maximum( + 1.0, + jnp.sqrt(jnp.mean(numerics.abs_sq(u))) / threshold) + return u / clip_denom + + updates = jax.tree_util.tree_map(_clip_fn, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +ClipByGlobalNormState = base.EmptyState + + +def clip_by_global_norm(max_norm: float) -> base.GradientTransformation: + """Clips updates using their global norm. + + References: + [Pascanu et al, 2012](https://arxiv.org/abs/1211.5063) + + Args: + max_norm: The maximum global norm for an update. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ClipByGlobalNormState() + + def update_fn(updates, state, params=None): + del params + g_norm = linear_algebra.global_norm(updates) + # TODO(b/163995078): revert back to the following (faster) implementation + # once analysed how it affects backprop through update (e.g. meta-gradients) + # g_norm = jnp.maximum(max_norm, g_norm) + # updates = jax.tree_util.tree_map( + # lambda t: (t / g_norm) * max_norm, updates) + trigger = jnp.squeeze(g_norm < max_norm) + chex.assert_shape(trigger, ()) # A scalar. + + def clip_fn(t): + return jax.lax.select(trigger, t, (t / g_norm.astype(t.dtype)) * max_norm) + + updates = jax.tree_util.tree_map(clip_fn, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def per_example_global_norm_clip( + grads: list[chex.Array], l2_norm_clip: float +) -> tuple[list[chex.Array], jax.Array]: + """Applies gradient clipping per-example using their global norm. + + References: + [Abadi et al, 2016](https://arxiv.org/abs/1607.00133) + + Args: + grads: flattened update; the function expects these to have a batch + dimension on the 0th axis. + l2_norm_clip: maximum L2 norm of the per-example gradients. + + Returns: + A tuple containing sum of the clipped per-example grads, and the number of + per-example grads that were clipped. + """ + bsize = grads[0].shape[0] + + if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): + raise ValueError( + 'Unlike other transforms, `per_example_global_norm_clip` expects' + ' `grads` to have a batch dimension in the 0th axis.') + + global_grad_norms = jax.vmap(linear_algebra.global_norm)(grads) + divisors = jnp.maximum(global_grad_norms / l2_norm_clip, 1.0) + num_clipped = jnp.greater(divisors, 1.0).sum() + clipped_sum = [(jnp.moveaxis(g, 0, -1) / divisors).sum(-1) for g in grads] + return clipped_sum, num_clipped + + +def per_example_layer_norm_clip( + grads: list[chex.Array], + global_l2_norm_clip: float, + uniform: bool = True, + eps: float = 1e-8, +) -> tuple[list[chex.Array], list[chex.Array]]: + """Applies gradient clipping per-example using per-layer norms. + + References: + [McMahan et al, 2012](https://arxiv.org/abs/1710.06963)] + + Args: + grads: flattened update; i.e. a list of gradients in which each item is + the gradient for one layer; the function expects these to have a batch + dimension on the 0th axis. + global_l2_norm_clip: overall L2 clip norm to use. + uniform: If `True`, per-layer clip norm is global_l2_norm_clip/sqrt(L), + where L is the number of layers. Otherwise, per-layer clip norm is + global_l2_norm_clip * sqrt(f), where f is the fraction of total model + parameters that are in this layer. + eps: Small positive value to add to norms to avoid possible division by + zero. + + Let C = `global_l2_norm_clip value`. Then per-layer clipping is done as + follows: + (1) If `uniform` is `True`, each of the K layers has an individual clip + norm of C / sqrt(K). + (2) If `uniform` is `False`, each of the K layers has an individual clip + norm of C * sqrt(D_i / D) where D_i is the number of parameters in + layer i, and D is the total number of parameters in the model. + + Returns: + A tuple containing sum of the clipped per-example grads and the number of + per-example grads that were clipped for each layer. + """ + bsize = grads[0].shape[0] + + if any(g.ndim == 0 or bsize != g.shape[0] for g in grads): + raise ValueError( + 'Unlike other transforms, `per_example_layer_norm_clip` expects' + ' `grads` to have a batch dimension in the 0th axis; got shapes:' + f' {(g.shape for g in grads)}.' + ) + + num_layers = len(grads) + + # Compute per-layer clip norms, based on whether we are using uniform + # variant or not. + if uniform: + # Create list of length `num_layers` of per-layer clip norm. + layer_clip_norms = ( + global_l2_norm_clip * (1.0 / num_layers) ** 0.5, + ) * num_layers + else: + total_params = sum(g[0].size for g in grads) + layer_clip_norms = tuple( + global_l2_norm_clip * (g[0].size / float(total_params)) ** 0.5 + for g in grads + ) + + # Compute per-layer grad norms. + def map_layer_norm(grads_list): + return [jnp.linalg.norm(g, ord=None, axis=None) for g in grads_list] + + layer_grad_norms_per_example = jax.vmap(map_layer_norm)(grads) + + # Perform clipping. + divisors = ( + tuple( + jnp.maximum( + layer_grad_norm / (layer_clip_norm + eps), 1.0 + ) + for layer_grad_norm, layer_clip_norm in zip( + layer_grad_norms_per_example, layer_clip_norms + ) + ) + ) + num_clipped = [jnp.greater(divisor, 1.0).sum() for divisor in divisors] + clipped_sum = [ + (g / jnp.expand_dims(d, axis=[i for i in range(1, g.ndim)])).sum(0) + for g, d in zip(grads, divisors) + ] + return clipped_sum, num_clipped + + +def unitwise_norm(x: chex.Array) -> chex.Array: + """Computes norms of each output unit separately.""" + if jnp.squeeze(x).ndim <= 1: # Scalars and vectors + squared_norm = jnp.sum(numerics.abs_sq(x), keepdims=True) + # Note that this assumes parameters with a shape of length 3 are multihead + # linear parameters--if you wish to apply AGC to 1D convs, you may need + # to modify this line. + elif x.ndim in (2, 3): # Linear layers of shape IO or multihead linear + squared_norm = jnp.sum(numerics.abs_sq(x), axis=0, keepdims=True) + elif x.ndim == 4: # Conv kernels of shape HWIO + squared_norm = jnp.sum(numerics.abs_sq(x), axis=(0, 1, 2), keepdims=True) + else: + raise ValueError( + f'Expected parameter with shape in {1, 2, 3, 4}, got {x.shape}.') + chex.assert_is_broadcastable(squared_norm.shape, x.shape) + return jnp.broadcast_to(jnp.sqrt(squared_norm), x.shape) + + +def unitwise_clip(g_norm: chex.Array, + max_norm: chex.Array, + grad: chex.Array, + div_eps: float = 1e-6) -> chex.Array: + """Applies gradient clipping unit-wise.""" + # This little max(., div_eps) is distinct from the normal eps and just + # prevents division by zero. It technically should be impossible to engage. + clipped_grad = grad * (max_norm / jnp.maximum(g_norm, div_eps)) + chex.assert_equal_shape((g_norm, max_norm, grad, clipped_grad)) + return jnp.where(g_norm < max_norm, grad, clipped_grad) + + +AdaptiveGradClipState = base.EmptyState + + +def adaptive_grad_clip(clipping: float, + eps: float = 1e-3) -> base.GradientTransformation: + """Clips updates to be at most ``clipping * parameter_norm``, unit-wise. + + References: + [Brock, Smith, De, Simonyan 2021] High-Performance Large-Scale Image + Recognition Without Normalization. (https://arxiv.org/abs/2102.06171) + + Args: + clipping: The maximum allowed ratio of update norm to parameter norm. + eps: An epsilon term to prevent clipping of zero-initialized params. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AdaptiveGradClipState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + g_norm, p_norm = jax.tree_util.tree_map(unitwise_norm, (updates, params)) + # Maximum allowable norm. + max_norm = jax.tree_util.tree_map( + lambda x: clipping * jnp.maximum(x, eps), p_norm) + # If grad norm > clipping * param_norm, rescale. + updates = jax.tree_util.tree_map(unitwise_clip, g_norm, max_norm, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/combine.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/combine.py new file mode 100644 index 000000000..c75532d5b --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/combine.py @@ -0,0 +1,248 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flexibly compose gradient transformations.""" + +from typing import Callable, NamedTuple, Union, Mapping, Hashable + +import jax + +from optax._src import base +from optax._src import wrappers + + +def chain( + *args: base.GradientTransformation, +) -> base.GradientTransformationExtraArgs: + """Applies a list of chainable update transformations. + + This function creates a new :func:`optax.GradientTransformation` that applies + a sequence of gradient transformations in order. The ``init`` function of the + new transformation constructs the optimizer state by concatenating the states + of the individual transforms, while the ``update`` function applies the + updates in the given order. + + Examples: + + A transform that scales by -0.1 the adam update: + + >>> import optax + >>> transform1 = optax.scale_by_adam() + >>> transform2 = optax.scale(-0.1) + >>> chained_transform = optax.chain(transform1, transform2) + >>> params = {'a': 1.0} + >>> state = chained_transform.init(params) + >>> updates = {'a': -0.5} + >>> updates, new_state = chained_transform.update(updates, state, params) + + Args: + *args: a sequence of chainable (init_fn, update_fn) tuples. + + Returns: + A :func:`GradientTransformationExtraArgs`, created by chaining the input + transformations. Note that independent of the argument types, the resulting + transformation always supports extra args. Any extra arguments passed to the + returned transformation will be passed only to those transformations in the + chain that support extra args. + """ + + transforms = [base.with_extra_args_support(t) for t in args] + init_fns, update_fns = zip(*transforms) + + def init_fn(params): + return tuple(fn(params) for fn in init_fns) + + def update_fn(updates, state, params=None, **extra_args): + if len(update_fns) != len(state): + raise ValueError('The number of updates and states has to be the same in ' + 'chain! Make sure you have called init first!') + + new_state = [] + for s, fn in zip(state, update_fns): + updates, new_s = fn(updates, s, params, **extra_args) + new_state.append(new_s) + return updates, tuple(new_state) + + # We opt to always return the GradientTransformationExtraArgs type here, + # instead of selecting the return type based on the arguments, since it works + # much better with the currently available type checkers. It also means that + # users will not get unexpected signature errors if they remove all of the + # transformations in a chain accepting extra args. + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +def named_chain( + *transforms: tuple[str, base.GradientTransformation] +) -> base.GradientTransformationExtraArgs: + """Chains optax gradient transformations. + + The `transforms` are `(name, transformation)` pairs, constituted of a string + `name` and an associated gradient transformation `transformation`. The + gradient transformation must be an instance of either `GradientTransformation` + or `GradientTransformationExtraArgs`. + + Each `name` is used as key for the state of the corresponding transformation + within the `named_chain` state. Thus the state of the gradient transformation + with a given `name` can be retrieved as `opt_state[name]`. + + Example: + + # tx1 is a GradientTransformation with no extra_args. + # tx2 is a GradientTransformationExtraArgs that requires `loss`. + # tx3 is a GradientTransformationExtraArgs that requires `temperature`. + + tx = named_chain(('one', tx1), ('two', tx2), ('three', tx3)) + extra_args={'loss': 0.3, 'temperature': 0.01} + tx.init(params) + tx.update(grads, state, params, **extra_args) + + Args: + *transforms: an arbitrary number of `(name, tx)` pairs, constituted of a + string `name` and an associated gradient transformation `tx`. The latter + is a `GradientTransformation` or `GradientTransformationExtraArgs`. + + Returns: + A single (init_fn, update_fn) tuple. + """ + + names = [name for name, _ in transforms] + + if len(names) != len(set(names)): + raise ValueError( + f'Named transformations must have unique names, but got {names}') + + transforms = [ + (name, base.with_extra_args_support(t)) + for name, t in transforms] + + def init_fn(params): + states = {} + for (name, tx) in transforms: + states[name] = tx.init(params) + return states + def update_fn(updates, state, params=None, **extra_args): + new_state = {} + for (name, tx) in transforms: + updates, new_state[name] = tx.update( + updates, state[name], params, **extra_args) + return updates, new_state + + return base.GradientTransformationExtraArgs(init_fn, update_fn) + + +class MultiTransformState(NamedTuple): + inner_states: Mapping[Hashable, base.OptState] + + +def multi_transform( + transforms: Mapping[Hashable, base.GradientTransformation], + param_labels: Union[base.PyTree, Callable[[base.PyTree], base.PyTree]], + *, + mask_compatible_extra_args: bool = False, +) -> base.GradientTransformationExtraArgs: + """Partitions params and applies a different transformation to each subset. + + Below is an example where we apply Adam to the weights and SGD to the biases + of a 2-layer neural network:: + + import optax + import jax + import jax.numpy as jnp + + def map_nested_fn(fn): + '''Recursively apply `fn` to the key-value pairs of a nested dict.''' + def map_fn(nested_dict): + return {k: (map_fn(v) if isinstance(v, dict) else fn(k, v)) + for k, v in nested_dict.items()} + return map_fn + + params = {'linear_1': {'w': jnp.zeros((5, 6)), 'b': jnp.zeros(5)}, + 'linear_2': {'w': jnp.zeros((6, 1)), 'b': jnp.zeros(1)}} + gradients = jax.tree_util.tree_map(jnp.ones_like, params) # dummy gradients + + label_fn = map_nested_fn(lambda k, _: k) + tx = optax.multi_transform({'w': optax.adam(1.0), 'b': optax.sgd(1.0)}, + label_fn) + state = tx.init(params) + updates, new_state = tx.update(gradients, state, params) + new_params = optax.apply_updates(params, updates) + + Instead of providing a ``label_fn``, you may provide a PyTree of labels + directly. Also, this PyTree may be a prefix of the parameters PyTree. This + is demonstrated in the GAN pseudocode below:: + + generator_params = ... + discriminator_params = ... + all_params = (generator_params, discriminator_params) + param_labels = ('generator', 'discriminator') + + tx = optax.multi_transform( + {'generator': optax.adam(0.1), 'discriminator': optax.adam(0.5)}, + param_labels) + + If you would like to not optimize some parameters, you may wrap + :func:`optax.multi_transform` with :func:`optax.masked`. + + Args: + transforms: A mapping from labels to transformations. Each transformation + will be only be applied to parameters with the same label. + param_labels: A PyTree that is the same shape or a prefix of the + parameters/updates (or a function that returns one given the parameters as + input). The leaves of this PyTree correspond to the keys of the transforms + (therefore the values at the leaves must be a subset of the keys). + mask_compatible_extra_args: Whether to also apply the same masking to + extra_arg fields with the same tree structure as params/updates. + + Returns: + A :func:`optax.GradientTransformationExtraArgs` that implements an ``init`` + and ``update`` function. + """ + + transforms = { + k: base.with_extra_args_support(v) + for k, v in transforms.items() + } + + def make_mask(labels, group): + return jax.tree_util.tree_map(lambda label: label == group, labels) + + def init_fn(params): + labels = param_labels(params) if callable(param_labels) else param_labels + + label_set = set(jax.tree_util.tree_leaves(labels)) + if not label_set.issubset(transforms.keys()): + raise ValueError('Some parameters have no corresponding transformation.\n' + f'Parameter labels: {list(sorted(label_set))} \n' + f'Transforms keys: {list(sorted(transforms.keys()))} \n') + + inner_states = { + group: wrappers.masked( + tx, make_mask(labels, group), + mask_compatible_extra_args=mask_compatible_extra_args).init(params) + for group, tx in transforms.items() + } + return MultiTransformState(inner_states) + + def update_fn(updates, state, params=None, **extra_args): + labels = param_labels(updates) if callable(param_labels) else param_labels + new_inner_state = {} + for group, tx in transforms.items(): + masked_tx = wrappers.masked( + tx, make_mask(labels, group), + mask_compatible_extra_args=mask_compatible_extra_args) + updates, new_inner_state[group] = masked_tx.update( + updates, state.inner_states[group], params, **extra_args) + return updates, MultiTransformState(new_inner_state) + + return base.GradientTransformationExtraArgs(init_fn, update_fn) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/linear_algebra.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/linear_algebra.py new file mode 100644 index 000000000..20222fca4 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/linear_algebra.py @@ -0,0 +1,257 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Linear algebra utilities used in optimisation.""" + +import functools +from typing import Callable, Optional, Union + +import chex +import jax +from jax import lax +import jax.numpy as jnp +from optax import tree_utils as otu +from optax._src import base +from optax._src import numerics + + +def _normalize_tree(x): + # divide by the L2 norm of the tree weights. + return otu.tree_scalar_mul(1.0 / otu.tree_l2_norm(x), x) + + +def global_norm(updates: base.PyTree) -> chex.Array: + """Compute the global norm across a nested structure of tensors.""" + return jnp.sqrt(sum( + jnp.sum(numerics.abs_sq(x)) for x in jax.tree_util.tree_leaves(updates))) + + +def _power_iteration_cond_fun(error_tolerance, num_iters, loop_vars): + normalized_eigvec, unnormalized_eigvec, eig, iter_num = loop_vars + residual = otu.tree_sub( + unnormalized_eigvec, otu.tree_scalar_mul(eig, normalized_eigvec) + ) + residual_norm = otu.tree_l2_norm(residual) + converged = jnp.abs(residual_norm / eig) < error_tolerance + return ~converged & (iter_num < num_iters) + + +def power_iteration( + matrix: Union[chex.Array, Callable[[chex.ArrayTree], chex.ArrayTree]], + *, + v0: Optional[chex.ArrayTree] = None, + num_iters: int = 100, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST, + key: Optional[chex.PRNGKey] = None, +) -> tuple[chex.Numeric, chex.ArrayTree]: + r"""Power iteration algorithm. + + This algorithm computes the dominant eigenvalue and its associated eigenvector + of a diagonalizable matrix. This matrix can be given as an array or as a + callable implementing a matrix-vector product. + + References: + Wikipedia contributors. `Power iteration + `_. + + Args: + matrix: a square matrix, either as an array or a callable implementing a + matrix-vector product. + v0: initial vector approximating the dominiant eigenvector. If ``matrix`` + is an array of size (n, n), v0 must be a vector of size (n,). If instead + ``matrix`` is a callable, then v0 must be a tree with the same structure + as the input of this callable. If this argument is None and ``matrix`` is + an array, then a random vector sampled from a uniform distribution in + [-1, 1] is used as initial vector. + num_iters: Number of power iterations. + error_tolerance: Iterative exit condition. The procedure stops when the + relative error of the estimate of the dominant eigenvalue is below this + threshold. + precision: precision XLA related flag, the available options are: a) + lax.Precision.DEFAULT (better step time, but not precise); b) + lax.Precision.HIGH (increased precision, slower); c) lax.Precision.HIGHEST + (best possible precision, slowest). + key: random key for the initialization of ``v0`` when not given + explicitly. When this argument is None, `jax.random.PRNGKey(0)` is used. + + Returns: + A pair (eigenvalue, eigenvector), where eigenvalue is the dominant + eigenvalue of ``matrix`` and eigenvector is its associated eigenvector. + + .. versionchanged:: 0.2.2 + ``matrix`` can be a callable. Reversed the order of the return parameters, + from (eigenvector, eigenvalue) to (eigenvalue, eigenvector). + """ + if callable(matrix): + mvp = matrix + if v0 is None: + # v0 must be given as we don't know the underlying pytree structure. + raise ValueError('v0 must be provided when `matrix` is a callable.') + else: + mvp = lambda v: jnp.matmul(matrix, v, precision=precision) + if v0 is None: + if key is None: + key = jax.random.PRNGKey(0) + # v0 is uniformly distributed in [-1, 1] + v0 = jax.random.uniform( + key, + shape=matrix.shape[-1:], + dtype=matrix.dtype, + minval=-1.0, + maxval=1.0, + ) + + v0 = _normalize_tree(v0) + + cond_fun = functools.partial( + _power_iteration_cond_fun, + error_tolerance, + num_iters, + ) + + def _body_fun(loop_vars): + _, z, _, iter_num = loop_vars + eigvec = _normalize_tree(z) + z = mvp(eigvec) + eig = otu.tree_vdot(eigvec, z) + return eigvec, z, eig, iter_num + 1 + + init_vars = (v0, mvp(v0), jnp.asarray(0.0), jnp.asarray(0)) + _, unormalized_eigenvector, dominant_eigenvalue, _ = ( + jax.lax.while_loop(cond_fun, _body_fun, init_vars) + ) + normalized_eigenvector = _normalize_tree(unormalized_eigenvector) + return dominant_eigenvalue, normalized_eigenvector + + +def matrix_inverse_pth_root(matrix: chex.Array, + p: int, + num_iters: int = 100, + ridge_epsilon: float = 1e-6, + error_tolerance: float = 1e-6, + precision: lax.Precision = lax.Precision.HIGHEST): + """Computes `matrix^(-1/p)`, where `p` is a positive integer. + + This function uses the Coupled newton iterations algorithm for + the computation of a matrix's inverse pth root. + + + References: + [Functions of Matrices, Theory and Computation, + Nicholas J Higham, Pg 184, Eq 7.18]( + https://epubs.siam.org/doi/book/10.1137/1.9780898717778) + + Args: + matrix: the symmetric PSD matrix whose power it to be computed + p: exponent, for p a positive integer. + num_iters: Maximum number of iterations. + ridge_epsilon: Ridge epsilon added to make the matrix positive definite. + error_tolerance: Error indicator, useful for early termination. + precision: precision XLA related flag, the available options are: + a) lax.Precision.DEFAULT (better step time, but not precise); + b) lax.Precision.HIGH (increased precision, slower); + c) lax.Precision.HIGHEST (best possible precision, slowest). + + Returns: + matrix^(-1/p) + """ + + # We use float32 for the matrix inverse pth root. + # Switch to f64 if you have hardware that supports it. + matrix_size = matrix.shape[0] + alpha = jnp.asarray(-1.0 / p, jnp.float32) + identity = jnp.eye(matrix_size, dtype=jnp.float32) + max_ev, _ = power_iteration( + matrix=matrix, num_iters=100, + error_tolerance=1e-6, precision=precision) + ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, 1e-16) + + def _unrolled_mat_pow_1(mat_m): + """Computes mat_m^1.""" + return mat_m + + def _unrolled_mat_pow_2(mat_m): + """Computes mat_m^2.""" + return jnp.matmul(mat_m, mat_m, precision=precision) + + def _unrolled_mat_pow_4(mat_m): + """Computes mat_m^4.""" + mat_pow_2 = _unrolled_mat_pow_2(mat_m) + return jnp.matmul( + mat_pow_2, mat_pow_2, precision=precision) + + def _unrolled_mat_pow_8(mat_m): + """Computes mat_m^4.""" + mat_pow_4 = _unrolled_mat_pow_4(mat_m) + return jnp.matmul( + mat_pow_4, mat_pow_4, precision=precision) + + def mat_power(mat_m, p): + """Computes mat_m^p, for p == 1, 2, 4 or 8. + + Args: + mat_m: a square matrix + p: a positive integer + + Returns: + mat_m^p + """ + # We unrolled the loop for performance reasons. + exponent = jnp.round(jnp.log2(p)) + return lax.switch( + jnp.asarray(exponent, jnp.int32), [ + _unrolled_mat_pow_1, + _unrolled_mat_pow_2, + _unrolled_mat_pow_4, + _unrolled_mat_pow_8, + ], (mat_m)) + + def _iter_condition(state): + (i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, + run_step) = state + error_above_threshold = jnp.logical_and( + error > error_tolerance, run_step) + return jnp.logical_and(i < num_iters, error_above_threshold) + + def _iter_body(state): + (i, mat_m, mat_h, unused_old_mat_h, error, unused_run_step) = state + mat_m_i = (1 - alpha) * identity + alpha * mat_m + new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision) + new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision) + new_error = jnp.max(jnp.abs(new_mat_m - identity)) + # sometimes error increases after an iteration before decreasing and + # converging. 1.2 factor is used to bound the maximal allowed increase. + return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, + new_error < error * 1.2) + + if matrix_size == 1: + resultant_mat_h = (matrix + ridge_epsilon)**alpha + error = 0 + else: + damped_matrix = matrix + ridge_epsilon * identity + + z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix)) + new_mat_m_0 = damped_matrix * z + new_error = jnp.max(jnp.abs(new_mat_m_0 - identity)) + new_mat_h_0 = identity * jnp.power(z, 1.0 / p) + init_state = tuple( + [0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, True]) + _, mat_m, mat_h, old_mat_h, error, convergence = lax.while_loop( + _iter_condition, _iter_body, init_state) + error = jnp.max(jnp.abs(mat_m - identity)) + is_converged = jnp.asarray(convergence, old_mat_h.dtype) + resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h + resultant_mat_h = jnp.asarray(resultant_mat_h, matrix.dtype) + return resultant_mat_h, error diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/numerics.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/numerics.py new file mode 100644 index 000000000..c80bc2ff9 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/numerics.py @@ -0,0 +1,118 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to ensure the implementation is safe wrt numerical issues. + +Note that complex numbers are also supported, see +https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 +""" + +from typing import Optional, Union + +import chex +import jax.numpy as jnp +import numpy as np + + +# TODO(jscholz) Promote these functions to jax core lib? + + +def abs_sq(x: chex.Array) -> chex.Array: + """Returns the squared norm of a (maybe complex) array. + + For real `x`, JAX generates the same HLO from this, `jnp.square(x)`, `x * x`, + or `x**2`. + + Args: + x: a (maybe complex) array. + + Returns: + The squared norm of `x`. + """ + if not isinstance(x, (np.ndarray, jnp.ndarray)): + raise ValueError(f"`abs_sq` accepts only NDarrays, got: {x}.") + return (x.conj() * x).real + + +def safe_norm(x: chex.Array, + min_norm: chex.Numeric, + ord: Optional[Union[int, float, str]] = None, # pylint: disable=redefined-builtin + axis: Union[None, tuple[int, ...], int] = None, + keepdims: bool = False) -> chex.Array: + """Returns jnp.maximum(jnp.linalg.norm(x), min_norm) with correct gradients. + + The gradients of `jnp.maximum(jnp.linalg.norm(x), min_norm)` at 0.0 is `NaN`, + because jax will evaluate both branches of the `jnp.maximum`. This function + will instead return the correct gradient of 0.0 also in such setting. + + Args: + x: jax array. + min_norm: lower bound for the returned norm. + ord: {non-zero int, inf, -inf, ‘fro’, ‘nuc’}, optional. Order of the norm. + inf means numpy’s inf object. The default is None. + axis: {None, int, 2-tuple of ints}, optional. If axis is an integer, it + specifies the axis of x along which to compute the vector norms. If axis + is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix + norms of these matrices are computed. If axis is None then either a vector + norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The + default is None. + keepdims: bool, optional. If this is set to True, the axes which are normed + over are left in the result as dimensions with size one. With this option + the result will broadcast correctly against the original x. + + Returns: + The safe norm of the input vector, accounting for correct gradient. + """ + norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=True) + x = jnp.where(norm <= min_norm, jnp.ones_like(x), x) + norm = jnp.squeeze(norm, axis=axis) if not keepdims else norm + masked_norm = jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims) + return jnp.where(norm <= min_norm, min_norm, masked_norm) + + +def safe_root_mean_squares(x: chex.Array, min_rms: chex.Numeric) -> chex.Array: + """Returns `maximum(sqrt(mean(abs_sq(x))), min_norm)` with correct grads. + + The gradients of `maximum(sqrt(mean(abs_sq(x))), min_norm)` at 0.0 + is `NaN`, because jax will evaluate both branches of the `jnp.maximum`. This + function will instead return the correct gradient of 0.0 also in such setting. + + Args: + x: jax array. + min_rms: lower bound for the returned norm. + + Returns: + The safe RMS of the input vector, accounting for correct gradient. + """ + rms = jnp.sqrt(jnp.mean(abs_sq(x))) + x = jnp.where(rms <= min_rms, jnp.ones_like(x), x) + return jnp.where(rms <= min_rms, min_rms, jnp.sqrt(jnp.mean(abs_sq(x)))) + + +def safe_int32_increment(count: chex.Numeric) -> chex.Numeric: + """Increments int32 counter by one. + + Normally `max_int + 1` would overflow to `min_int`. This functions ensures + that when `max_int` is reached the counter stays at `max_int`. + + Args: + count: a counter to be incremented. + + Returns: + A counter incremented by 1, or max_int if the maximum precision is reached. + """ + chex.assert_type(count, jnp.int32) + max_int32_value = jnp.iinfo(jnp.int32).max + one = jnp.array(1, dtype=jnp.int32) + return jnp.where(count < max_int32_value, count + one, max_int32_value) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/transform.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/transform.py new file mode 100644 index 000000000..c8f92422b --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/transform.py @@ -0,0 +1,1464 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gradient transformations.""" + +import functools +from typing import Any, Callable, NamedTuple, Optional, Union + +import chex +import jax +import jax.numpy as jnp + +from optax import tree_utils +from optax._src import base +from optax._src import numerics +from optax._src import utils +from optax._src import wrappers + +# pylint:disable=no-value-for-parameter + +_abs_sq = numerics.abs_sq + + +def _init_empty_state(params: base.Params) -> base.EmptyState: + """Init function for an empty state.""" + del params + return base.EmptyState() + + +class TraceState(NamedTuple): + """Holds an aggregation of past updates.""" + trace: base.Params + + +def trace( + decay: float, + nesterov: bool = False, + accumulator_dtype: Optional[Any] = None, +) -> base.GradientTransformation: + """Compute a trace of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `trace = decay * trace + t`, while `ema = decay * ema + (1-decay) * t`. + Both are frequently found in the optimization literature. + + Args: + decay: Decay rate for the trace of past updates. + nesterov: Whether to use Nesterov momentum. + accumulator_dtype: Optional `dtype` to be used for the accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + + def init_fn(params): + return TraceState( + trace=jax.tree_util.tree_map( + lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) + + def update_fn(updates, state, params=None): + del params + f = lambda g, t: g + decay * t + new_trace = jax.tree_util.tree_map(f, updates, state.trace) + updates = ( + jax.tree_util.tree_map(f, updates, new_trace) if nesterov + else new_trace) + new_trace = utils.cast_tree(new_trace, accumulator_dtype) + return updates, TraceState(trace=new_trace) + + return base.GradientTransformation(init_fn, update_fn) + + +def update_moment(updates, moments, decay, order): + """Compute the exponential moving average of the `order`-th moment.""" + return jax.tree_util.tree_map( + lambda g, t: (1 - decay) * (g ** order) + decay * t, updates, moments) + + +def update_infinity_moment(updates, moments, decay, eps): + """Compute the exponential moving average of the infinity norm.""" + return jax.tree_util.tree_map( + lambda g, t: jnp.maximum(jnp.abs(g) + eps, decay * t), updates, moments) + + +def update_moment_per_elem_norm(updates, moments, decay, order): + """Compute the EMA of the `order`-th moment of the element-wise norm.""" + + def orderth_norm(g): + if jnp.isrealobj(g): + return g ** order + else: + half_order = order / 2 + # JAX generates different HLO for int and float `order` + if half_order.is_integer(): + half_order = int(half_order) + return _abs_sq(g) ** half_order + + return jax.tree_util.tree_map( + lambda g, t: (1 - decay) * orderth_norm(g) + decay * t, updates, moments) + + +@functools.partial(jax.jit, inline=True) +def bias_correction(moment, decay, count): + """Performs bias correction. It becomes a no-op as count goes to infinity.""" + # The conversion to the data type of the moment ensures that bfloat16 remains + # bfloat16 in the optimizer state. This conversion has to be done after + # `bias_correction_` is calculated as calculating `decay**count` in low + # precision can result in it being rounded to 1 and subsequently a + # "division by zero" error. + bias_correction_ = 1 - decay**count + + # Perform division in the original precision. + return jax.tree_util.tree_map( + lambda t: t / bias_correction_.astype(t.dtype), moment) + + +def _reject_complex(params): + if any(jnp.iscomplexobj(x) for x in jax.tree_util.tree_leaves(params)): + raise ValueError('This transformation does not support complex parameters.') + + +class EmaState(NamedTuple): + """Holds an exponential moving average of past updates.""" + count: chex.Array # shape=(), dtype=jnp.int32. + ema: base.Params + + +def ema( + decay: float, + debias: bool = True, + accumulator_dtype: Optional[Any] = None +) -> base.GradientTransformation: + """Compute an exponential moving average of past updates. + + Note: `trace` and `ema` have very similar but distinct updates; + `ema = decay * ema + (1-decay) * t`, while `trace = decay * trace + t`. + Both are frequently found in the optimization literature. + + Args: + decay: Decay rate for the exponential moving average. + debias: Whether to debias the transformed gradient. + accumulator_dtype: Optional `dtype` to used for the accumulator; if `None` + then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) + + def init_fn(params): + return EmaState( + count=jnp.zeros([], jnp.int32), + ema=jax.tree_util.tree_map( + lambda t: jnp.zeros_like(t, dtype=accumulator_dtype), params)) + + def update_fn(updates, state, params=None): + del params + updates = new_ema = update_moment(updates, state.ema, decay, order=1) + count_inc = utils.safe_int32_increment(state.count) + if debias: + updates = bias_correction(new_ema, decay, count_inc) + state_ema = utils.cast_tree(new_ema, accumulator_dtype) + return updates, EmaState(count=count_inc, ema=state_ema) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRssState(NamedTuple): + """State holding the sum of gradient squares to date.""" + sum_of_squares: base.Updates + + +def scale_by_rss( + initial_accumulator_value: float = 0.1, + eps: float = 1e-7 +) -> base.GradientTransformation: + """Rescale updates by the root of the sum of all squared gradients to date. + + References: + [Duchi et al, 2011](https://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) + [McMahan et al., 2010](https://arxiv.org/abs/1002.4908) + + Args: + initial_accumulator_value: Starting value for accumulators, must be >= 0. + eps: A small floating point value to avoid zero denominator. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + sum_of_squares = jax.tree_util.tree_map( + lambda t: jnp.full_like(t, initial_accumulator_value), params) + return ScaleByRssState(sum_of_squares=sum_of_squares) + + def update_fn(updates, state, params=None): + del params + sum_of_squares = jax.tree_util.tree_map( + lambda g, t: _abs_sq(g) + t, updates, state.sum_of_squares) + inv_sqrt_g_square = jax.tree_util.tree_map( + lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), sum_of_squares) + updates = jax.tree_util.tree_map( + lambda scale, g: scale * g, inv_sqrt_g_square, updates) + return updates, ScaleByRssState(sum_of_squares=sum_of_squares) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRmsState(NamedTuple): + """State for exponential root mean-squared (RMS)-normalized updates.""" + nu: base.Updates + + +def scale_by_rms( + decay: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0. +) -> base.GradientTransformation: + r"""Rescale updates by the root of the exp. moving avg of the square. + + WARNING: PyTorch and optax's RMSprop implementations differ and could impact + performance. In the denominator, optax uses $\sqrt{v + \epsilon}$ whereas + PyTorch uses $\sqrt{v} + \epsilon$. See + https://github.com/google-deepmind/optax/issues/532 for more detail. + + References: + [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + + Args: + decay: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + initial_scale: Initial value for second moment. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + nu = jax.tree_util.tree_map( + lambda n: jnp.full_like(n, initial_scale), params) # second moment + return ScaleByRmsState(nu=nu) + + def update_fn(updates, state, params=None): + del params + nu = update_moment_per_elem_norm(updates, state.nu, decay, 2) + updates = jax.tree_util.tree_map( + lambda g, n: g * jax.lax.rsqrt(n + eps), updates, nu) + return updates, ScaleByRmsState(nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRStdDevState(NamedTuple): + """State for centered exponential moving average of squares of updates.""" + mu: base.Updates + nu: base.Updates + + +def scale_by_stddev( + decay: float = 0.9, + eps: float = 1e-8, + initial_scale: float = 0. +) -> base.GradientTransformation: + """Rescale updates by the root of the centered exp. moving average of squares. + + References: + [Hinton](www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf) + + Args: + decay: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + initial_scale: Initial value for second moment. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_util.tree_map( + lambda n: jnp.full_like(n, initial_scale), params) # second moment + return ScaleByRStdDevState(mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, decay, 1) + nu = update_moment_per_elem_norm(updates, state.nu, decay, 2) + updates = jax.tree_util.tree_map( + lambda g, m, n: g * jax.lax.rsqrt(n - _abs_sq(m) + eps), + updates, mu, nu) + return updates, ScaleByRStdDevState(mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByAdamState(NamedTuple): + """State for the Adam algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + + +def scale_by_adam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[chex.ArrayDType] = None, + *, + nesterov: bool = False +) -> base.GradientTransformation: + """Rescale updates according to the Adam algorithm. + + References: + Kingma et al, `Adam: A Method for Stochastic Optimization + `_, 2014 + + Dozat, `Incorporating Nesterov Momentum into Adam + `_ 2016 + + .. warning:: + PyTorch and optax's adam follow Algorithm 1 of the Kingma + and Ba's Adam paper, if reproducing old results note that TensorFlow + used instead the formulation just before Section 2.1 of the paper. + See https://github.com/deepmind/optax/issues/571 for more detail. + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + nesterov: Whether to use Nesterov momentum. The variant of Adam with + Nesterov momentum is described in [Dozat 2016] + + Returns: + A `GradientTransformation` object. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + if nesterov: + mu_hat = jax.tree_util.tree_map( + lambda m, g: b1 * m + (1 - b1) * g, + bias_correction(mu, b1, numerics.safe_int32_increment(count_inc)), + bias_correction(updates, b1, count_inc)) + else: + mu_hat = bias_correction(mu, b1, count_inc) + # Dozat 2016 https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ + # Algorithm 2 further multiplies Adam's standard nu_hat by b2. It is + # unclear why. Other Nadam implementations also omit the extra b2 factor. + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByAmsgradState(NamedTuple): + """State for the AMSGrad algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + nu_max: base.Updates + + +def scale_by_amsgrad( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + mu_dtype: Optional[chex.ArrayDType] = None, +) -> base.GradientTransformation: + """Rescale updates according to the AMSGrad algorithm. + + References: + [Reddi et al, 2018](https://openreview.net/forum?id=ryQu7f-RZ) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + mu_dtype: Optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + nu_max = jax.tree_util.tree_map(jnp.zeros_like, params) + return ScaleByAmsgradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu, + nu_max=nu_max) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + nu_max = jax.tree_util.tree_map(jnp.maximum, state.nu_max, nu_hat) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_max) + mu = utils.cast_tree(mu, mu_dtype) + return updates, ScaleByAmsgradState(count=count_inc, mu=mu, nu=nu, + nu_max=nu_max) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_adamax( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8 +) -> base.GradientTransformation: + """Rescale updates according to the Adamax algorithm. + + References: + [Kingma et al, 2014](https://arxiv.org/abs/1412.6980) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted maximum of grads. + eps: Term added to the denominator to improve numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Infinite moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + count_inc = numerics.safe_int32_increment(state.count) + mu = update_moment(updates, state.mu, b1, 1) + nu = update_infinity_moment(updates, state.nu, b2, eps) + # Bias correction for mean. No bias correction needed for infinity moment. + mu_hat = bias_correction(mu, b1, count_inc) + updates = jax.tree_util.tree_map(lambda m, v: m / v, mu_hat, nu) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByLionState(NamedTuple): + """State for the Lion algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + + +def scale_by_lion( + b1: float = 0.9, + b2: float = 0.99, + mu_dtype: Optional[chex.ArrayDType] = None, +) -> base.GradientTransformation: + """Rescale updates according to the Lion algorithm. + + References: + [Chen et al, 2023](https://arxiv.org/abs/2302.06675) + + Args: + b1: Rate for combining the momentum and the current grad. + b2: Decay rate for the exponentially weighted average of grads. + mu_dtype: Optional `dtype` to be used for the momentum; if + `None` then the `dtype is inferred from `params` and `updates`. + + Returns: + A `GradientTransformation` object. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + return ScaleByLionState(count=jnp.zeros([], jnp.int32), mu=mu) + + def update_fn(updates, state, params=None): + del params + updates_new = jax.tree_util.tree_map( + lambda g, m: jnp.sign((1. - b1) * g + b1 * m), updates, state.mu) + mu = update_moment(updates, state.mu, b2, 1) + mu = utils.cast_tree(mu, mu_dtype) + count_inc = numerics.safe_int32_increment(state.count) + return updates_new, ScaleByLionState(count=count_inc, mu=mu) + + return base.GradientTransformation(init_fn, update_fn) + + +ScaleState = base.EmptyState + + +def scale( + step_size: float +) -> base.GradientTransformation: + """Scale updates by some fixed scalar `step_size`. + + Args: + step_size: A scalar corresponding to a fixed scaling factor for updates. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ScaleState() + + def update_fn(updates, state, params=None): + del params + updates = jax.tree_util.tree_map(lambda g: step_size * g, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_param_block_norm( + min_scale: float = 1e-3 +) -> base.GradientTransformation: + """Scale updates for each param block by the norm of that block's parameters. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + min_scale: Minimum scaling factor. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jax.tree_util.tree_map( + lambda u, p: u * numerics.safe_norm(p, min_scale), + updates, params) + return updates, state + + return base.GradientTransformation(_init_empty_state, update_fn) + + +def scale_by_param_block_rms( + min_scale: float = 1e-3 +) -> base.GradientTransformation: + """Scale updates by rms of the gradient for each param vector or matrix. + + A `block` is here a weight vector (e.g. in a Linear layer) or a weight matrix + (e.g. in a convolutional layer) appearing as a leaf in the grads/param pytree. + + Args: + min_scale: Minimum scaling factor. + + Returns: + A `GradientTransformation` object. + """ + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jax.tree_util.tree_map( + lambda u, p: u * numerics.safe_root_mean_squares(p, min_scale), + updates, params) + return updates, state + + return base.GradientTransformation(_init_empty_state, update_fn) + + +class ScaleByAdaDeltaState(NamedTuple): + """State for the rescaling by Adadelta algoritm.""" + + e_g: base.Updates + e_x: base.Updates + + +def scale_by_adadelta( + rho: float = 0.9, eps: float = 1e-6 +) -> base.GradientTransformation: + """Rescale updates according to the Adadelta algorithm. + + References: + [Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf) + + Args: + rho: A coefficient used for computing a running average of squared + gradients. + eps: Term added to the denominator to improve numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + e_g = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared gradient] + e_x = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared update] + return ScaleByAdaDeltaState(e_g=e_g, e_x=e_x) + + def update_fn(updates, state, params=None): + del params + e_g = update_moment(updates, state.e_g, rho, 2) + updates = jax.tree_util.tree_map( + lambda g, cur_e_g, prev_e_x: ( + jnp.sqrt(prev_e_x + eps) / jnp.sqrt(cur_e_g + eps) + ) + * g, + updates, + e_g, + state.e_x, + ) + e_x = update_moment(updates, state.e_x, rho, 2) + return updates, ScaleByAdaDeltaState(e_g=e_g, e_x=e_x) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByBeliefState(NamedTuple): + """State for the rescaling by AdaBelief algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. + mu: base.Updates + nu: base.Updates + + +def scale_by_belief( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-16, + eps_root: float = 1e-16 +) -> base.GradientTransformation: + """Rescale updates according to the AdaBelief algorithm. + + References: + [Zhuang et al, 2020](https://arxiv.org/abs/2010.07468) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of variance of grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the second moment of the prediction error to + improve numerical stability. If backpropagating gradients through the + gradient transformation (e.g. for meta-learning), this must be non-zero. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + s = jax.tree_util.tree_map(jnp.zeros_like, params) # Second Central moment + return ScaleByBeliefState(count=jnp.zeros([], jnp.int32), mu=mu, nu=s) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + prediction_error = jax.tree_util.tree_map( + lambda g, m: g-m, updates, state.mu) + nu = update_moment_per_elem_norm(prediction_error, state.nu, b2, 2) + nu = jax.tree_util.tree_map(lambda v: v + eps_root, nu) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v) + eps), mu_hat, nu_hat) + return updates, ScaleByBeliefState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_yogi( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-3, + eps_root: float = 0.0, + initial_accumulator_value: float = 1e-6 +) -> base.GradientTransformation: + """Rescale updates according to the Yogi algorithm. + + Supports complex numbers, see + https://gist.github.com/wdphy16/118aef6fb5f82c49790d7678cf87da29 + + References: + [Zaheer et al, 2018](https://papers.nips.cc/paper/2018/hash/90365351ccc7437a1309dc64e4db32a3-Abstract.html) #pylint:disable=line-too-long + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of variance of grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + initial_accumulator_value: The starting value for accumulators. + Only positive values are allowed. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + value_like = lambda p: jnp.full_like(p, initial_accumulator_value) + mu = jax.tree_util.tree_map(value_like, params) # First moment + nu = jax.tree_util.tree_map(value_like, params) # Second Central moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = jax.tree_util.tree_map( + lambda g, v: v - (1 - b2) * jnp.sign(v - _abs_sq(g)) * _abs_sq(g), + updates, state.nu) + count_inc = numerics.safe_int32_increment(state.count) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.tree_util.tree_map( + lambda m, v: m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_radam( + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, + eps_root: float = 0.0, + threshold: float = 5.0 +) -> base.GradientTransformation: + """Rescale updates according to the Rectified Adam algorithm. + + References: + [Liu et al, 2020](https://arxiv.org/abs/1908.03265) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + eps_root: Term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + threshold: Threshold for variance tractability. + + Returns: + A `GradientTransformation` object. + """ + + ro_inf = 2./(1 - b2) - 1 + def _radam_update(params): + ro = params[0] + mu_hat = params[1] + nu_hat = params[2] + r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro)) + updates = jax.tree_util.tree_map( + lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + return updates + + def init_fn(params): + mu = jax.tree_util.tree_map(jnp.zeros_like, params) # First moment + nu = jax.tree_util.tree_map(jnp.zeros_like, params) # Second moment + return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def update_fn(updates, state, params=None): + del params + mu = update_moment(updates, state.mu, b1, 1) + nu = update_moment_per_elem_norm(updates, state.nu, b2, 2) + count_inc = numerics.safe_int32_increment(state.count) + b2t = b2**count_inc + ro = ro_inf - 2 * count_inc * b2t / (1 - b2t) + mu_hat = bias_correction(mu, b1, count_inc) + nu_hat = bias_correction(nu, b2, count_inc) + updates = jax.lax.cond( + ro >= threshold, _radam_update, lambda _: mu_hat, + (ro, mu_hat, nu_hat)) + return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByRpropState(NamedTuple): + step_sizes: base.Updates + prev_updates: base.Updates + + +def scale_by_rprop( + learning_rate: float, + eta_minus: float = 0.5, + eta_plus: float = 1.2, + min_step_size: float = 1e-6, + max_step_size: float = 50.0, +) -> base.GradientTransformation: + """Scale with the Rprop optimizer. + + Rprop, short for resillient backpropogation, is a first order variant of + gradient descent. It responds only to the sign of the gradient by increasing + or decreasing the step size selected per parameter exponentially to speed up + convergence and avoid oscillations. + + References: + PyTorch implementation: + https://pytorch.org/docs/stable/generated/torch.optim.Rprop.html + Riedmiller and Braun, 1993: https://ieeexplore.ieee.org/document/298623 + Igel and Hüsken, 2003: + https://www.sciencedirect.com/science/article/abs/pii/S0925231201007007 + + Args: + learning_rate: The initial step size. + eta_minus: Multiplicative factor for decreasing step size. This is applied + when the gradient changes sign from one step to the next. + eta_plus: Multiplicative factor for increasing step size. This is applied + when the gradient has the same sign from one step to the next. + min_step_size: Minimum allowed step size. Smaller steps will be clipped to + this value. + max_step_size: Maximum allowed step size. Larger steps will be clipped to + this value. + + Returns: + The corresponding `GradientTransformation`. + """ + + def init_fn(params): + step_sizes = jax.tree_util.tree_map( + lambda p: learning_rate * jnp.ones_like(p), params) + prev_updates = jax.tree_util.tree_map(jnp.zeros_like, params) + return ScaleByRpropState(step_sizes, prev_updates) + + def update_fn(updates, state, params=None): + del params + sign = jax.tree_util.tree_map( + lambda g, prev_g: g * prev_g, updates, state.prev_updates) + step_sizes = jax.tree_util.tree_map( + lambda s, step_size: jnp.where( + s == 0, + step_size, + jnp.clip( + step_size * jnp.where(s > 0, eta_plus, eta_minus), + a_min=min_step_size, a_max=max_step_size + ) + ), + sign, state.step_sizes + ) + prev_updates = jax.tree_util.tree_map( + lambda s, g, step_size: jnp.where( + s < 0, jnp.zeros_like(g), step_size * jnp.sign(g)), + sign, updates, step_sizes) + updates = jax.tree_util.tree_map( + lambda s, g, prev_g: jnp.where(s < 0, jnp.zeros_like(prev_g), prev_g), + sign, prev_updates, state.prev_updates) + return updates, ScaleByRpropState(step_sizes, prev_updates) + + return base.GradientTransformation(init_fn, update_fn) + + +AddDecayedWeightsState = base.EmptyState + + +def add_decayed_weights( + weight_decay: Union[float, jax.Array] = 0.0, + mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None +) -> base.GradientTransformation: + """Add parameter scaled by `weight_decay`. + + Args: + weight_decay: A scalar weight decay rate. + mask: A tree with same structure as (or a prefix of) the params PyTree, + or a Callable that returns such a pytree given the params/updates. + The leaves should be booleans, `True` for leaves/subtrees you want to + apply the transformation to, and `False` for those you want to skip. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AddDecayedWeightsState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + updates = jax.tree_util.tree_map( + lambda g, p: g + weight_decay * p, updates, params) + return updates, state + + # If mask is not `None`, apply mask to the gradient transformation. + # E.g. it is common to skip weight decay on bias units and batch stats. + if mask is not None: + return wrappers.masked( + base.GradientTransformation(init_fn, update_fn), mask) + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByScheduleState(NamedTuple): + """Maintains count for scale scheduling.""" + count: chex.Array # shape=(), dtype=jnp.int32 + + +def scale_by_learning_rate( + learning_rate: base.ScalarOrSchedule, + *, + flip_sign: bool = True, +) -> base.GradientTransformation: + """Scale by the (negative) learning rate (either as scalar or as schedule). + + Args: + learning_rate: Can either be a scalar or a schedule (i.e. a callable that + maps an (int) step to a float). + flip_sign: When set to True (the default) this corresponds to scaling by the + negative learning rate. + + Returns: + An optax.GradientTransformation that corresponds to multiplying the gradient + with `-learning_rate` (if flip_sign is True) or with `learning_rate` (if + flip_sign is False). + """ + m = -1 if flip_sign else 1 + if callable(learning_rate): + return scale_by_schedule(lambda count: m * learning_rate(count)) + return scale(m * learning_rate) + + +def scale_by_schedule( + step_size_fn: base.Schedule +) -> base.GradientTransformation: + """Scale updates using a custom schedule for the `step_size`. + + Args: + step_size_fn: A function that takes an update count as input and proposes + the step_size to multiply the updates by. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ScaleByScheduleState(count=jnp.zeros([], jnp.int32)) + + def update_fn(updates, state, params=None): + del params + step_size = step_size_fn(state.count) + updates = jax.tree_util.tree_map( + lambda g: jnp.array(step_size, dtype=g.dtype) * g, updates) + return updates, ScaleByScheduleState( + count=numerics.safe_int32_increment(state.count)) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByTrustRatioState(NamedTuple): + """The scale and decay trust ratio transformation is stateless.""" + + +def scale_by_trust_ratio( + min_norm: float = 0.0, + trust_coefficient: float = 1., + eps: float = 0., +) -> base.GradientTransformation: + """Scale updates by `trust ratio`. + + References: + [You et. al 2020](https://arxiv.org/abs/1904.00962) + + Args: + min_norm: Minimum norm for params and gradient norms; by default is zero. + trust_coefficient: A multiplier for the trust ratio. + eps: Additive constant added to the denominator for numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return ScaleByTrustRatioState() + + def update_fn(updates, state, params): + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + + def _scale_update(update, param): + + # Clip norms to minimum value, by default no clipping. + param_norm = numerics.safe_norm(param, min_norm) + update_norm = numerics.safe_norm(update, min_norm) + trust_ratio = trust_coefficient * param_norm / (update_norm + eps) + + # If no minimum norm clipping is used + # Set trust_ratio to 1 in case where parameters would never be updated. + zero_norm = jnp.logical_or(param_norm == 0., update_norm == 0.) + safe_trust_ratio = jnp.where( + zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio) + + return update * safe_trust_ratio + + updates = jax.tree_util.tree_map(_scale_update, updates, params) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class AddNoiseState(NamedTuple): + """State for adding gradient noise. Contains a count for annealing.""" + count: chex.Array + rng_key: chex.PRNGKey + + +def add_noise( + eta: float, + gamma: float, + seed: int +) -> base.GradientTransformation: + """Add gradient noise. + + References: + [Neelakantan et al, 2014](https://arxiv.org/abs/1511.06807) + + Args: + eta: Base variance of the gaussian noise added to the gradient. + gamma: Decay exponent for annealing of the variance. + seed: Seed for random number generation. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return AddNoiseState( + count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed)) + + def update_fn(updates, state, params=None): # pylint: disable=missing-docstring + del params + num_vars = len(jax.tree_util.tree_leaves(updates)) + treedef = jax.tree_util.tree_structure(updates) + count_inc = numerics.safe_int32_increment(state.count) + variance = eta / count_inc**gamma + standard_deviation = jnp.sqrt(variance) + all_keys = jax.random.split(state.rng_key, num=num_vars + 1) + noise = jax.tree_util.tree_map( + lambda g, k: jax.random.normal(k, shape=g.shape, dtype=g.dtype), + updates, jax.tree_util.tree_unflatten(treedef, all_keys[1:])) + updates = jax.tree_util.tree_map( + lambda g, n: g + standard_deviation.astype(g.dtype) * n, + updates, noise) + return updates, AddNoiseState(count=count_inc, rng_key=all_keys[0]) + + return base.GradientTransformation(init_fn, update_fn) + + +class ApplyEvery(NamedTuple): + """Contains a counter and a gradient accumulator.""" + count: chex.Array + grad_acc: base.Updates + + +def apply_every( + k: int = 1 +) -> base.GradientTransformation: + """Accumulate gradients and apply them every k steps. + + Note that if this transformation is part of a chain, the states of the other + transformations will still be updated at every step. In particular, using + `apply_every` with a batch size of N/2 and k=2 is not necessarily equivalent + to not using `apply_every` with a batch size of N. If this equivalence is + important for you, consider using the `optax.MultiSteps`. + + Args: + k: Emit non-zero gradients every k steps, otherwise accumulate them. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + grad_acc = jax.tree_util.tree_map(jnp.zeros_like, params) + return ApplyEvery(count=jnp.zeros([], jnp.int32), grad_acc=grad_acc) + + def update_fn(updates, state, params=None): + del params + c = state.count % k + acc = c != 0 + grad_acc = jax.tree_util.tree_map( + lambda g, ga: acc * ga + g, updates, state.grad_acc) + emit = c == (k - 1) + updates = jax.tree_util.tree_map(lambda ga: emit * ga, grad_acc) + count_inc = numerics.safe_int32_increment(state.count) + return updates, ApplyEvery(count=count_inc % k, grad_acc=grad_acc) + + return base.GradientTransformation(init_fn, update_fn) + + +def _subtract_mean(g): + if len(g.shape) > 1: + return g - g.mean(tuple(range(1, len(g.shape))), keepdims=True) + else: + return g + + +CentralState = base.EmptyState + + +def centralize() -> base.GradientTransformation: + """Centralize gradients. + + References: + [Yong et al, 2020](https://arxiv.org/abs/2004.01461) + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + del params + return CentralState() + + def update_fn(updates, state, params=None): + del params + updates = jax.tree_util.tree_map(_subtract_mean, updates) + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleBySM3State(NamedTuple): + """State for the SM3 algorithm.""" + mu: base.Updates + nu: base.Updates + + +def scale_by_sm3( + b1: float = 0.9, + b2: float = 1.0, + eps: float = 1e-8 +) -> base.GradientTransformation: + """Scale updates by `sm3`. + + References: + [Anil et. al 2019](https://arxiv.org/abs/1901.11150) + + Args: + b1: Decay rate for the exponentially weighted average of grads. + b2: Decay rate for the exponentially weighted average of squared grads. + eps: Term added to the denominator to improve numerical stability. + + Returns: + A `GradientTransformation` object. + """ + + def zeros_for_dim(p): + return [jnp.zeros([s]) for s in p.shape] + + def init_fn(params): + _reject_complex(params) + mu = jax.tree_util.tree_map(zeros_for_dim, params) + nu = jax.tree_util.tree_map(jnp.zeros_like, params) + return ScaleBySM3State(mu, nu) + + def _expanded_shape(shape, axis): + # Replaces a `shape` of [M, N, K] with 1 in all dimensions except for i. + # For eg: i = 1 returns [1, N, 1]. + rank = len(shape) + return [1] * axis + [shape[axis]] + [1] * (rank - axis - 1) + + def _new_accum(g, v): + coeffs = ((1.0 - b2) if b2 != 1.0 else 1.0, b2) + if g.ndim < 2: + return coeffs[0]*g**2 + coeffs[1]*v[0] + else: + return coeffs[0]*g**2 + coeffs[1]*functools.reduce(jnp.minimum, v) + + def _new_mu(g, i): + if g.ndim < 2: + return g + else: + return jnp.max(g, axis=other_axes(i, g.ndim)) + + def other_axes(idx, ndim): + return list(range(idx)) + list(range(idx+1, ndim)) + + def update_fn(updates, state, params=None): + del params + mu = jax.tree_util.tree_map( + lambda g, v: # pylint:disable=g-long-lambda + [jnp.reshape(v[i], _expanded_shape(g.shape, i)) for i in range(g.ndim)], + updates, state.mu) + accum = jax.tree_util.tree_map(_new_accum, updates, mu) + accum_inv_sqrt = jax.tree_util.tree_map( + lambda t: jnp.where(t > 0, jax.lax.rsqrt(t + eps), 0.0), accum) + up = jax.tree_util.tree_map(lambda g, a: g*a, updates, accum_inv_sqrt) + nu = update_moment(up, state.nu, b1, 1) + mu = jax.tree_util.tree_map( + lambda g: [_new_mu(g, i) for i in range(g.ndim)], accum) + + return nu, ScaleBySM3State(mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByNovogradState(NamedTuple): + """State for Novograd.""" + count: chex.Array + mu: base.Updates + nu: base.Updates + + +def scale_by_novograd( + b1: float = 0.9, + b2: float = 0.25, + eps: float = 1e-8, + eps_root: float = 0.0, + weight_decay: float = 0.0, + mu_dtype: Optional[chex.ArrayDType] = None, +) -> base.GradientTransformation: + """Computes NovoGrad updates. + + References: + [Ginsburg et al, 2019](https://arxiv.org/abs/1905.11286) + + Args: + b1: A decay rate for the exponentially weighted average of grads. + b2: A decay rate for the exponentially weighted average of squared grads. + eps: A term added to the denominator to improve numerical stability. + eps_root: A term added to the denominator inside the square-root to improve + numerical stability when backpropagating gradients through the rescaling. + weight_decay: A scalar weight decay rate. + mu_dtype: An optional `dtype` to be used for the first order accumulator; if + `None` then the `dtype` is inferred from `params` and `updates`. + + Returns: + The corresponding `GradientTransformation`. + """ + + mu_dtype = utils.canonicalize_dtype(mu_dtype) + + def init_fn(params): + mu = jax.tree_util.tree_map( # First moment + lambda t: jnp.zeros_like(t, dtype=mu_dtype), params) + nu = jax.tree_util.tree_map(lambda _: 0.0, params) # Second moment + return ScaleByNovogradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) + + def nu_addition(grads): + return jnp.linalg.norm(grads)**2 + + def mu_addition(grads, params, nu): + return grads / (jnp.sqrt(nu + eps_root) + eps) + weight_decay * params + + def init_nu(grads, nu): + del nu + return jax.tree_util.tree_map(nu_addition, grads) + + def update_nu(grads, nu): + updates = jax.tree_util.tree_map(nu_addition, grads) + return update_moment(updates, nu, b2, 1) + + def init_mu(grads, params, mu, nu): + del mu + return jax.tree_util.tree_map(mu_addition, grads, params, nu) + + def update_mu(grads, params, mu, nu): + updates = jax.tree_util.tree_map(mu_addition, grads, params, nu) + return jax.tree_util.tree_map(lambda m, u: b1 * m + u, mu, updates) + + # Second moment + def update_fn(updates, state, params): + count_inc = numerics.safe_int32_increment(state.count) + + nu = jax.lax.cond(count_inc == 1, init_nu, update_nu, updates, state.nu) + + mu = jax.lax.cond(count_inc == 1, init_mu, update_mu, updates, params, + state.mu, nu) + + mu = utils.cast_tree(mu, mu_dtype) + updates = mu + return updates, ScaleByNovogradState(count=count_inc, mu=mu, nu=nu) + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_optimistic_gradient(alpha: float = 1.0, + beta: float = 1.0 + ) -> base.GradientTransformation: + """Compute generalized optimistic gradients. + + References: + [Mokhtari et al, 2019](https://arxiv.org/abs/1901.08511v2) + + Args: + alpha: Coefficient for generalized optimistic gradient descent. + beta: Coefficient for negative momentum. + + Returns: + A `GradientTransformation` object. + """ + + def init_fn(params): + prev_grads = jax.tree_util.tree_map(jnp.zeros_like, params) + return TraceState(trace=prev_grads) + + def update_fn(updates, state, params=None): + del params + + new_updates = jax.tree_util.tree_map( + lambda grad_t, grad_tm1: (alpha + beta) * grad_t - beta * grad_tm1, + updates, state.trace) + return new_updates, TraceState(trace=updates) + + return base.GradientTransformation(init_fn, update_fn) + + +class ScaleByDistanceOverGradientsState(NamedTuple): + """State for scale_by_distance_over_gradients.""" + + max_dist: base.OptState + grad_sum_of_squares: base.OptState + init_params: base.OptState + + +def scale_by_distance_over_gradients( + reps_rel=1e-6, eps=1e-8, param_dtype=jnp.float32, global_scale=1.0 +) -> base.GradientTransformation: + """Distance-over-gradients learning rate-free optimizer. + + This implementation stores a single copy of the model parameters, plus two + scalars per parameter array. It is equivalent to "Layer-wise DoG" (LDoG) + in the paper. + + The authors recommend using model averaging with this optimizer. + + References: + ["DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size + Schedule"](https://arxiv.org/pdf/2302.12022.pdf) + + Args: + reps_rel: Used to compute initial learning rate. Recommended values are 1e-4 + for models using batch norm, 1e-6 otherwise. + eps: Small loading term to avoid divide-by-zero errors. + param_dtype: dtype for storing initial parameters. + global_scale: Global scale factor, typically 1.0 or -1.0 + + Returns: + A `GradientTransformation` object. + """ + + def _l2(x, y=0.0): + return jnp.sqrt(jnp.square(x - y).sum()) + + def init_fn(params): + return ScaleByDistanceOverGradientsState( + # Initial distance (needed to prevent zero step sizes). + jax.tree_util.tree_map(lambda x: reps_rel * (1 + _l2(x)), params), + # Initial gradient sum-of-squares. + jax.tree_util.tree_map(lambda x: jnp.zeros(1), params), + # Initial params, cast to preferred precision. + jax.tree_map(lambda x: x.astype(param_dtype), params), + ) + + def update_fn(updates, state: ScaleByDistanceOverGradientsState, params): + # update max distance + max_dist = jax.tree_map( + lambda d, x, y: jnp.maximum(d, _l2(x, y)), + state.max_dist, + params, + state.init_params, + ) + + # update gradient sum-of-squares + g_sos = jax.tree_map( + lambda x, y: x + jnp.square(y).sum(), state.grad_sum_of_squares, updates + ) + + def _tx(g, d, g_sos): + """Apply the transformation.""" + eta = global_scale * (d / jnp.sqrt(g_sos + eps)) + return eta * g + + updates = jax.tree_map(_tx, max_dist, g_sos, updates) + + # new state + state = ScaleByDistanceOverGradientsState( + max_dist, g_sos, state.init_params + ) + + return updates, state + + return base.GradientTransformation(init_fn, update_fn) + + +def scale_by_polyak( + f_min: float = 0.0, + max_learning_rate: float = 1.0, + eps: float = 0.0, +) -> base.GradientTransformationExtraArgs: + """Scales the update by Polyak's step-size.""" + + def update_fn( + updates: base.Updates, + state: base.EmptyState, + params: Optional[base.Params] = None, + *, + value: float, + **extra_args, + ) -> tuple[base.Updates, base.EmptyState]: + """Scales the update by the Polyak step-size. + + Args: + updates: the updates to be scaled. + state: the state of the transformation. + params: the parameters of the model. + value: the value of the loss function. + **extra_args: additional keyword arguments. They are ignored by this + transformation. + Returns: + The scaled updates and the state of the transformation. + """ + del params, extra_args + grad_sq_norm = tree_utils.tree_l2_norm(updates, squared=True) + # avoid division by zero + step = jnp.where( + grad_sq_norm + eps <= jnp.finfo(float).eps, + jnp.array(0.0), + jnp.minimum( + (value - f_min) / (grad_sq_norm + eps), max_learning_rate + ), + ) + updates = tree_utils.tree_scalar_mul(step, updates) + return updates, state + + return base.GradientTransformationExtraArgs(_init_empty_state, update_fn) diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/update.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/update.py new file mode 100644 index 000000000..e810f62a9 --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/_src/update.py @@ -0,0 +1,103 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Apply transformed gradient updates to parameters.""" + +import chex +import jax +import jax.numpy as jnp + +from optax._src import base + + +def apply_updates(params: base.Params, updates: base.Updates) -> base.Params: + """Applies an update to the corresponding parameters. + + This is a utility functions that applies an update to a set of parameters, and + then returns the updated parameters to the caller. As an example, the update + may be a gradient transformed by a sequence of`GradientTransformations`. This + function is exposed for convenience, but it just adds updates and parameters; + you may also apply updates to parameters manually, using `tree_map` + (e.g. if you want to manipulate updates in custom ways before applying them). + + Args: + params: a tree of parameters. + updates: a tree of updates, the tree structure and the shape of the leaf + nodes must match that of `params`. + + Returns: + Updated parameters, with same structure, shape and type as `params`. + """ + return jax.tree_util.tree_map( + lambda p, u: jnp.asarray(p + u).astype(jnp.asarray(p).dtype), + params, updates) + + +def incremental_update( + new_tensors: base.Params, + old_tensors: base.Params, + step_size: chex.Numeric +) -> base.Params: + """Incrementally update parameters via polyak averaging. + + Polyak averaging tracks an (exponential moving) average of the past + parameters of a model, for use at test/evaluation time. + + References: + [Polyak et al, 1991](https://epubs.siam.org/doi/10.1137/0330046) + + Args: + new_tensors: the latest value of the tensors. + old_tensors: a moving average of the values of the tensors. + step_size: the step_size used to update the polyak average on each step. + + Returns: + an updated moving average `step_size*new+(1-step_size)*old` of the params. + """ + return jax.tree_util.tree_map( + lambda new, old: step_size * new + (1.0 - step_size) * old, + new_tensors, old_tensors) + + +def periodic_update( + new_tensors: base.Params, + old_tensors: base.Params, + steps: chex.Array, + update_period: int +) -> base.Params: + """Periodically update all parameters with new values. + + A slow copy of a model's parameters, updated every K actual updates, can be + used to implement forms of self-supervision (in supervised learning), or to + stabilise temporal difference learning updates (in reinforcement learning). + + References: + [Grill et al., 2020](https://arxiv.org/abs/2006.07733) + [Mnih et al., 2015](https://arxiv.org/abs/1312.5602) + + Args: + new_tensors: the latest value of the tensors. + old_tensors: a slow copy of the model's parameters. + steps: number of update steps on the "online" network. + update_period: every how many steps to update the "target" network. + + Returns: + a slow copy of the model's parameters, updated every `update_period` steps. + """ + return jax.lax.cond( + jnp.mod(steps, update_period) == 0, + lambda _: new_tensors, + lambda _: old_tensors, + None) + diff --git a/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/schedules/_join.py b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/schedules/_join.py new file mode 100644 index 000000000..fab6ccccc --- /dev/null +++ b/.github/workflows/nsys-jax/maxtext_fsdp4_test_data/sources/usr/local/lib/python3.10/dist-packages/optax/schedules/_join.py @@ -0,0 +1,45 @@ +# Copyright 2019 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to join schedules.""" + +from typing import Sequence + +import chex +import jax.numpy as jnp + +from optax._src import base + + +def join_schedules( + schedules: Sequence[base.Schedule], + boundaries: Sequence[int] +) -> base.Schedule: + """Sequentially apply multiple schedules. + + Args: + schedules: A list of callables (expected to be optax schedules). Each + schedule will receive a step count indicating the number of steps since + the previous boundary transition. + boundaries: A list of integers (of length one less than schedules) that + indicate when to transition between schedules. + Returns: + schedule: A function that maps step counts to values. + """ + def schedule(step: chex.Numeric) -> chex.Numeric: + output = schedules[0](step) + for boundary, schedule in zip(boundaries, schedules[1:]): + output = jnp.where(step < boundary, output, schedule(step - boundary)) + return output + return schedule