Skip to content

Commit

Permalink
nsys-jax: all-to-all and repeated thunk support
Browse files Browse the repository at this point in the history
Also run notebook in CI
  • Loading branch information
olupton committed Jun 4, 2024
1 parent e28ad9b commit 769b5dc
Show file tree
Hide file tree
Showing 153 changed files with 23,778 additions and 28 deletions.
28 changes: 3 additions & 25 deletions .github/container/jax_nsys/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,9 @@
" # program, there may be different sub-groupings that are participating in smaller\n",
" # collectives in the strict/NCCL sense. TODO: it would be better to identify those\n",
" # sub-groupings and group them, but we currently lack the relevant information.\n",
" collective_df = df.groupby([\"ProgramId\", \"Name\", \"ModuleExecution\"])\n",
" collective_df = df.groupby(\n",
" [\"ProgramId\", \"Name\", \"ModuleExecution\", \"ThunkExecution\"]\n",
" )\n",
" # Take the fastest device kernel as a proxy for the actual bandwidth of the\n",
" # collective.\n",
" bandwidth_df = collective_df.agg(\n",
Expand Down Expand Up @@ -468,30 +470,6 @@
"axs[2].set_xscale(\"log\")\n",
"axs[2].set_yscale(\"log\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "710100ec-55e6-4b6f-b8aa-5e6bda15fda1",
"metadata": {},
"outputs": [],
"source": [
"compute_times = (\n",
" thunk_df[~thunk_df[\"Communication\"]]\n",
" .groupby([\"ProgramId\", \"Name\"])\n",
" .agg({\"ProjDurNs\": [\"mean\", \"std\"]})\n",
" .sort_values((\"ProjDurNs\", \"mean\"))\n",
")\n",
"plt.plot(\n",
" compute_times[(\"ProjDurNs\", \"mean\")],\n",
" compute_times[(\"ProjDurNs\", \"std\")] / compute_times[(\"ProjDurNs\", \"mean\")],\n",
" \"o\",\n",
")\n",
"# plt.errorbar([\"{}:{}\".format(*x) for x in compute_times.index], compute_times[(\"ProjDurNs\", \"mean\")], compute_times[(\"ProjDurNs\", \"std\")], marker=\"o\")\n",
"# plt.xlabel(\n",
"plt.xscale(\"log\")\n",
"# compute_times[(\"ProjDurNs\", \"std\")]"
]
}
],
"metadata": {
Expand Down
12 changes: 9 additions & 3 deletions .github/container/jax_nsys/install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# The expectation is that those archives will be copied and extracted on a
# laptop or workstation, and this installation script will be run there, while
# the `nsys-jax` wrapper is executed on a remote GPU cluster.
set -ex
SCRIPT_DIR=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
VIRTUALENV="${SCRIPT_DIR}/nsys_jax_venv"
if [[ ! -d "${VIRTUALENV}" ]]; then
Expand All @@ -18,12 +19,17 @@ if [[ ! -d "${VIRTUALENV}" ]]; then
. "${VIRTUALENV}/bin/activate"
python -m pip install -U pip
"${SCRIPT_DIR}/nsys-jax-ensure-protobuf"
python -m pip install jupyterlab
# matplotlib is a dependency of Analysis.ipynb but not jax_nsys
python -m pip install jupyterlab matplotlib
python -m pip install -e "${SCRIPT_DIR}/python/jax_nsys"
curl -o "${VIRTUALENV}/bin/flamegraph.pl" https://raw.githubusercontent.com/brendangregg/FlameGraph/master/flamegraph.pl
chmod 755 "${VIRTUALENV}/bin/flamegraph.pl"
else
echo "Virtual environment already exists, not installing anything..."
fi
echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb"
cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb
if [ -z ${NSYS_JAX_INSTALL_SKIP_LAUNCH+x} ]; then
echo "Launching: cd ${SCRIPT_DIR} && ${VIRTUALENV}/bin/python -m jupyterlab Analysis.ipynb"
cd "${SCRIPT_DIR}" && "${VIRTUALENV}/bin/python" -m jupyterlab Analysis.ipynb
else
echo "Skipping launch of jupyterlab due to NSYS_JAX_INSTALL_SKIP_LAUNCH"
fi
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def _collective_correction(kind: str, size: int) -> tuple[float, float]:
return (size, (size - 1) / size)
case "all-reduce":
return (1, 2 * (size - 1) / size)
case "all-to-all":
# https://github.com/NVIDIA/nccl-tests/blob/a1efb427e764241bc43d2d91be875c9f55da03a5/src/alltoall.cu#L44
return (1, (size - 1) / size)
case "collective-broadcast":
return (1, 1)
case "collective-permute":
Expand Down Expand Up @@ -71,6 +74,7 @@ def get_message_size(program_id: int, instruction: str) -> pd.Series:
in {
"all-gather-start",
"all-reduce-start",
"all-to-all",
"collective-broadcast",
"collective-permute-start",
"reduce-scatter",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,14 @@ def clean_data_frame(d, extra_columns=[]):
value=r"\2",
regex=True,
)
# Add a new column describing which (0th, 1st, ...) execution of the thunk
# within the given module execution this is. For example, while loops in the
# HLO can lead to the same thunk being executed multiple times within the same
# module execution.
thunk_df["ThunkExecution"] = thunk_df.groupby(
["TID", "ProgramId", "Name", "ModuleExecution"]
).cumcount()

# Classify thunks as communication/computation and save to output
output["thunk"] = _classify_comms(thunk_df, prefix)

Expand Down
27 changes: 27 additions & 0 deletions .github/workflows/nsys-jax.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@ jobs:
run: |
export MYPYPATH="${PWD}/compiled_stubs"
./venv/bin/mypy ${NSYS_JAX_PYTHON_FILES}
notebook:
runs-on: ubuntu-22.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Mock up the structure of an extracted .zip file
shell: bash -x -e {0}
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
mv .github/workflows/nsys-jax/maxtext_fsdp4_test_data/* .github/container/jax_nsys/
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
./nsys_jax_venv/bin/python -m jupyterlab --version
./nsys_jax_venv/bin/ipython Analysis.ipynb
ruff:
runs-on: ubuntu-24.04
steps:
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// This proto intends to match format expected by pprof tool.
syntax = "proto3";

package tensorflow.tfprof.pprof;

message Profile {
repeated ValueType sample_type = 1;
repeated Sample sample = 2;
repeated Mapping mapping = 3;
repeated Location location = 4;
repeated Function function = 5;
repeated string string_table = 6;
int64 drop_frames = 7;
int64 keep_frames = 8;
int64 time_nanos = 9;
int64 duration_nanos = 10;
ValueType period_type = 11;
int64 period = 12;
repeated int64 comment = 13;
int64 default_sample_type = 14;
}

message ValueType {
int64 type = 1;
int64 unit = 2;
}

message Sample {
repeated uint64 location_id = 1;
repeated int64 value = 2;
repeated Label label = 3;
}

message Label {
int64 key = 1;
int64 str = 2;
int64 num = 3;
}

message Mapping {
uint64 id = 1;
uint64 memory_start = 2;
uint64 memory_limit = 3;
uint64 file_offset = 4;
int64 filename = 5;
int64 build_id = 6;
bool has_functions = 7;
bool has_filenames = 8;
bool has_line_numbers = 9;
bool has_inline_frames = 10;
}

message Location {
uint64 id = 1;
uint64 mapping_id = 2;
uint64 address = 3;
repeated Line line = 4;
}

message Line {
uint64 function_id = 1;
int64 line = 2;
}

message Function {
uint64 id = 1;
int64 name = 2;
int64 system_name = 3;
int64 filename = 4;
int64 start_line = 5;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

syntax = "proto3";

package tensorflow.profiler;

// Next ID: 3
message ProfiledInstructionsProto {
message InstructionCost {
string name = 1;
double cost_us = 2;
}
message Latency {
string source = 1;
string target = 2;
double latency_us = 3;
}
repeated InstructionCost costs = 1;
repeated Latency latencies = 2;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
syntax = "proto3";

package tensorflow;

import "tsl/profiler/protobuf/profiler_service.proto";

message NewProfileSessionRequest {
ProfileRequest request = 1;
// The place where we will dump profile data. We will normally use
// MODEL_DIR/plugins/profile as the repository root.
string repository_root = 2;
repeated string hosts = 3; // host or host:port, port will be ignored.
string session_id = 4;
}

message NewProfileSessionResponse {
// Auxiliary error_message.
string error_message = 1;

// Whether all hosts had returned a empty trace.
bool empty_trace = 2;
}

message EnumProfileSessionsAndToolsRequest {
string repository_root = 1;
}

message ProfileSessionInfo {
string session_id = 1;
// Which tool data is available for consumption.
repeated string available_tools = 2;
}

message EnumProfileSessionsAndToolsResponse {
// Auxiliary error_message.
string error_message = 1;
// If success, the returned sessions information are stored here.
repeated ProfileSessionInfo sessions = 2;
}

message ProfileSessionDataRequest {
// The place where we will read profile data. We will normally use
// MODEL_DIR/plugins/profile as the repository root.
string repository_root = 1;
string session_id = 2;
// Which host the data is associated. if empty, data from all hosts are
// aggregated.
string host_name = 5;
// Which tool
string tool_name = 3;
// Tool's specific parameters. e.g. TraceViewer's viewport etc
map<string, string> parameters = 4;
}

message ProfileSessionDataResponse {
// Auxiliary error_message.
string error_message = 1;

// Output format. e.g. "json" or "proto" or "blob"
string output_format = 2;

// TODO(jiesun): figure out whether to put bytes or oneof tool specific proto.
bytes output = 3;
}
////////////////////////////////////////////////////////////////////////////////
// ProfileAnalysis service provide entry point for profiling TPU and for
// serving profiled data to TensorBoard through GRPC
////////////////////////////////////////////////////////////////////////////////
service ProfileAnalysis {
// Starts a profiling session, blocks until it completes.
// TPUProfileAnalysis service delegate this to TPUProfiler service.
// Populate the profiled data in repository, then return status to caller.
rpc NewSession(NewProfileSessionRequest) returns (NewProfileSessionResponse) {
}
// Enumerate existing sessions and return available profile tools.
rpc EnumSessions(EnumProfileSessionsAndToolsRequest)
returns (EnumProfileSessionsAndToolsResponse) {}
// Retrieve specific tool's data for specific session.
rpc GetSessionToolData(ProfileSessionDataRequest)
returns (ProfileSessionDataResponse) {}
}
Loading

0 comments on commit 769b5dc

Please sign in to comment.