Skip to content

Commit

Permalink
[misc] CUDA Time Layerwise Profiler (vllm-project#8337)
Browse files Browse the repository at this point in the history
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
  • Loading branch information
3 people authored Oct 17, 2024
1 parent 390be74 commit 9d30a05
Show file tree
Hide file tree
Showing 8 changed files with 1,390 additions and 4 deletions.
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ steps:
- python3 offline_inference_vision_language_multi_image.py
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
- python3 offline_inference_encoder_decoder.py
- python3 offline_profile.py --model facebook/opt-125m

- label: Prefix Caching Test # 9min
#mirror_hardwares: [amd]
Expand Down
282 changes: 282 additions & 0 deletions examples/offline_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
import inspect
import json
import os
import sys
from argparse import RawTextHelpFormatter
from dataclasses import asdict, dataclass
from typing import Optional

import torch

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.profiler import layerwise_profile
from vllm.utils import FlexibleArgumentParser

BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256
OUTPUT_LEN_DEFAULT = 2


@dataclass
class ProfileContext:
engine_args: EngineArgs
prompt_len: int
output_len: int
batch_size: int
save_chrome_traces_folder: Optional[str]


def get_dtype(dtype: str):
if dtype == "torch.float":
return torch.float
else:
return dtype


def run_profile(context: ProfileContext, csv_output: Optional[str],
json_output: Optional[str]):
print("Run profile with:")
for key, value in asdict(context).items():
print(f" {key} = {value}")

# Create sampling params
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=args.output_len,
ignore_eos=True)

# Create LLM
llm = LLM(**asdict(context.engine_args))
batch_size = context.batch_size
prompt_len = context.prompt_len
output_len = context.output_len

scheduler_config = llm.llm_engine.scheduler_config
max_model_len = llm.llm_engine.model_config.max_model_len
max_num_batched_tokens = scheduler_config.max_num_batched_tokens
max_num_seqs = scheduler_config.max_num_seqs

if batch_size * prompt_len > max_num_batched_tokens:
print(f"ERROR: chosen batch_size * prompt_len "
f"({batch_size} * {prompt_len} = {batch_size * prompt_len}) is "
f"larger than max_num_batched_tokens ({max_num_batched_tokens}) "
f"and therefore cannot be run in a single profile step, please "
f"choose a smaller batch size or prompt length, or increase "
f"--max-num-batched-tokens")
sys.exit(-1)
if batch_size >= max_num_seqs:
print(
f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
f"single profile step, please choose a smaller batch size")
sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len)
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
print(
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
f"{output_len} = {prompt_len + output_len}) is larger than the "
f"model's max_model_len ({max_model_len}), please choose a smaller "
f"prompt_len or output_len, or increase --max-model-len")
sys.exit(-1)

def add_requests():
for i in range(batch_size):
prompt_token_ids = torch.randint(
llm.llm_engine.model_config.get_vocab_size(),
size=(prompt_len, )).tolist()

llm.llm_engine.add_request(
request_id=f"seq{i}",
prompt={'prompt_token_ids': prompt_token_ids},
params=sampling_params)

def abort_requests():
for i in range(batch_size):
llm.llm_engine.abort_request(f"seq{i}")

# Warm up run
print("Warm up run ...")
add_requests()
llm.llm_engine.step() # Prefill
llm.llm_engine.step() # Decode
abort_requests()

print("Profile run ...")
add_requests()

with layerwise_profile() as prefill_prof:
llm.llm_engine.step() # First step is prefill

decode_profs = []
for x in range(args.output_len - 1):
with layerwise_profile() as decode_prof:
llm.llm_engine.step()
decode_profs.append(decode_prof)

decode_results_list = [prof.results for prof in decode_profs]
prefill_results = prefill_prof.results
has_decode = len(decode_results_list) > 0

