Skip to content

Commit

Permalink
feat: add nccl support for multi-gpu tensor parallelism (#91)
Browse files Browse the repository at this point in the history
* first commit

* first commit

* add tests mod

* first commit

* refactor repository

* refactor the llm service logic to be able to communicate with the axum service

* fmt

* message content format and parse RequestBody into GenerateRequest

* config comments

* minor mods

* add unit tests for messages to prompt

* improve docs, resolve clippy, add remaining logic to handle responses back to the user

* refactor tests

* refactor tests for llm

* first commit

* llama-nccl

* add clap args

* add features derive to clap

* remove comments

* resolve few issues with finished reason parsing

* address PR comments

* add llama models enums

* add llama models enums

* add llama models enums

* correct meta hf string

* correct meta hf string

* add changes

* handle compilation issues

* update candle versions

* clippy

* fix compilation issues

* minor changes

* add llama_nccl to vllm

* add changes for compilation

* new changes

* adjust engine tests

* resolve bug

* solve a few issues

* add changes

* resolve a few minor bugs and adds info logs for cache engine to load times

* add nccl feature

* update features on server crate

* update features dependencies on backends with nccl

* add changes

* address PR comments

* add small changes

* add small changes

* add small changes

* remove unnecessary feature

* remove unnecessary feature flags from code

* remove unnecessary feature flags from code

* add changes

* add imports

* add imports

* add imports

* add feature gating to llama tests

* only allocate device ids memory

* only allocate device ids memory
  • Loading branch information
jorgeantonio21 authored Oct 4, 2024
1 parent 986c723 commit 036cc85
Show file tree
Hide file tree
Showing 22 changed files with 1,560 additions and 214 deletions.
11 changes: 11 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ candle-examples = "0.7.2"
clap = "4.5.18"
config = "0.14.0"
csrc = { path = "csrc" }
cudarc = { version = "0.12.0", features = [
"std",
"cublas",
"cublaslt",
"curand",
"driver",
"nvrtc",
"f16",
"cuda-version-from-build-system",
"dynamic-linking",
], default-features = false }
cuda-runtime-sys = "0.3.0-alpha.1"
dotenv = "0.15.0"
expect-test = "1.5"
Expand Down
1 change: 1 addition & 0 deletions backends/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ serde = { workspace = true, features = ["derive"] }

[features]
vllm = ["dep:atoma-vllm-backend"]
nccl = ["vllm", "atoma-vllm-backend/nccl"]
3 changes: 3 additions & 0 deletions backends/vllm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ candle-core = { workspace = true, features = ["cuda"] }
candle-nn = { workspace = true, features = ["cuda"] }
candle-transformers = { workspace = true, features = ["cuda"] }
config.workspace = true
cudarc = { workspace = true, optional = true }
cuda-runtime-sys.workspace = true
dotenv.workspace = true
futures.workspace = true
Expand All @@ -29,3 +30,5 @@ rand.workspace = true
tracing-subscriber.workspace = true
tokenizers = { workspace = true, features = ["http"] }

[features]
nccl = ["dep:cudarc", "cudarc/nccl", "models/nccl"]
22 changes: 8 additions & 14 deletions backends/vllm/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ impl CacheConfig {
num_kv_heads: usize,
hidden_dim: usize,
num_hidden_layers: usize,
device_ids: &[usize],
) -> Result<Self, CacheConfigError> {
let builder = Config::builder().add_source(config::File::with_name(
config_file_path.as_ref().to_str().unwrap(),
Expand Down Expand Up @@ -199,6 +200,7 @@ impl CacheConfig {
num_kv_heads,
hidden_dim,
dtype,
device_ids,
)?;
this.num_gpu_blocks = Some(num_gpu_blocks);
}
Expand Down Expand Up @@ -583,24 +585,16 @@ pub(crate) mod utils {
num_kv_heads: usize,
hidden_dim: usize,
dtype: DType,
device_ids: &[usize],
) -> Result<usize, CacheConfigError> {
unsafe {
let mut device_count = 0;
let result = cudaGetDeviceCount(&mut device_count);
if result != cudaError::cudaSuccess || device_count == 0 {
return Err(CacheConfigError::GpuMemoryQueryError(format!(
"Failed to get device count: {:?}",
result
)));
}

let mut per_device_memory = Vec::with_capacity(device_count as usize);
for device in 0..device_count {
let result = cudaSetDevice(device);
let mut per_device_memory = Vec::with_capacity(device_ids.len());
for device_id in device_ids.iter() {
let result = cudaSetDevice(*device_id as i32);
if result != cudaError::cudaSuccess {
return Err(CacheConfigError::GpuMemoryQueryError(format!(
"Failed to set device {}: {:?}",
device, result
device_id, result
)));
}

Expand All @@ -610,7 +604,7 @@ pub(crate) mod utils {
if result != cudaError::cudaSuccess {
return Err(CacheConfigError::GpuMemoryQueryError(format!(
"Failed to get memory info for device {}: {:?}",
device, result
device_id, result
)));
}

Expand Down
7 changes: 4 additions & 3 deletions backends/vllm/src/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ impl LlmEngine {
let response_sender = self
.response_senders
.remove(&request_output.request_id)
.ok_or(EngineError::SendResponseError(
format!("Failed to get response sender for request with id = {}", request_output.request_id),
))?;
.ok_or(EngineError::SendResponseError(format!(
"Failed to get response sender for request with id = {}",
request_output.request_id
)))?;
response_sender
.send(request_output)
.map_err(|out| EngineError::SendResponseError(out.request_id))?;
Expand Down
30 changes: 20 additions & 10 deletions backends/vllm/src/llm_service.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use std::{path::Path, str::FromStr, time::Instant};
use std::{path::Path, str::FromStr, sync::Arc, time::Instant};

use crate::{
config::{
CacheConfig, CacheConfigError, ModelConfig, SchedulerConfig, SchedulerConfigError,
ValidationConfig,
},
llm_engine::{EngineError, GenerateRequestOutput, LlmEngine},
model_executor::{ModelExecutor, ModelLoaderError, ModelThreadDispatcher, ModelThreadError},
model_executor::{
Config, ConfigError, ModelExecutor, ModelLoaderError, ModelThreadDispatcher,
ModelThreadError,
},
scheduler::{Scheduler, SchedulerError},
sequence::{Sequence, SequenceError, SequenceGroup},
tokenizer::{TokenizerError, TokenizerWorker},
types::GenerateRequest,
validation::{ValidGenerateRequest, Validation, ValidationError},
};
use candle_core::{DType, DTypeParseError, Device, Error as CandleError};
use candle_core::{DType, DTypeParseError, Error as CandleError};
use candle_transformers::generation::{LogitsProcessor, Sampling};
use metrics::{counter, gauge};
use thiserror::Error;
Expand Down Expand Up @@ -91,20 +94,24 @@ impl LlmService {
model_config.model_name.clone(),
model_config.revision.clone(),
)?;
let config = M::C::from_file_path(&file_paths.config_path)?;
let tokenizer = Tokenizer::from_file(&file_paths.tokenizer_path)?;
// NOTE: we load the model on GPU memory, as to properly compute the number of blocks
// during the system profiling stage. See `compute_num_gpu_blocks` comments
// in the file `config.rs` for more details.
// TODO: support multi-GPUs
let device = Device::new_cuda(model_config.device_ids[0])?;
let devices_ids = model_config.device_ids.clone();
let dtype = DType::from_str(&model_config.dtype)?;
let model = M::load(device.clone(), dtype, &file_paths)?;
let num_kv_heads = config.num_kv_heads();
let hidden_dim = config.hidden_dim();
let num_hidden_layers = config.num_hidden_layers();

let cache_config = CacheConfig::from_file_path(
config_path.as_ref(),
model.num_kv_heads(),
model.hidden_dim(),
model.num_hidden_layers(),
num_kv_heads,
hidden_dim,
num_hidden_layers,
&devices_ids,
)?;
let scheduler_config = SchedulerConfig::from_file_path(config_path.as_ref())?;
let validation_config = ValidationConfig::from_file_path(config_path.as_ref());
Expand Down Expand Up @@ -136,9 +143,10 @@ impl LlmService {

let model_thread_dispatcher = ModelThreadDispatcher::start::<M>(
cache_config,
device,
config,
devices_ids,
dtype,
model,
Arc::new(file_paths),
scheduler_config,
)?;

Expand Down Expand Up @@ -338,6 +346,8 @@ pub enum LlmServiceError {
CacheConfigError(#[from] CacheConfigError),
#[error("Candle error: `{0}`")]
CandleError(#[from] CandleError),
#[error("Config error: `{0}`")]
ConfigError(#[from] ConfigError),
#[error("DType parse error: `{0}`")]
DTypeParseError(#[from] DTypeParseError),
#[error("Model loader error: `{0}`")]
Expand Down
Loading

0 comments on commit 036cc85

Please sign in to comment.