diff --git a/.github/workflows/test_cli_cpu_py_tgi.yaml b/.github/workflows/test_cli_cpu_py_tgi.yaml index 1ec01c84..b7fc1c5a 100644 --- a/.github/workflows/test_cli_cpu_py_tgi.yaml +++ b/.github/workflows/test_cli_cpu_py_tgi.yaml @@ -15,6 +15,15 @@ jobs: run_cli_cpu_py_tgi_tests: runs-on: ubuntu-latest steps: + - name: Free disk space + uses: jlumbroso/free-disk-space@main + with: + dotnet: true + android: true + haskell: true + docker-images: true + large-packages: false + - name: Checkout uses: actions/checkout@v3 diff --git a/optimum_benchmark/backends/llm_swarm/backend.py b/optimum_benchmark/backends/llm_swarm/backend.py new file mode 100644 index 00000000..dba63433 --- /dev/null +++ b/optimum_benchmark/backends/llm_swarm/backend.py @@ -0,0 +1,94 @@ +import asyncio +import gc +from logging import getLogger +from typing import Any, Dict, List + +import torch +from huggingface_hub import AsyncInferenceClient +from llm_swarm import LLMSwarm +from llm_swarm import LLMSwarmConfig as LLMSwarmCfg + +from ...task_utils import TEXT_GENERATION_TASKS +from ..base import Backend +from .config import LLMSwarmConfig + +# bachend logger +LOGGER = getLogger("llm-swarm") + + +class LLMSwarmBackend(Backend[LLMSwarmConfig]): + NAME: str = "llm-swarm" + + def __init__(self, config: LLMSwarmConfig) -> None: + super().__init__(config) + self.validate_task() + + LOGGER.info("\t+ Downloading pretrained model") + self.download_pretrained_model() + LOGGER.info("\t+ Preparing generation config") + self.prepare_generation_config() + LOGGER.info("\t+ Loading pretrained model") + self.load_model_from_pretrained() + + def validate_task(self) -> None: + if self.config.task not in TEXT_GENERATION_TASKS: + raise NotImplementedError(f"LLM Swarm does not support task {self.config.task}") + + def load_model_from_pretrained(self) -> None: + self.llm_swarm_config = LLMSwarmCfg( + gpus=self.config.gpus, + model=self.config.model, + instances=self.config.instances, + inference_engine=self.config.inference_engine, + slurm_template_path=self.config.slurm_template_path, + load_balancer_template_path=self.config.load_balancer_template_path, + per_instance_max_parallel_requests=self.config.per_instance_max_parallel_requests, + revision=self.config.hub_kwargs.get("revision", "main"), + debug_endpoint=self.config.debug_endpoint, + ) + self.llm_swarm = LLMSwarm(self.llm_swarm_config).__enter__() + self.client = AsyncInferenceClient(self.llm_swarm.endpoint) + + def download_pretrained_model(self) -> None: + with torch.device("meta"): + self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs) + + def prepare_generation_config(self) -> None: + self.generation_config.eos_token_id = -100 + self.generation_config.pad_token_id = -100 + model_cache_folder = f"models/{self.config.model}".replace("/", "--") + model_cache_path = f"{self.config.volume}/{model_cache_folder}" + snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}" + snapshot_ref = open(snapshot_file, "r").read().strip() + model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}" + LOGGER.info("\t+ Saving new pretrained generation config") + self.generation_config.save_pretrained(save_directory=model_snapshot_path) + + def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + if "inputs" in inputs: + return {"prompt": self.pretrained_processor.batch_decode(inputs["inputs"].tolist())} + elif "input_ids" in inputs: + return {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())} + else: + raise ValueError("inputs must contain either input_ids or inputs") + + async def single_client_call(self, prompt: str, kwargs: Dict[str, Any]) -> str: + return await self.client.text_generation(prompt, max_new_tokens=kwargs.get("max_new_tokens", 1)) + + async def batch_client_call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: + return await asyncio.gather(*(self.single_client_call(p, kwargs) for p in inputs["prompt"])) + + def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: + return asyncio.run(self.batch_client_call(inputs, kwargs)) + + def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]: + return asyncio.run(self.batch_client_call(inputs, kwargs)) + + def clean(self) -> None: + super().clean() + + if hasattr(self, "llm_swarm"): + LOGGER.info("Cleaning up LLM Swarm") + self.llm_swarm.__exit__(None, None, None) + + gc.collect() diff --git a/optimum_benchmark/backends/llm_swarm/config.py b/optimum_benchmark/backends/llm_swarm/config.py new file mode 100644 index 00000000..745cdd3f --- /dev/null +++ b/optimum_benchmark/backends/llm_swarm/config.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass +from typing import Optional + +from ...import_utils import llm_swarm_version +from ..config import BackendConfig + + +@dataclass +class LLMSwarmConfig(BackendConfig): + name: str = "llm-swarm" + version: Optional[str] = llm_swarm_version() + _target_: str = "optimum_benchmark.backends.llm_swarm.backend.LLMSwarmBackend" + + # optimum benchmark specific + no_weights: bool = False + + # llm-swarm specific + gpus: int = 8 + instances: int = 1 + inference_engine: str = "tgi" + volume: str = "/fsx/ilyas/.cache" + per_instance_max_parallel_requests: int = 500 + slurm_template_path: str = "/fsx/ilyas/swarm-templates/tgi_h100.template.slurm" + load_balancer_template_path: str = "/fsx/ilyas/swarm-templates/nginx.template.conf" + debug_endpoint: Optional[str] = None + + def __post_init__(self): + super().__post_init__() + + # so that downloaded artifacts are stored in the same place + self.hub_kwargs["cache_dir"] = self.volume diff --git a/optimum_benchmark/cli.py b/optimum_benchmark/cli.py index 0a9254ab..0806c09e 100644 --- a/optimum_benchmark/cli.py +++ b/optimum_benchmark/cli.py @@ -6,6 +6,7 @@ from hydra.core.config_store import ConfigStore from omegaconf import DictConfig, OmegaConf +from .backends.llm_swarm.config import LLMSwarmConfig from .backends.neural_compressor.config import INCConfig from .backends.onnxruntime.config import ORTConfig from .backends.openvino.config import OVConfig @@ -34,6 +35,7 @@ cs.store(group="backend", name=TRTLLMConfig.name, node=TRTLLMConfig) cs.store(group="backend", name=INCConfig.name, node=INCConfig) cs.store(group="backend", name=PyTGIConfig.name, node=PyTGIConfig) +cs.store(group="backend", name=LLMSwarmConfig.name, node=LLMSwarmConfig) # benchmarks configurations cs.store(group="benchmark", name=TrainingConfig.name, node=TrainingConfig) cs.store(group="benchmark", name=InferenceConfig.name, node=InferenceConfig) diff --git a/optimum_benchmark/experiment.py b/optimum_benchmark/experiment.py index c9a556cc..268daaa1 100644 --- a/optimum_benchmark/experiment.py +++ b/optimum_benchmark/experiment.py @@ -174,6 +174,7 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport: launcher: Launcher = launcher_factory(launcher_config) except Exception as e: LOGGER.error(f"Error during launcher allocation: {e}") + os.chdir(original_dir) tmpdir.cleanup() raise e @@ -184,6 +185,7 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport: output = launcher.launch(run, benchmark_config, backend_config) except Exception as e: LOGGER.error(f"Error during experiment launching: {e}") + os.chdir(original_dir) tmpdir.cleanup() raise e diff --git a/optimum_benchmark/import_utils.py b/optimum_benchmark/import_utils.py index 7ee853bf..75cfec66 100644 --- a/optimum_benchmark/import_utils.py +++ b/optimum_benchmark/import_utils.py @@ -29,6 +29,11 @@ _optimum_benchmark_available = importlib.util.find_spec("optimum_benchmark") is not None _py_tgi_available = importlib.util.find_spec("py_tgi") is not None _pyrsmi_available = importlib.util.find_spec("pyrsmi") is not None +_llm_swarm_available = importlib.util.find_spec("llm_swarm") is not None + + +def is_llm_swarm_available(): + return _llm_swarm_available def is_pyrsmi_available(): @@ -198,6 +203,11 @@ def py_tgi_version(): return importlib.metadata.version("py_tgi") +def llm_swarm_version(): + if _llm_swarm_available: + return importlib.metadata.version("llm_swarm") + + def get_git_revision_hash(package_name: str) -> Optional[str]: """ Returns the git commit SHA of a package installed from a git repository. diff --git a/setup.py b/setup.py index dba055ff..50dc0528 100644 --- a/setup.py +++ b/setup.py @@ -46,9 +46,10 @@ "onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={MIN_OPTIMUM_VERSION}"], "neural-compressor": [f"optimum[neural-compressor]>={MIN_OPTIMUM_VERSION}"], "torch-ort": ["torch-ort", "onnxruntime-training", f"optimum>={MIN_OPTIMUM_VERSION}"], - # docker-based backends + # other backends + "llm-swarm": ["llm-swarm@git+https://github.com/huggingface/llm-swarm.git"], "py-tgi": ["py-tgi==0.1.3"], - # third-party features + # optional dependencies "codecarbon": ["codecarbon"], "deepspeed": ["deepspeed"], "diffusers": ["diffusers"],