LINE_WIDTH = 80
print("=" * LINE_WIDTH)
print(f"= Prefill Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
prefill_results.print_model_table()

if has_decode:
print()
print("=" * LINE_WIDTH)
print(f"= First Decode Step Model Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
decode_results_list[0].print_model_table()

print()
print("=" * LINE_WIDTH)
print(f"= Prefill Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
prefill_results.print_summary_table()

if has_decode:
print()
print("=" * LINE_WIDTH)
print(f"= First Decode Step Summary Table "
f"(prompt_len={prompt_len}, batch_size={batch_size})")
print("=" * LINE_WIDTH)
print()
decode_results_list[0].print_summary_table()

if csv_output:
csv_filename_base = csv_output.rstrip(".csv")
prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv")
prefill_results.export_summary_stats_table_csv(
csv_filename_base + "_prefill_summary_table.csv")

if has_decode:
decode_results_list[0].export_model_stats_table_csv(\
csv_filename_base + "_decode_model_table.csv")
decode_results_list[0].export_summary_stats_table_csv(
csv_filename_base + "_decode_summary_table.csv")

if json_output:
cuda_devices = [
torch.cuda.get_device_properties(dev_idx)
for dev_idx in range(torch.cuda.device_count())
]

json_dict = {
"context": {
"python_version": f"{sys.version}",
"torch_version": f"{torch.__version__}",
"torch_cuda_version": f"{torch.version.cuda}",
"cuda_devices": f"{cuda_devices}",
**asdict(context)
},
"prefill": prefill_results.convert_stats_to_dict(),
}

if has_decode:
for idx, dr in enumerate(decode_results_list):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

for idx, dr in enumerate(decode_results_list[1:]):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

with open(json_output.rstrip(".json") + ".json", "w+") as f:
json.dump(json_dict, f, indent=2)
pass

if context.save_chrome_traces_folder is not None:
os.makedirs(context.save_chrome_traces_folder, exist_ok=True)
prefill_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + "/prefill.json")
for idx, decode_prof in enumerate(decode_profs):
decode_prof.profiler.export_chrome_trace(
context.save_chrome_traces_folder + f"/decode_{idx + 1}.json")
print("Traces saved as prefill.json and decode_1.json, etc."
f" in folder {context.save_chrome_traces_folder}")


if __name__ == "__main__":
parser = FlexibleArgumentParser(description="""
Profile a model
example:
```
python examples/offline_profile.py \\
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
--enforce-eager
```
then you can use various tools to analyze the json output
terminal ascii tables:
```
python tools/profiler/print_layerwise_table.py \\
--json-trace Llama31-8b-FP8.json --phase prefill --table summary
```
or create matplotlib stacked bar charts:
```
python tools/profiler/visualize_layerwise_profile.py \\
--json-trace Llama31-8b-FP8.json \\
--output-directory profile_breakdown --plot-metric pct_cuda_time
```
""",
formatter_class=RawTextHelpFormatter)
parser.add_argument(
"--csv",
type=str,
default=None,
help="Export the results as multiple csv file. This should be the root "
"filename, will create <filename>_prefill_model_table.csv, "
"<filename>_prefill_summary_table.csv, "
"<filename>_decode_model_table.csv, and "
"<filename>_decode_summary_table.csv")
parser.add_argument(
"--json",
type=str,
default=None,
help="Export the results as a json file. This should be the filename")
parser.add_argument("--save-chrome-traces-folder",
type=str,
help="Save chrome traces for the prefill and decode "
"will save traces as prefill.json and decode_1.json, "
"etc. inside this folder")
parser.add_argument(
"--prompt-len",
type=int,
default=PROMPT_LEN_DEFAULT,
help=f"Length of the random prompt to use when profiling, all batched "
f"requests use the same prompt_len, default={PROMPT_LEN_DEFAULT}")
parser.add_argument("--batch-size",
type=int,
default=BATCH_SIZE_DEFAULT,
help=f"Number of requests to run as a single batch, "
f"default={BATCH_SIZE_DEFAULT}")
parser.add_argument(
"--output-len",
type=int,
default=OUTPUT_LEN_DEFAULT,
help="Number of llm steps to run (includes prefill and decode) "
"- default={OUTPUT_LEN_DEFAULT}")

EngineArgs.add_cli_args(parser)

args = parser.parse_args()

context = ProfileContext(
engine_args=EngineArgs.from_cli_args(args),
**{
k: v
for k, v in vars(args).items()
if k in inspect.signature(ProfileContext).parameters
})
run_profile(context, csv_output=args.csv, json_output=args.json)
77 changes: 77 additions & 0 deletions tools/profiler/print_layerwise_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse
import json
from typing import Dict

from vllm.profiler.layerwise_profile import ModelStatsEntry, SummaryStatsEntry
from vllm.profiler.utils import TablePrinter, indent_string


def flatten_entries(entry_cls, profile_dict: Dict):
entries_and_depth = []

def get_entries(node, curr_depth=0):
entries_and_depth.append((entry_cls(**node["entry"]), curr_depth))

for child in node["children"]:
get_entries(
child,
curr_depth=curr_depth + 1,
)

for root in profile_dict:
get_entries(root)

return entries_and_depth


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--json-trace",
type=str,
required=True,
help="json trace file output by "
"examples/offline_profile.py")
parser.add_argument("--phase",
type=str,
choices=["prefill", "decode_1"],
required=True,
help="The phase to print the table for.")
parser.add_argument("--table",
type=str,
choices=["summary", "model"],
default="summary",
help="Which table to print, the summary table or the "
"layerwise model table")

args = parser.parse_args()

with open(args.json_trace, "r") as f:
profile_data = json.load(f)

if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
invocations=15)
elif args.table == "model":
entries_and_depths = flatten_entries(
ModelStatsEntry, profile_data[args.phase]["model_stats"])
column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
pct_cuda_time=12,
trace=60)

# indent entry names based on the depth
entries = []
for entry, depth in entries_and_depths:
entry.name = indent_string(
entry.name,
indent=depth,
indent_style=lambda indent: "|" + "-" * indent + " ")
entries.append(entry)

TablePrinter(type(entries[0]), column_widths).print_table(entries)
Loading

0 comments on commit 9d30a05

Please sign in to comment.