Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llm-swarm backend integration for slurm clusters #142

Merged
merged 4 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/test_cli_cpu_py_tgi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ jobs:
run_cli_cpu_py_tgi_tests:
runs-on: ubuntu-latest
steps:
- name: Free disk space
uses: jlumbroso/free-disk-space@main
with:
dotnet: true
android: true
haskell: true
docker-images: true
large-packages: false

- name: Checkout
uses: actions/checkout@v3

Expand Down
94 changes: 94 additions & 0 deletions optimum_benchmark/backends/llm_swarm/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import asyncio
import gc
from logging import getLogger
from typing import Any, Dict, List

import torch
from huggingface_hub import AsyncInferenceClient
from llm_swarm import LLMSwarm
from llm_swarm import LLMSwarmConfig as LLMSwarmCfg

from ...task_utils import TEXT_GENERATION_TASKS
from ..base import Backend
from .config import LLMSwarmConfig

# bachend logger
LOGGER = getLogger("llm-swarm")


class LLMSwarmBackend(Backend[LLMSwarmConfig]):
NAME: str = "llm-swarm"

def __init__(self, config: LLMSwarmConfig) -> None:
super().__init__(config)
self.validate_task()

LOGGER.info("\t+ Downloading pretrained model")
self.download_pretrained_model()
LOGGER.info("\t+ Preparing generation config")
self.prepare_generation_config()
LOGGER.info("\t+ Loading pretrained model")
self.load_model_from_pretrained()

def validate_task(self) -> None:
if self.config.task not in TEXT_GENERATION_TASKS:
raise NotImplementedError(f"LLM Swarm does not support task {self.config.task}")

def load_model_from_pretrained(self) -> None:
self.llm_swarm_config = LLMSwarmCfg(
gpus=self.config.gpus,
model=self.config.model,
instances=self.config.instances,
inference_engine=self.config.inference_engine,
slurm_template_path=self.config.slurm_template_path,
load_balancer_template_path=self.config.load_balancer_template_path,
per_instance_max_parallel_requests=self.config.per_instance_max_parallel_requests,
revision=self.config.hub_kwargs.get("revision", "main"),
debug_endpoint=self.config.debug_endpoint,
)
self.llm_swarm = LLMSwarm(self.llm_swarm_config).__enter__()
self.client = AsyncInferenceClient(self.llm_swarm.endpoint)

def download_pretrained_model(self) -> None:
with torch.device("meta"):
self.automodel_class.from_pretrained(self.config.model, **self.config.hub_kwargs)

def prepare_generation_config(self) -> None:
self.generation_config.eos_token_id = -100
self.generation_config.pad_token_id = -100
model_cache_folder = f"models/{self.config.model}".replace("/", "--")
model_cache_path = f"{self.config.volume}/{model_cache_folder}"
snapshot_file = f"{model_cache_path}/refs/{self.config.hub_kwargs.get('revision', 'main')}"
snapshot_ref = open(snapshot_file, "r").read().strip()
model_snapshot_path = f"{model_cache_path}/snapshots/{snapshot_ref}"
LOGGER.info("\t+ Saving new pretrained generation config")
self.generation_config.save_pretrained(save_directory=model_snapshot_path)

def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
if "inputs" in inputs:
return {"prompt": self.pretrained_processor.batch_decode(inputs["inputs"].tolist())}
elif "input_ids" in inputs:
return {"prompt": self.pretrained_processor.batch_decode(inputs["input_ids"].tolist())}
else:
raise ValueError("inputs must contain either input_ids or inputs")

async def single_client_call(self, prompt: str, kwargs: Dict[str, Any]) -> str:
return await self.client.text_generation(prompt, max_new_tokens=kwargs.get("max_new_tokens", 1))

