diff --git a/server/text_generation_server/models/paged_causal_lm.py b/server/text_generation_server/models/paged_causal_lm.py index 3524f738..c82fcd54 100644 --- a/server/text_generation_server/models/paged_causal_lm.py +++ b/server/text_generation_server/models/paged_causal_lm.py @@ -327,7 +327,7 @@ def __init__( model_config.num_attention_heads, model_config.hidden_size, kv_heads=model_config.num_key_value_heads, - tensor_parallel_size=1, + tensor_parallel_size=self.engine.world_size, dtype=dtype, device=self.device, total_num_gpu_blocks=total_num_gpu_blocks,