diff --git a/optimum_benchmark/benchmarks/base.py b/optimum_benchmark/benchmarks/base.py index 84495a1a..55f0477b 100644 --- a/optimum_benchmark/benchmarks/base.py +++ b/optimum_benchmark/benchmarks/base.py @@ -3,7 +3,7 @@ from typing import ClassVar, Generic from ..backends.base import Backend -from .report import BenchmarkReport +from ..report import BenchmarkReport from .config import BenchmarkConfigT LOGGER = getLogger("benchmark") diff --git a/optimum_benchmark/benchmarks/inference/benchmark.py b/optimum_benchmark/benchmarks/inference/benchmark.py index f145dbba..ad24dc92 100644 --- a/optimum_benchmark/benchmarks/inference/benchmark.py +++ b/optimum_benchmark/benchmarks/inference/benchmark.py @@ -1,15 +1,16 @@ from logging import getLogger +from dataclasses import dataclass from ..base import Benchmark from .config import InferenceConfig from ...trackers.memory import MemoryTracker +from ...report import BenchmarkReport, Measurements from ...backends.base import Backend, BackendConfigT from ...generators.input_generator import InputGenerator from ...trackers.energy import EnergyTracker, Efficiency from ...trackers.latency import LatencyTracker, Throughput from ...import_utils import is_torch_distributed_available from ...task_utils import TEXT_GENERATION_TASKS, IMAGE_DIFFUSION_TASKS -from .report import InferenceReport, TextGenerationReport, ImageDiffusionReport if is_torch_distributed_available(): import torch.distributed @@ -44,6 +45,22 @@ CALL_EFFICIENCY_UNIT = "images/kWh" +@dataclass +class InferenceReport(BenchmarkReport): + forward: Measurements = Measurements() + + +@dataclass +class ImageDiffusionReport(BenchmarkReport): + call: Measurements = Measurements() + + +@dataclass +class TextGenerationReport(BenchmarkReport): + prefill: Measurements = Measurements() + decode: Measurements = Measurements() + + class InferenceBenchmark(Benchmark[InferenceConfig]): NAME = "inference" @@ -157,13 +174,13 @@ def run_text_generation_memory_tracking(self, backend: Backend): with self.memory_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) - self.report.prefill.max_memory = self.memory_tracker.get_max_memory() + self.report.prefill.memory = self.memory_tracker.get_max_memory() self.memory_tracker.reset() with self.memory_tracker.track(): _ = backend.generate(self.generate_input, self.config.generate_kwargs) - self.report.decode.max_memory = self.memory_tracker.get_max_memory() + self.report.decode.memory = self.memory_tracker.get_max_memory() def run_image_diffusion_memory_tracking(self, backend: Backend): LOGGER.info("\t+ Running memory tracking") @@ -171,7 +188,7 @@ def run_image_diffusion_memory_tracking(self, backend: Backend): with self.memory_tracker.track(): _ = backend.call(self.diffuse_input, self.config.forward_kwargs) - self.report.call.max_memory = self.memory_tracker.get_max_memory() + self.report.call.memory = self.memory_tracker.get_max_memory() def run_inference_memory_tracking(self, backend: Backend): LOGGER.info("\t+ Running memory tracking") @@ -179,7 +196,7 @@ def run_inference_memory_tracking(self, backend: Backend): with self.memory_tracker.track(): _ = backend.forward(self.forward_inputs, self.config.forward_kwargs) - self.report.forward.max_memory = self.memory_tracker.get_max_memory() + self.report.forward.memory = self.memory_tracker.get_max_memory() ## Latency tracking def run_text_generation_latency_tracking(self, backend: Backend): diff --git a/optimum_benchmark/benchmarks/inference/report.py b/optimum_benchmark/benchmarks/inference/report.py deleted file mode 100644 index 1f2edd52..00000000 --- a/optimum_benchmark/benchmarks/inference/report.py +++ /dev/null @@ -1,22 +0,0 @@ -from dataclasses import dataclass -from logging import getLogger - -from ..report import BenchmarkReport, BenchmarkMeasurements - -LOGGER = getLogger("report") - - -@dataclass -class InferenceReport(BenchmarkReport): - forward: BenchmarkMeasurements = BenchmarkMeasurements() - - -@dataclass -class ImageDiffusionReport(BenchmarkReport): - call: BenchmarkMeasurements = BenchmarkMeasurements() - - -@dataclass -class TextGenerationReport(BenchmarkReport): - prefill: BenchmarkMeasurements = BenchmarkMeasurements() - decode: BenchmarkMeasurements = BenchmarkMeasurements() diff --git a/optimum_benchmark/benchmarks/training/benchmark.py b/optimum_benchmark/benchmarks/training/benchmark.py index ffac8975..312be76b 100644 --- a/optimum_benchmark/benchmarks/training/benchmark.py +++ b/optimum_benchmark/benchmarks/training/benchmark.py @@ -1,10 +1,11 @@ from logging import getLogger from contextlib import ExitStack +from dataclasses import dataclass from ..base import Benchmark from .config import TrainingConfig -from .report import TrainingReport from ...trackers.memory import MemoryTracker +from ...report import BenchmarkReport, Measurements from ...backends.base import Backend, BackendConfigT from ...trackers.energy import EnergyTracker, Efficiency from ...generators.dataset_generator import DatasetGenerator @@ -18,6 +19,13 @@ TRAIN_EFFICIENCY_UNIT = "samples/kWh" +@dataclass +class TrainingReport(BenchmarkReport): + overall: Measurements = Measurements() + warmup: Measurements = Measurements() + train: Measurements = Measurements() + + class TrainingBenchmark(Benchmark[TrainingConfig]): NAME = "training" @@ -69,10 +77,9 @@ def run(self, backend: Backend[BackendConfigT]) -> None: ) if self.config.memory: - # it's the same - self.report.overall.max_memory = memory_tracker.get_max_memory() - self.report.warmup.max_memory = memory_tracker.get_max_memory() - self.report.train.max_memory = memory_tracker.get_max_memory() + self.report.overall.memory = memory_tracker.get_max_memory() + self.report.warmup.memory = memory_tracker.get_max_memory() + self.report.train.memory = memory_tracker.get_max_memory() self.report.log_memory() diff --git a/optimum_benchmark/benchmarks/training/report.py b/optimum_benchmark/benchmarks/training/report.py deleted file mode 100644 index 90cd91f2..00000000 --- a/optimum_benchmark/benchmarks/training/report.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass -from logging import getLogger - -from ..report import BenchmarkReport, BenchmarkMeasurements - -LOGGER = getLogger("report") - - -@dataclass -class TrainingReport(BenchmarkReport): - overall: BenchmarkMeasurements = BenchmarkMeasurements() - warmup: BenchmarkMeasurements = BenchmarkMeasurements() - train: BenchmarkMeasurements = BenchmarkMeasurements() diff --git a/optimum_benchmark/experiment.py b/optimum_benchmark/experiment.py index 2fd0345e..9843b85d 100644 --- a/optimum_benchmark/experiment.py +++ b/optimum_benchmark/experiment.py @@ -1,15 +1,12 @@ import os from logging import getLogger from tempfile import TemporaryDirectory -from dataclasses import dataclass, field -from typing import Any, Dict, Type, Optional, TYPE_CHECKING - -from hydra.utils import get_class -from transformers.configuration_utils import PushToHubMixin +from dataclasses import dataclass, field, asdict +from typing import Any, Dict, Type, Optional, Union, TYPE_CHECKING +from .report import BenchmarkReport from .env_utils import get_system_info from .import_utils import get_hf_libs_info -from .benchmarks.report import BenchmarkReport from .benchmarks.config import BenchmarkConfig from .launchers.config import LauncherConfig from .backends.config import BackendConfig @@ -22,9 +19,15 @@ from .launchers.base import Launcher from .backends.base import Backend +from json import dump +from flatten_dict import flatten +from hydra.utils import get_class +from transformers.configuration_utils import PushToHubMixin LOGGER = getLogger("experiment") +EXPERIMENT_FILE_NAME = "experiment_config.json" + @dataclass class ExperimentConfig(PushToHubMixin): @@ -46,6 +49,58 @@ class ExperimentConfig(PushToHubMixin): # ENVIRONMENT CONFIGURATION environment: Dict = field(default_factory=lambda: {**get_system_info(), **get_hf_libs_info()}) + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + def to_flat_dict(self) -> Dict[str, Any]: + report_dict = self.to_dict() + return flatten(report_dict, reducer="dot") + + def to_json(self, path: str, flat: bool = False) -> None: + if flat: + with open(path, "w") as f: + dump(self.to_flat_dict(), f, indent=4) + else: + with open(path, "w") as f: + dump(self.to_dict(), f, indent=4) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + config_file_name: Optional[Union[str, os.PathLike]] = None, + push_to_hub: bool = False, + **kwargs, + ): + use_auth_token = kwargs.pop("use_auth_token", None) + + if use_auth_token is not None: + kwargs["token"] = use_auth_token + + config_file_name = config_file_name if config_file_name is not None else EXPERIMENT_FILE_NAME + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + if push_to_hub: + commit_message = kwargs.pop("commit_message", None) + repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1]) + repo_id = self._create_repo(repo_id, **kwargs) + files_timestamps = self._get_files_timestamps(save_directory) + + output_config_file = os.path.join(save_directory, config_file_name) + self.to_json(output_config_file, flat=False) + + if push_to_hub: + self._upload_modified_files( + save_directory, + repo_id, + files_timestamps, + commit_message=commit_message, + token=kwargs.get("token"), + ) + def run(benchmark_config: BenchmarkConfig, backend_config: BackendConfig) -> BenchmarkReport: try: diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index b3003bd7..30761486 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -6,9 +6,9 @@ from ..base import Launcher from .config import TorchrunConfig +from ...report import BenchmarkReport from ...logging_utils import setup_logging from ..isolation_utils import device_isolation -from ...benchmarks.report import BenchmarkReport import torch.distributed from torch.distributed import FileStore @@ -65,12 +65,15 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]: while not queue.empty(): outputs.append(queue.get()) - if len(outputs) == 1: - report: BenchmarkReport = outputs[0] - else: + if len(outputs) > 1: LOGGER.info(f"\t+ Merging benchmark reports from {len(outputs)} workers") - report: BenchmarkReport = outputs[0].aggregate(outputs) - report.log_all() + report = outputs[0].aggregate(outputs) + elif len(outputs) == 1: + report = outputs[0] + else: + raise ValueError("No benchmark report was returned by the workers") + + report.log_all() return report diff --git a/optimum_benchmark/benchmarks/report.py b/optimum_benchmark/report.py similarity index 71% rename from optimum_benchmark/benchmarks/report.py rename to optimum_benchmark/report.py index e44a0848..7317c4dc 100644 --- a/optimum_benchmark/benchmarks/report.py +++ b/optimum_benchmark/report.py @@ -1,12 +1,12 @@ +from typing import Optional, Union, List, Dict, Any from dataclasses import dataclass, asdict -from typing import Optional, Union, List from logging import getLogger from json import dump import os -from ..trackers.latency import Latency, Throughput -from ..trackers.energy import Energy, Efficiency -from ..trackers.memory import Memory +from .trackers.latency import Latency, Throughput +from .trackers.energy import Energy, Efficiency +from .trackers.memory import Memory from transformers.configuration_utils import PushToHubMixin from flatten_dict import flatten @@ -14,9 +14,11 @@ LOGGER = getLogger("report") +REPORT_FILE_NAME = "benchmark_report.json" + @dataclass -class BenchmarkMeasurements: +class Measurements: memory: Optional[Memory] = None latency: Optional[Latency] = None throughput: Optional[Throughput] = None @@ -24,14 +26,14 @@ class BenchmarkMeasurements: efficiency: Optional[Efficiency] = None @staticmethod - def aggregate(measurements: List["BenchmarkMeasurements"]) -> "BenchmarkMeasurements": + def aggregate(measurements: List["Measurements"]) -> "Measurements": memory = Memory.aggregate([m.memory for m in measurements if m.memory is not None]) latency = Latency.aggregate([m.latency for m in measurements if m.latency is not None]) throughput = Throughput.aggregate([m.throughput for m in measurements if m.throughput is not None]) energy = Energy.aggregate([m.energy for m in measurements if m.energy is not None]) efficiency = Efficiency.aggregate([m.efficiency for m in measurements if m.efficiency is not None]) - return BenchmarkMeasurements( + return Measurements( memory=memory, latency=latency, throughput=throughput, @@ -54,7 +56,7 @@ def save_pretrained( if use_auth_token is not None: kwargs["token"] = use_auth_token - config_file_name = config_file_name if config_file_name is not None else "benchmark_report.json" + config_file_name = config_file_name if config_file_name is not None else REPORT_FILE_NAME if os.path.isfile(save_directory): raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") @@ -68,7 +70,7 @@ def save_pretrained( files_timestamps = self._get_files_timestamps(save_directory) output_config_file = os.path.join(save_directory, config_file_name) - self.to_json(output_config_file) + self.to_json(output_config_file, flat=False) if push_to_hub: self._upload_modified_files( @@ -79,10 +81,10 @@ def save_pretrained( token=kwargs.get("token"), ) - def to_dict(self) -> dict: + def to_dict(self) -> Dict[str, Any]: return asdict(self) - def to_flat_dict(self) -> dict: + def to_flat_dict(self) -> Dict[str, Any]: report_dict = self.to_dict() return flatten(report_dict, reducer="dot") @@ -103,33 +105,33 @@ def to_csv(self, path: str) -> None: def log_memory(self): for target in self.to_dict().keys(): - benchmark_measurements: BenchmarkMeasurements = getattr(self, target) - if benchmark_measurements.memory is not None: - benchmark_measurements.memory.log(prefix=target) + measurements: Measurements = getattr(self, target) + if measurements.memory is not None: + measurements.memory.log(prefix=target) def log_latency(self): for target in self.to_dict().keys(): - benchmark_measurements: BenchmarkMeasurements = getattr(self, target) - if benchmark_measurements.latency is not None: - benchmark_measurements.latency.log(prefix=target) + measurements: Measurements = getattr(self, target) + if measurements.latency is not None: + measurements.latency.log(prefix=target) def log_throughput(self): for target in self.to_dict().keys(): - benchmark_measurements: BenchmarkMeasurements = getattr(self, target) - if benchmark_measurements.throughput is not None: - benchmark_measurements.throughput.log(prefix=target) + measurements: Measurements = getattr(self, target) + if measurements.throughput is not None: + measurements.throughput.log(prefix=target) def log_energy(self): for target in self.to_dict().keys(): - benchmark_measurements: BenchmarkMeasurements = getattr(self, target) - if benchmark_measurements.energy is not None: - benchmark_measurements.energy.log(prefix=target) + measurements: Measurements = getattr(self, target) + if measurements.energy is not None: + measurements.energy.log(prefix=target) def log_efficiency(self): for target in self.to_dict().keys(): - benchmark_measurements: BenchmarkMeasurements = getattr(self, target) - if benchmark_measurements.efficiency is not None: - benchmark_measurements.efficiency.log(prefix=target) + measurements: Measurements = getattr(self, target) + if measurements.efficiency is not None: + measurements.efficiency.log(prefix=target) def log_all(self): self.log_memory() @@ -143,7 +145,7 @@ def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": aggregated_report = cls() for target in aggregated_report.to_dict().keys(): measurements = [getattr(report, target) for report in reports] - aggregated_measurements = BenchmarkMeasurements.aggregate(measurements) + aggregated_measurements = Measurements.aggregate(measurements) setattr(aggregated_report, target, aggregated_measurements) return aggregated_report diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 85ae6ccd..ff42f6c5 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -6,6 +6,10 @@ import time from .utils import compute_mean, compute_stdev +from ..import_utils import is_torch_distributed_available + +if is_torch_distributed_available(): + import torch.distributed from transformers import TrainerCallback, LogitsProcessor import torch @@ -111,20 +115,25 @@ def __init__(self, device: str, backend: str): self.device = device self.backend = backend + if is_torch_distributed_available() and torch.distributed.is_initialized(): + self.distributed = True + else: + self.distributed = False + self.start_events: List[Union[float, torch.cuda.Event]] = [] self.end_events: List[Union[float, torch.cuda.Event]] = [] self.start_time: float = time.perf_counter() - def reset(self): - self.start_time = time.perf_counter() - self.start_events = [] - self.end_events = [] - if self.backend == "pytorch" and self.device == "cuda": LOGGER.info("\t+ Tracking Pytorch CUDA latency") else: LOGGER.info("\t+ Tracking CPU latency") + def reset(self): + self.start_time = time.perf_counter() + self.start_events = [] + self.end_events = [] + @contextmanager def track(self): if self.backend == "pytorch" and self.device == "cuda": @@ -158,8 +167,10 @@ def get_elapsed_time(self) -> float: def get_latency(self) -> Latency: if self.backend == "pytorch" and self.device == "cuda": - # synchronize the device to make sure all events have been recorded - torch.cuda.synchronize() + # synchronize the last event to make sure it has been recorded + self.start_events[-1].synchronize() + self.end_events[-1].synchronize() + latencies_list = [ self.start_events[i].elapsed_time(self.end_events[i]) / 1e3 for i in range(len(self.start_events)) ] diff --git a/optimum_benchmark/trackers/memory.py b/optimum_benchmark/trackers/memory.py index 86df0816..f788ac46 100644 --- a/optimum_benchmark/trackers/memory.py +++ b/optimum_benchmark/trackers/memory.py @@ -50,13 +50,13 @@ def __add__(self, other: "Memory") -> "Memory": ) @staticmethod - def aggregate(max_memories: List["Memory"]) -> "Memory": - if len(max_memories) == 0 or all(memory is None for memory in max_memories): + def aggregate(memories: List["Memory"]) -> "Memory": + if len(memories) == 0 or all(memory is None for memory in memories): return None - elif any(memory is None for memory in max_memories): + elif any(memory is None for memory in memories): raise ValueError("Some memory measurements are missing") - return reduce(lambda x, y: x + y, max_memories) + return reduce(lambda x, y: x + y, memories) def log(self, prefix: str = "forward"): LOGGER.info(f"\t\t+ {prefix} max RAM memory: {self.max_ram:f} ({self.unit})") @@ -124,11 +124,11 @@ def _cuda_pytorch_memory(self): yield from self._cuda_memory() - self.max_allocated_memory = ( - sum(torch.cuda.max_memory_allocated(device=device) for device in range(torch.cuda.device_count())) / 1e6 + self.max_allocated_memory = sum( + torch.cuda.max_memory_allocated(device=device) / 1e6 for device in range(torch.cuda.device_count()) ) - self.max_reserved_memory = ( - sum(torch.cuda.max_memory_reserved(device=device) for device in range(torch.cuda.device_count())) / 1e6 + self.max_reserved_memory = sum( + torch.cuda.max_memory_reserved(device=device) / 1e6 for device in range(torch.cuda.device_count()) ) def _cuda_memory(self): @@ -227,12 +227,16 @@ def monitor_gpu_vram_memory( LOGGER.warning(f"Could not get process list for device {device_id}: {e}") continue for device_process in device_processes: - if device_process.pid == process_id or ( - psutil.pid_exists(device_process.pid) - and psutil.Process(device_process.pid).parent().pid == process_id - ): - # only memory usage of the process and its children is tracked + if device_process.pid == process_id: current_used_memory += device_process.usedGpuMemory + else: + try: + cpu_process = psutil.Process(device_process.pid) + except Exception as e: + LOGGER.warning(f"Could not get process info for process {device_process.pid}: {e}") + continue + if cpu_process.parent() is not None and cpu_process.parent().pid == process_id: + current_used_memory += device_process.usedGpuMemory max_memory = max(max_memory, current_used_memory) stop = connection.poll(interval) @@ -255,24 +259,29 @@ def monitor_gpu_vram_memory( for device_id in device_ids: device_handle = devices_handles[device_id] try: - device_process = amdsmi.amdsmi_get_gpu_process_list(device_handle) + processes_handles = amdsmi.amdsmi_get_gpu_process_list(device_handle) except Exception as e: LOGGER.warning(f"Could not get process list for device {device_id}: {e}") continue - - for process_handle in device_process: + for process_handle in processes_handles: try: - process_info = amdsmi.amdsmi_get_gpu_process_info(device_handle, process_handle) + gpu_process_info = amdsmi.amdsmi_get_gpu_process_info(device_handle, process_handle) except Exception as e: LOGGER.warning(f"Could not get process info for process {process_handle}: {e}") continue - - if process_info["pid"] == process_id or ( - psutil.pid_exists(process_info["pid"]) - and psutil.Process(process_info["pid"]).parent().pid == process_id - ): - # only memory usage of the monitored process and its children is tracked - current_used_memory += process_info["memory_usage"]["vram_mem"] + # only memory usage of the monitored process and its children is tracked + if gpu_process_info["pid"] == process_id: + current_used_memory += gpu_process_info["memory_usage"]["vram_mem"] + else: + try: + cpu_process_info = psutil.Process(gpu_process_info["pid"]) + except Exception as e: + LOGGER.warning( + f"Could not get process info for process {gpu_process_info['pid']}: {e}" + ) + continue + if cpu_process_info.parent() is not None and cpu_process_info.parent().pid == process_id: + current_used_memory += gpu_process_info["memory_usage"]["vram_mem"] max_memory = max(max_memory, current_used_memory) stop = connection.poll(interval) @@ -283,24 +292,29 @@ def monitor_gpu_vram_memory( for device_id in device_ids: device_handle = devices_handles[device_id] try: - device_process = amdsmi.amdsmi_get_process_list(device_handle) + processes_handles = amdsmi.amdsmi_get_process_list(device_handle) except Exception as e: LOGGER.warning(f"Could not get process list for device {device_id}: {e}") continue - - for process_handle in device_process: + for process_handle in processes_handles: try: - process_info = amdsmi.amdsmi_get_process_info(device_handle, process_handle) + gpu_process_info = amdsmi.amdsmi_get_process_info(device_handle, process_handle) except Exception as e: LOGGER.warning(f"Could not get process info for process {process_handle}: {e}") continue - - if process_info["pid"] == process_id or ( - psutil.pid_exists(process_info["pid"]) - and psutil.Process(process_info["pid"]).parent().pid == process_id - ): - # only memory usage of the monitored process and its children is tracked - current_used_memory += process_info["memory_usage"]["vram_mem"] + # only memory usage of the monitored process and its children is tracked + if gpu_process_info["pid"] == process_id: + current_used_memory += gpu_process_info["memory_usage"]["vram_mem"] + else: + try: + cpu_process_info = psutil.Process(gpu_process_info["pid"]) + except Exception as e: + LOGGER.warning( + f"Could not get process info for process {gpu_process_info['pid']}: {e}" + ) + continue + if cpu_process_info.parent() is not None and cpu_process_info.parent().pid == process_id: + current_used_memory += gpu_process_info["memory_usage"]["vram_mem"] max_memory = max(max_memory, current_used_memory) stop = connection.poll(interval)