Skip to content

Commit

Permalink
llm-swarm backend integration for slurm clusters (#142)
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil authored Mar 4, 2024
1 parent 28c89c7 commit 2c58b76
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 2 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/test_cli_cpu_py_tgi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ jobs:
run_cli_cpu_py_tgi_tests:
runs-on: ubuntu-latest
steps:
- name: Free disk space
uses: jlumbroso/free-disk-space@main
with:
dotnet: true
android: true
haskell: true
docker-images: true
large-packages: false

- name: Checkout
uses: actions/checkout@v3

Expand Down
94 changes: 94 additions & 0 deletions optimum_benchmark/backends/llm_swarm/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
import gc
from logging import getLogger
from typing import Any, Dict, List

import torch
from huggingface_hub import AsyncInferenceClient
from llm_swarm import LLMSwarm
from llm_swarm import LLMSwarmConfig as LLMSwarmCfg

from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from .config import LLMSwarmConfig

# bachend logger
LOGGER = getLogger("llm-swarm")


class LLMSwarmBackend(Backend[LLMSwarmConfig]):
NAME: str = "llm-swarm"

def __init__(self, config: LLMSwarmConfig) -> None:
super().__init__(config)
self.validate_task()

LOGGER.info("\t+ Downloading pretrained model")
self.download_pretrained_model()
LOGGER.info("\t+ Preparing generation config")
self.prepare_generation_config()
LOGGER.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()

def validate_task(self) -> None:
if self.config.task not in TEXT_GENERATION_TASKS:
raise NotImplementedError(f"LLM Swarm does not support task {self.config.task}")

def load_model_from_pretrained(self) -> None:
self.llm_swarm_config = LLMSwarmCfg(
gpus=self.config.gpus,
model=self.config.model,
instances=self.config.instances,
inference_engine=self.config.inference_engine,
slurm_template_path=self.config.slurm_template_path,
load_balancer_template_path=self.config.load_balancer_template_path,
per_instance_max_parallel_requests=self.config.per_instance_max_parallel_requests,
revision=self.config.hub_kwargs.get("revision", "main"),
debug_endpoint=self.config.debug_endpoint,
)
self.llm_swarm = LLMSwarm(self.llm_swarm_config).__enter__()
self.client = AsyncInferenceClient(self.llm_swarm.endpoint)

def download_pretrained_model(self) -> None:
with torch.device("meta"):
self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs)

def prepare_generation_config(self) -> None:
self.generation_config.eos_token_id = -100
self.generation_config.pad_token_id = -100
model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{self.config.volume}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
LOGGER.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if "inputs" in inputs:
return {"prompt": self.pretrained_processor.batch_decode(inputs["inputs"].tolist())}
elif "input_ids" in inputs:
return {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())}
else:
raise ValueError("inputs must contain either input_ids or inputs")

async def single_client_call(self, prompt: str, kwargs: Dict[str, Any]) -> str:
return await self.client.text_generation(prompt, max_new_tokens=kwargs.get("max_new_tokens", 1))

async def batch_client_call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return await asyncio.gather(*(self.single_client_call(p, kwargs) for p in inputs["prompt"]))

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return asyncio.run(self.batch_client_call(inputs, kwargs))

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return asyncio.run(self.batch_client_call(inputs, kwargs))

def clean(self) -> None:
super().clean()

if hasattr(self, "llm_swarm"):
LOGGER.info("Cleaning up LLM Swarm")
self.llm_swarm.__exit__(None, None, None)

gc.collect()
31 changes: 31 additions & 0 deletions optimum_benchmark/backends/llm_swarm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import Optional

from ...import_utils import llm_swarm_version
from ..config import BackendConfig


@dataclass
class LLMSwarmConfig(BackendConfig):
name: str = "llm-swarm"
version: Optional[str] = llm_swarm_version()
_target_: str = "optimum_benchmark.backends.llm_swarm.backend.LLMSwarmBackend"

# optimum benchmark specific
no_weights: bool = False

# llm-swarm specific
gpus: int = 8
instances: int = 1
inference_engine: str = "tgi"
volume: str = "/fsx/ilyas/.cache"
per_instance_max_parallel_requests: int = 500
slurm_template_path: str = "/fsx/ilyas/swarm-templates/tgi_h100.template.slurm"
load_balancer_template_path: str = "/fsx/ilyas/swarm-templates/nginx.template.conf"
debug_endpoint: Optional[str] = None

def __post_init__(self):
super().__post_init__()

# so that downloaded artifacts are stored in the same place
self.hub_kwargs["cache_dir"] = self.volume
2 changes: 2 additions & 0 deletions optimum_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf

from .backends.llm_swarm.config import LLMSwarmConfig
from .backends.neural_compressor.config import INCConfig
from .backends.onnxruntime.config import ORTConfig
from .backends.openvino.config import OVConfig
Expand Down Expand Up @@ -34,6 +35,7 @@
cs.store(group="backend", name=TRTLLMConfig.name, node=TRTLLMConfig)
cs.store(group="backend", name=INCConfig.name, node=INCConfig)
cs.store(group="backend", name=PyTGIConfig.name, node=PyTGIConfig)
cs.store(group="backend", name=LLMSwarmConfig.name, node=LLMSwarmConfig)
# benchmarks configurations
cs.store(group="benchmark", name=TrainingConfig.name, node=TrainingConfig)
cs.store(group="benchmark", name=InferenceConfig.name, node=InferenceConfig)
Expand Down
2 changes: 2 additions & 0 deletions optimum_benchmark/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport:
launcher: Launcher = launcher_factory(launcher_config)
except Exception as e:
LOGGER.error(f"Error during launcher allocation: {e}")
os.chdir(original_dir)
tmpdir.cleanup()
raise e

Expand All @@ -184,6 +185,7 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport:
output = launcher.launch(run, benchmark_config, backend_config)
except Exception as e:
LOGGER.error(f"Error during experiment launching: {e}")
os.chdir(original_dir)
tmpdir.cleanup()
raise e

Expand Down
10 changes: 10 additions & 0 deletions optimum_benchmark/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
_optimum_benchmark_available = importlib.util.find_spec("optimum_benchmark") is not None
_py_tgi_available = importlib.util.find_spec("py_tgi") is not None
_pyrsmi_available = importlib.util.find_spec("pyrsmi") is not None
_llm_swarm_available = importlib.util.find_spec("llm_swarm") is not None


def is_llm_swarm_available():
return _llm_swarm_available


def is_pyrsmi_available():
Expand Down Expand Up @@ -198,6 +203,11 @@ def py_tgi_version():
return importlib.metadata.version("py_tgi")


def llm_swarm_version():
if _llm_swarm_available:
return importlib.metadata.version("llm_swarm")


def get_git_revision_hash(package_name: str) -> Optional[str]:
"""
Returns the git commit SHA of a package installed from a git repository.
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
"onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={MIN_OPTIMUM_VERSION}"],
"neural-compressor": [f"optimum[neural-compressor]>={MIN_OPTIMUM_VERSION}"],
"torch-ort": ["torch-ort", "onnxruntime-training", f"optimum>={MIN_OPTIMUM_VERSION}"],
# docker-based backends
# other backends
"llm-swarm": ["llm-swarm@git+https://github.com/huggingface/llm-swarm.git"],
"py-tgi": ["py-tgi==0.1.3"],
# third-party features
# optional dependencies
"codecarbon": ["codecarbon"],
"deepspeed": ["deepspeed"],
"diffusers": ["diffusers"],
Expand Down

0 comments on commit 2c58b76

Please sign in to comment.