Skip to content

Commit

Permalink
experiment mixin
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Feb 14, 2024
1 parent b125566 commit c542546
Show file tree
Hide file tree
Showing 10 changed files with 201 additions and 127 deletions.
2 changes: 1 addition & 1 deletion optimum_benchmark/benchmarks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import ClassVar, Generic

from ..backends.base import Backend
from .report import BenchmarkReport
from ..report import BenchmarkReport
from .config import BenchmarkConfigT

LOGGER = getLogger("benchmark")
Expand Down
27 changes: 22 additions & 5 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from logging import getLogger
from dataclasses import dataclass

from ..base import Benchmark
from .config import InferenceConfig
from ...trackers.memory import MemoryTracker
from ...report import BenchmarkReport, Measurements
from ...backends.base import Backend, BackendConfigT
from ...generators.input_generator import InputGenerator
from ...trackers.energy import EnergyTracker, Efficiency
from ...trackers.latency import LatencyTracker, Throughput
from ...import_utils import is_torch_distributed_available
from ...task_utils import TEXT_GENERATION_TASKS, IMAGE_DIFFUSION_TASKS
from .report import InferenceReport, TextGenerationReport, ImageDiffusionReport

if is_torch_distributed_available():
import torch.distributed
Expand Down Expand Up @@ -44,6 +45,22 @@
CALL_EFFICIENCY_UNIT = "images/kWh"


@dataclass
class InferenceReport(BenchmarkReport):
forward: Measurements = Measurements()


@dataclass
class ImageDiffusionReport(BenchmarkReport):
call: Measurements = Measurements()


@dataclass
class TextGenerationReport(BenchmarkReport):
prefill: Measurements = Measurements()
decode: Measurements = Measurements()


class InferenceBenchmark(Benchmark[InferenceConfig]):
NAME = "inference"

Expand Down Expand Up @@ -157,29 +174,29 @@ def run_text_generation_memory_tracking(self, backend: Backend):
with self.memory_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.prefill.max_memory = self.memory_tracker.get_max_memory()
self.report.prefill.memory = self.memory_tracker.get_max_memory()

self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.generate(self.generate_input, self.config.generate_kwargs)

self.report.decode.max_memory = self.memory_tracker.get_max_memory()
self.report.decode.memory = self.memory_tracker.get_max_memory()

def run_image_diffusion_memory_tracking(self, backend: Backend):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.call(self.diffuse_input, self.config.forward_kwargs)

self.report.call.max_memory = self.memory_tracker.get_max_memory()
self.report.call.memory = self.memory_tracker.get_max_memory()

def run_inference_memory_tracking(self, backend: Backend):
LOGGER.info("\t+ Running memory tracking")
self.memory_tracker.reset()
with self.memory_tracker.track():
_ = backend.forward(self.forward_inputs, self.config.forward_kwargs)

self.report.forward.max_memory = self.memory_tracker.get_max_memory()
self.report.forward.memory = self.memory_tracker.get_max_memory()

## Latency tracking
def run_text_generation_latency_tracking(self, backend: Backend):
Expand Down
22 changes: 0 additions & 22 deletions optimum_benchmark/benchmarks/inference/report.py

This file was deleted.

17 changes: 12 additions & 5 deletions optimum_benchmark/benchmarks/training/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from logging import getLogger
from contextlib import ExitStack
from dataclasses import dataclass

from ..base import Benchmark
from .config import TrainingConfig
from .report import TrainingReport
from ...trackers.memory import MemoryTracker
from ...report import BenchmarkReport, Measurements
from ...backends.base import Backend, BackendConfigT
from ...trackers.energy import EnergyTracker, Efficiency
from ...generators.dataset_generator import DatasetGenerator
Expand All @@ -18,6 +19,13 @@
TRAIN_EFFICIENCY_UNIT = "samples/kWh"


@dataclass
class TrainingReport(BenchmarkReport):
overall: Measurements = Measurements()
warmup: Measurements = Measurements()
train: Measurements = Measurements()


class TrainingBenchmark(Benchmark[TrainingConfig]):
NAME = "training"

Expand Down Expand Up @@ -69,10 +77,9 @@ def run(self, backend: Backend[BackendConfigT]) -> None:
)

