Skip to content

Commit

Permalink
Minor tweaks to ease profiling for Vidur Simulator (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
nitinkedia7 authored May 14, 2024
1 parent 8d8c986 commit 7fe4bb0
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 33 deletions.
7 changes: 3 additions & 4 deletions sarathi/benchmark/benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ def _get_replica_resource_mapping(self) -> ReplicaResourceMapping:
if self._config.replica_resource_mapping:
replica_resource_mapping = json.loads(
self._config.replica_resource_mapping)
print("Replica resource mapping:")
print(replica_resource_mapping)
logger.info(
f"Replica resource mapping: {replica_resource_mapping}")
return replica_resource_mapping

cluster_resources_keys = list(ray.available_resources().keys())
Expand Down Expand Up @@ -273,8 +273,7 @@ def _get_replica_resource_mapping(self) -> ReplicaResourceMapping:
replica_resource_mapping[str(replica_id)].append(
available_gpus.pop(0))

print("Replica resource mapping:")
print(replica_resource_mapping)
logger.info(f"Replica resource mapping: {replica_resource_mapping}")

return replica_resource_mapping

Expand Down
18 changes: 10 additions & 8 deletions sarathi/benchmark/capacity_search/capacity_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import ray
import wandb

from sarathi.logger import init_logger
from sarathi.benchmark.capacity_search.config import (
JobConfig,
BenchmarkConfig,
Expand All @@ -19,14 +20,16 @@
)
from sarathi.benchmark.types import ReplicaResourceMapping

logger = init_logger(__name__)


def release_resources_on_failure(func):

def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except Exception as e:
print(f"Error in search: {e}", flush=True)
logger.error(f"Error in search: {e}", flush=True)
self.release_resources()

return wrapper
Expand Down Expand Up @@ -61,8 +64,7 @@ def _generate_run_command(
):
resource_mapping_arg = f"--replica_resource_mapping '{json.dumps(self.resource_mapping)}'"
command = f"python -m sarathi.benchmark.main {benchmark_config.to_args()} {resource_mapping_arg}"
if self.args.debug:
print(f"Running command: {command}", flush=True)
logger.debug(f"Running command: {command}", flush=True)

return command

Expand Down Expand Up @@ -92,7 +94,7 @@ def _is_under_sla(
scheduling_delay <= self.args.scheduling_delay_slo_value
and tbt <= self.args.tbt_slo_value)

print(
logger.info(
f"{benchmark_config.to_human_readable_name()} - "
f"Scheduling delay (P{self.args.scheduling_delay_slo_quantile}): {scheduling_delay}"
f" - TBT (P{self.args.tbt_slo_quantile}): {tbt}",
Expand Down Expand Up @@ -149,7 +151,7 @@ def search(self):
"""
Perform binary search to find the maximum QPS under the SLO
"""
print(
logger.info(
f"Starting search for {self.job_config.get_human_readable_name()}",
flush=True,
)
Expand All @@ -166,7 +168,7 @@ def search(self):
found_valid_qps = False

for _ in range(self.args.max_iterations):
print(f"Searching between {left} and {right}", flush=True)
logger.info(f"Searching between {left} and {right}", flush=True)
# stopping condition - we have reached the minimum granularity
if abs(left -
right) < self.args.min_search_granularity * qps / 100:
Expand Down Expand Up @@ -209,13 +211,13 @@ def search(self):
min_qps_over_sla = min(min_qps_over_sla, qps)

if not found_valid_qps:
print(
logger.info(
f"No valid QPS found for {self.job_config.get_human_readable_name()}",
flush=True,
)
return {}

print(
logger.info(
f"Max QPS under SLO for {self.job_config.get_human_readable_name()}: "
f"{max_qps_under_sla}, Scheduling delay: {scheduling_delay_at_max_qps}, TBT: {tbt_at_max_qps}",
flush=True,
Expand Down
9 changes: 6 additions & 3 deletions sarathi/benchmark/capacity_search/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

import yaml

from sarathi.logger import init_logger
from sarathi.benchmark.capacity_search.search_manager import SearchManager

logger = init_logger(__name__)


def get_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -67,11 +70,11 @@ def get_args():

os.makedirs(args.output_dir, exist_ok=True)

print("Starting capacity search", flush=True)
logger.info("Starting capacity search", flush=True)

# merge the config with the args
config.update(vars(args))
print(f"Config: {config}", flush=True)
logger.info(f"Config: {config}", flush=True)

# store the config and args
json.dump(config, open(f"{args.output_dir}/config.json", "w"))
Expand All @@ -93,4 +96,4 @@ def get_args():

end_time = time.time()

print(f"Benchmarking took time: {end_time - start_time}", flush=True)
logger.info(f"Benchmarking took time: {end_time - start_time}", flush=True)
5 changes: 4 additions & 1 deletion sarathi/benchmark/capacity_search/search_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import ray

from sarathi.logger import init_logger
from sarathi.benchmark.capacity_search.capacity_search import CapacitySearch
from sarathi.benchmark.capacity_search.config import JobConfig
from sarathi.benchmark.capacity_search.ray_utils import (
Expand All @@ -10,6 +11,8 @@
)
from sarathi.benchmark.types import ReplicaResourceMapping

logger = init_logger(__name__)


def run_search(
job_config: JobConfig,
Expand Down Expand Up @@ -42,7 +45,7 @@ def run(self):
job_configs = JobConfig.generate_job_configs(self.config)

for job_config in job_configs:
print(f"Running search for {job_config}")
logger.info(f"Running search for {job_config}")

ray_parallel_runner = RayParallelRunner()

Expand Down
6 changes: 5 additions & 1 deletion sarathi/benchmark/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,11 @@

import yaml

from sarathi.logger import init_logger
from sarathi.benchmark.constants import DEFAULT_CONFIG_FILE

logger = init_logger(__name__)


def custom_bool(val):
if val.lower() in ('yes', 'true', 't', 'y', '1'):
Expand Down Expand Up @@ -35,6 +38,8 @@ def __init__(self, config_file=DEFAULT_CONFIG_FILE):
self._args = None
self._load_yaml(config_file)
self._parse_args()
logger.info(f"Starting benchmark with config: {self._args}")

self._add_derived_args()
self._write_yaml_to_file()

Expand All @@ -47,7 +52,6 @@ def _parse_args(self):
self._args = self._parser.parse_args()

def _add_derived_args(self):
print(self._args)
self._args.output_dir = f"{self._args.output_dir}/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f')}"
os.makedirs(self._args.output_dir, exist_ok=True)

Expand Down
6 changes: 3 additions & 3 deletions sarathi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,12 @@ def _get_and_verify_max_len(
derived_max_model_len *= scaling_factor

if max_model_len is None:
print(
logger.info(
f"Using the derived maximum model length: {derived_max_model_len}")
max_model_len = derived_max_model_len
elif max_model_len > derived_max_model_len:
print(f"Applying rope_scaling to the maximum model length: "
f"{derived_max_model_len} -> {max_model_len}")
logger.info(f"Applying rope_scaling to the maximum model length: "
f"{derived_max_model_len} -> {max_model_len}")
# force rope_scaling
scaling_factor = max_model_len / derived_max_model_len
rope_scaling = {"type": "linear", "factor": scaling_factor}
Expand Down
7 changes: 4 additions & 3 deletions sarathi/engine/base_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ def _get_worker_impl(self):

def _init_workers_ray(self, **ray_remote_kwargs):
replica_resource_mapping = self.parallel_config.replica_resource_mapping
print("Starting workers with resource mapping:")
print(replica_resource_mapping)
logger.info(
f"Starting workers with resource mapping: {replica_resource_mapping}"
)

self.workers: List[RayWorker] = []

Expand Down Expand Up @@ -167,7 +168,7 @@ def _init_workers_ray(self, **ray_remote_kwargs):
# In case port is already in use, this will fail.
distributed_init_method = f"tcp://{driver_ip}:{get_random_port()}"

print(
logger.info(
f"Initializing workers with distributed init method: {distributed_init_method}"
)

Expand Down
8 changes: 4 additions & 4 deletions sarathi/model_executor/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from enum import Enum, auto
from enum import Enum
from typing import Union

from sarathi.model_executor.attention.flashinfer_attention_wrapper import FlashinferAttentionWrapper
Expand All @@ -7,9 +7,9 @@


class AttentionBackend(Enum):
FLASHINFER = auto()
FLASH_ATTENTION = auto()
NO_OP = auto()
FLASHINFER = "FLASHINFER"
FLASH_ATTENTION = "FLASH_ATTENTION"
NO_OP = "NO_OP"


ATTENTION_BACKEND = AttentionBackend.NO_OP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
import torch.nn.init as init
from torch.nn.parameter import Parameter

from sarathi.logger import init_logger
from sarathi.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sarathi.metrics.constants import OperationMetrics
from sarathi.metrics.cuda_timer import CudaTimer
from .mappings import (
gather_from_tensor_model_parallel_region,
Expand All @@ -28,6 +28,9 @@
VocabUtility,
)

logger = init_logger(__name__)


_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,
'partition_dim': -1,
'partition_stride': 1}
Expand Down Expand Up @@ -341,8 +344,8 @@ def __init__(self, input_size, output_size, *,
self.create_weights(params_dtype)

if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
logger.warning("When not reduce the results, adding bias to the "
"results can lead to incorrect results")

if bias:
self.bias = Parameter(torch.empty(
Expand Down
4 changes: 2 additions & 2 deletions sarathi/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def detokenize_incrementally(
all_input_ids[-6:], skip_special_tokens=skip_special_tokens)
except ValueError as e:
new_tokens = ["[UNK]"] * 6
print(f"Warning: {e}", flush=True)
logger.warning(f"Warning: {e}", flush=True)

output_tokens = new_tokens
# 5 is an arbitrary value that should work for all
Expand All @@ -119,7 +119,7 @@ def detokenize_incrementally(
[new_token_id], skip_special_tokens=skip_special_tokens)
except ValueError as e:
new_tokens = [prev_tokens[-1]]
print(f"Warning: {e}", flush=True)
logger.warning(f"Warning: {e}", flush=True)
output_tokens = prev_tokens + new_tokens

# The prefix text is necessary only to defeat cleanup algorithms in
Expand Down
2 changes: 1 addition & 1 deletion sarathi/worker/base_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def init_model(self):
# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)

print(f"Worker {self.rank} is using device {self.local_rank}")
logger.info(f"Worker {self.rank} is using device {self.local_rank}")
self.device = torch.device(f"cuda:{self.local_rank}")
torch.cuda.set_device(self.device)

Expand Down

0 comments on commit 7fe4bb0

Please sign in to comment.