async def batch_client_call(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return await asyncio.gather(*(self.single_client_call(p, kwargs) for p in inputs["prompt"]))

def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return asyncio.run(self.batch_client_call(inputs, kwargs))

def generate(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> List[str]:
return asyncio.run(self.batch_client_call(inputs, kwargs))

def clean(self) -> None:
super().clean()

if hasattr(self, "llm_swarm"):
LOGGER.info("Cleaning up LLM Swarm")
self.llm_swarm.__exit__(None, None, None)

gc.collect()
31 changes: 31 additions & 0 deletions optimum_benchmark/backends/llm_swarm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from dataclasses import dataclass
from typing import Optional

from ...import_utils import llm_swarm_version
from ..config import BackendConfig


@dataclass
class LLMSwarmConfig(BackendConfig):
name: str = "llm-swarm"
version: Optional[str] = llm_swarm_version()
_target_: str = "optimum_benchmark.backends.llm_swarm.backend.LLMSwarmBackend"

# optimum benchmark specific
no_weights: bool = False

# llm-swarm specific
gpus: int = 8
instances: int = 1
inference_engine: str = "tgi"
volume: str = "/fsx/ilyas/.cache"
per_instance_max_parallel_requests: int = 500
slurm_template_path: str = "/fsx/ilyas/swarm-templates/tgi_h100.template.slurm"
load_balancer_template_path: str = "/fsx/ilyas/swarm-templates/nginx.template.conf"
debug_endpoint: Optional[str] = None

def __post_init__(self):
super().__post_init__()

# so that downloaded artifacts are stored in the same place
self.hub_kwargs["cache_dir"] = self.volume
2 changes: 2 additions & 0 deletions optimum_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig, OmegaConf

from .backends.llm_swarm.config import LLMSwarmConfig
from .backends.neural_compressor.config import INCConfig
from .backends.onnxruntime.config import ORTConfig
from .backends.openvino.config import OVConfig
Expand Down Expand Up @@ -34,6 +35,7 @@
cs.store(group="backend", name=TRTLLMConfig.name, node=TRTLLMConfig)
cs.store(group="backend", name=INCConfig.name, node=INCConfig)
cs.store(group="backend", name=PyTGIConfig.name, node=PyTGIConfig)
cs.store(group="backend", name=LLMSwarmConfig.name, node=LLMSwarmConfig)
# benchmarks configurations
cs.store(group="benchmark", name=TrainingConfig.name, node=TrainingConfig)
cs.store(group="benchmark", name=InferenceConfig.name, node=InferenceConfig)
Expand Down
2 changes: 2 additions & 0 deletions optimum_benchmark/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport:
launcher: Launcher = launcher_factory(launcher_config)
except Exception as e:
LOGGER.error(f"Error during launcher allocation: {e}")
os.chdir(original_dir)
tmpdir.cleanup()
raise e

Expand All @@ -184,6 +185,7 @@ def launch(experiment_config: ExperimentConfig) -> BenchmarkReport:
output = launcher.launch(run, benchmark_config, backend_config)
except Exception as e:
LOGGER.error(f"Error during experiment launching: {e}")
os.chdir(original_dir)
tmpdir.cleanup()
raise e

Expand Down
10 changes: 10 additions & 0 deletions optimum_benchmark/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
_optimum_benchmark_available = importlib.util.find_spec("optimum_benchmark") is not None
_py_tgi_available = importlib.util.find_spec("py_tgi") is not None
_pyrsmi_available = importlib.util.find_spec("pyrsmi") is not None
_llm_swarm_available = importlib.util.find_spec("llm_swarm") is not None


def is_llm_swarm_available():
return _llm_swarm_available


def is_pyrsmi_available():
Expand Down Expand Up @@ -198,6 +203,11 @@ def py_tgi_version():
return importlib.metadata.version("py_tgi")


def llm_swarm_version():
if _llm_swarm_available:
return importlib.metadata.version("llm_swarm")


def get_git_revision_hash(package_name: str) -> Optional[str]:
"""
Returns the git commit SHA of a package installed from a git repository.
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
"onnxruntime-gpu": [f"optimum[onnxruntime-gpu]>={MIN_OPTIMUM_VERSION}"],
"neural-compressor": [f"optimum[neural-compressor]>={MIN_OPTIMUM_VERSION}"],
"torch-ort": ["torch-ort", "onnxruntime-training", f"optimum>={MIN_OPTIMUM_VERSION}"],
# docker-based backends
# other backends
"llm-swarm": ["llm-swarm@git+https://github.com/huggingface/llm-swarm.git"],
"py-tgi": ["py-tgi==0.1.3"],
# third-party features
# optional dependencies
"codecarbon": ["codecarbon"],
"deepspeed": ["deepspeed"],
"diffusers": ["diffusers"],
Expand Down
Loading