if self.config.memory:
# it's the same
self.report.overall.max_memory = memory_tracker.get_max_memory()
self.report.warmup.max_memory = memory_tracker.get_max_memory()
self.report.train.max_memory = memory_tracker.get_max_memory()
self.report.overall.memory = memory_tracker.get_max_memory()
self.report.warmup.memory = memory_tracker.get_max_memory()
self.report.train.memory = memory_tracker.get_max_memory()

self.report.log_memory()

Expand Down
13 changes: 0 additions & 13 deletions optimum_benchmark/benchmarks/training/report.py

This file was deleted.

67 changes: 61 additions & 6 deletions optimum_benchmark/experiment.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import os
from logging import getLogger
from tempfile import TemporaryDirectory
from dataclasses import dataclass, field
from typing import Any, Dict, Type, Optional, TYPE_CHECKING

from hydra.utils import get_class
from transformers.configuration_utils import PushToHubMixin
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, Type, Optional, Union, TYPE_CHECKING

from .report import BenchmarkReport
from .env_utils import get_system_info
from .import_utils import get_hf_libs_info
from .benchmarks.report import BenchmarkReport
from .benchmarks.config import BenchmarkConfig
from .launchers.config import LauncherConfig
from .backends.config import BackendConfig
Expand All @@ -22,9 +19,15 @@
from .launchers.base import Launcher
from .backends.base import Backend

from json import dump
from flatten_dict import flatten
from hydra.utils import get_class
from transformers.configuration_utils import PushToHubMixin

LOGGER = getLogger("experiment")

EXPERIMENT_FILE_NAME = "experiment_config.json"


@dataclass
class ExperimentConfig(PushToHubMixin):
Expand All @@ -46,6 +49,58 @@ class ExperimentConfig(PushToHubMixin):
# ENVIRONMENT CONFIGURATION
environment: Dict = field(default_factory=lambda: {**get_system_info(), **get_hf_libs_info()})

def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def to_flat_dict(self) -> Dict[str, Any]:
report_dict = self.to_dict()
return flatten(report_dict, reducer="dot")

def to_json(self, path: str, flat: bool = False) -> None:
if flat:
with open(path, "w") as f:
dump(self.to_flat_dict(), f, indent=4)
else:
with open(path, "w") as f:
dump(self.to_dict(), f, indent=4)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
push_to_hub: bool = False,
**kwargs,
):
use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
kwargs["token"] = use_auth_token

config_file_name = config_file_name if config_file_name is not None else EXPERIMENT_FILE_NAME

if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

os.makedirs(save_directory, exist_ok=True)

if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

output_config_file = os.path.join(save_directory, config_file_name)
self.to_json(output_config_file, flat=False)

if push_to_hub:
self._upload_modified_files(
save_directory,
repo_id,
files_timestamps,
commit_message=commit_message,
token=kwargs.get("token"),
)


def run(benchmark_config: BenchmarkConfig, backend_config: BackendConfig) -> BenchmarkReport:
try:
Expand Down
15 changes: 9 additions & 6 deletions optimum_benchmark/launchers/torchrun/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

from ..base import Launcher
from .config import TorchrunConfig
from ...report import BenchmarkReport
from ...logging_utils import setup_logging
from ..isolation_utils import device_isolation
from ...benchmarks.report import BenchmarkReport

import torch.distributed
from torch.distributed import FileStore
Expand Down Expand Up @@ -65,12 +65,15 @@ def launch(self, worker: Callable, *worker_args) -> Dict[str, Any]:
while not queue.empty():
outputs.append(queue.get())

if len(outputs) == 1:
report: BenchmarkReport = outputs[0]
else:
if len(outputs) > 1:
LOGGER.info(f"\t+ Merging benchmark reports from {len(outputs)} workers")
report: BenchmarkReport = outputs[0].aggregate(outputs)
report.log_all()
report = outputs[0].aggregate(outputs)
elif len(outputs) == 1:
report = outputs[0]
else:
raise ValueError("No benchmark report was returned by the workers")

report.log_all()

return report

Expand Down
Loading

0 comments on commit c542546

Please sign in to comment.