Skip to content

Commit

Permalink
Better hub utils (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Mar 4, 2024
1 parent 99c4ad8 commit a3cd823
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 128 deletions.
72 changes: 7 additions & 65 deletions optimum_benchmark/benchmarks/report.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
import os
from dataclasses import asdict, dataclass
from json import dump
from dataclasses import dataclass
from logging import getLogger
from typing import Any, Dict, List, Optional, Union

import pandas as pd
from flatten_dict import flatten
from transformers.configuration_utils import PushToHubMixin
from typing import List, Optional

from ..hub_utils import PushToHubMixin
from ..trackers.energy import Efficiency, Energy
from ..trackers.latency import Latency, Throughput
from ..trackers.memory import Memory

LOGGER = getLogger("report")

REPORT_FILE_NAME = "benchmark_report.json"


@dataclass
class BenchmarkMeasurements:
Expand Down Expand Up @@ -60,61 +53,6 @@ def aggregate(benchmark_measurements: List["BenchmarkMeasurements"]) -> "Benchma

@dataclass
class BenchmarkReport(PushToHubMixin):
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
push_to_hub: bool = False,
**kwargs,
):
use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
kwargs["token"] = use_auth_token

config_file_name = config_file_name if config_file_name is not None else REPORT_FILE_NAME

if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

os.makedirs(save_directory, exist_ok=True)

if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

output_config_file = os.path.join(save_directory, config_file_name)
self.to_json(output_config_file, flat=False)

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=kwargs.get("token")
)

def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def to_flat_dict(self) -> Dict[str, Any]:
report_dict = self.to_dict()
return flatten(report_dict, reducer="dot")

def to_json(self, path: str, flat: bool = False) -> None:
if flat:
with open(path, "w") as f:
dump(self.to_flat_dict(), f, indent=4)
else:
with open(path, "w") as f:
dump(self.to_dict(), f, indent=4)

def to_dataframe(self) -> pd.DataFrame:
flat_report_dict = self.to_flat_dict()
return pd.DataFrame.from_dict(flat_report_dict, orient="index")

def to_csv(self, path: str) -> None:
self.to_dataframe().to_csv(path, index=False)

def log_memory(self):
for target in self.to_dict().keys():
benchmark_measurements: BenchmarkMeasurements = getattr(self, target)
Expand Down Expand Up @@ -167,3 +105,7 @@ def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport":
aggregated_measurements[target] = BenchmarkMeasurements.aggregate(benchmark_measurements)

return cls(**aggregated_measurements)

@property
def file_name(self) -> str:
return "benchmark_report.json"
77 changes: 14 additions & 63 deletions optimum_benchmark/experiment.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
from dataclasses import asdict, dataclass, field
from dataclasses import dataclass, field
from logging import getLogger
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Type

from .backends.config import BackendConfig
from .benchmarks.config import BenchmarkConfig
from .benchmarks.report import BenchmarkReport
from .hub_utils import PushToHubMixin
from .import_utils import get_hf_libs_info
from .launchers.config import LauncherConfig
from .system_utils import get_system_info
Expand All @@ -19,17 +20,10 @@
from .benchmarks.base import Benchmark
from .launchers.base import Launcher

from json import dump

import pandas as pd
from flatten_dict import flatten
from hydra.utils import get_class
from transformers.configuration_utils import PushToHubMixin

LOGGER = getLogger("experiment")

EXPERIMENT_FILE_NAME = "experiment_config.json"


@dataclass
class ExperimentConfig(PushToHubMixin):
Expand All @@ -51,63 +45,16 @@ class ExperimentConfig(PushToHubMixin):
# ENVIRONMENT CONFIGURATION
environment: Dict = field(default_factory=lambda: {**get_system_info(), **get_hf_libs_info()})

def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def to_flat_dict(self) -> Dict[str, Any]:
report_dict = self.to_dict()
return flatten(report_dict, reducer="dot")

def to_json(self, path: str, flat: bool = False) -> None:
if flat:
with open(path, "w") as f:
dump(self.to_flat_dict(), f, indent=4)
else:
with open(path, "w") as f:
dump(self.to_dict(), f, indent=4)

def to_dataframe(self) -> pd.DataFrame:
flat_report_dict = self.to_flat_dict()
return pd.DataFrame.from_dict(flat_report_dict, orient="index")

def to_csv(self, path: str) -> None:
self.to_dataframe().to_csv(path, index=False)

def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
config_file_name: Optional[Union[str, os.PathLike]] = None,
push_to_hub: bool = False,
**kwargs,
):
use_auth_token = kwargs.pop("use_auth_token", None)

if use_auth_token is not None:
kwargs["token"] = use_auth_token

config_file_name = config_file_name if config_file_name is not None else EXPERIMENT_FILE_NAME

if os.path.isfile(save_directory):
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")

os.makedirs(save_directory, exist_ok=True)

if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = self._create_repo(repo_id, **kwargs)
files_timestamps = self._get_files_timestamps(save_directory)

output_config_file = os.path.join(save_directory, config_file_name)
self.to_json(output_config_file, flat=False)

if push_to_hub:
self._upload_modified_files(
save_directory, repo_id, files_timestamps, commit_message=commit_message, token=kwargs.get("token")
)
@property
def file_name(self) -> str:
return "experiment_config.json"


def run(benchmark_config: BenchmarkConfig, backend_config: BackendConfig) -> BenchmarkReport:
"""
Runs a benchmark using specified backend and benchmark configurations
"""

try:
# Allocate requested backend
backend_factory: Type[Backend] = get_class(backend_config._target_)
Expand Down Expand Up @@ -144,6 +91,10 @@ def run(benchmark_config: BenchmarkConfig, backend_config: BackendConfig) -> Ben


def launch(experiment_config: ExperimentConfig) -> BenchmarkReport:
"""
Runs an experiment using specified launcher configuration/logic
"""

# fix backend until deprecated model and device are removed
if experiment_config.task is not None:
LOGGER.warning("`task` is deprecated in experiment config. Use `backend.task` instead.")
Expand Down
69 changes: 69 additions & 0 deletions optimum_benchmark/hub_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import tempfile
from dataclasses import asdict
from json import dump
from logging import getLogger
from typing import Any, Dict, Optional, Union

import pandas as pd
from flatten_dict import flatten
from huggingface_hub import create_repo, upload_file

LOGGER = getLogger(__name__)


class PushToHubMixin:
"""
A Mixin to push artifacts to the Hugging Face Hub
"""

def to_dict(self) -> Dict[str, Any]:
return asdict(self)

def to_flat_dict(self) -> Dict[str, Any]:
report_dict = self.to_dict()
return flatten(report_dict, reducer="dot")

def to_json(self, path: str, flat: bool = False) -> None:
if flat:
with open(path, "w") as f:
dump(self.to_flat_dict(), f, indent=4)
else:
with open(path, "w") as f:
dump(self.to_dict(), f, indent=4)

def to_dataframe(self) -> pd.DataFrame:
flat_report_dict = self.to_flat_dict()
return pd.DataFrame.from_dict(flat_report_dict, orient="index")

def to_csv(self, path: str) -> None:
self.to_dataframe().to_csv(path, index=False)

def push_to_hub(
self,
repo_id: str,
file_name: Optional[Union[str, os.PathLike]] = None,
path_in_repo: Optional[str] = None,
flat: bool = False,
**kwargs,
) -> str:
token = kwargs.get("token", None)
private = kwargs.get("private", False)
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id

with tempfile.TemporaryDirectory() as tmpdir:
file_name = file_name or self.file_name
path_or_fileobj = os.path.join(tmpdir, file_name)
path_in_repo = path_in_repo or file_name
self.to_json(path_or_fileobj, flat=flat)

upload_file(
path_or_fileobj=path_or_fileobj,
path_in_repo=path_in_repo,
repo_id=repo_id,
**kwargs,
)

@property
def file_name(self) -> str:
return "config.json"

0 comments on commit a3cd823

Please sign in to comment.