diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index e3a45b26d909b..71362299a9fcf 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -82,17 +82,15 @@ def __init__( self.router = DbrxRouter(config, self.params_dtype) self.ws = nn.Parameter( - torch.empty( - self.num_total_experts, - 2 * self.intermediate_size, - self.d_model, - dtype=self.params_dtype)) + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.d_model, + dtype=self.params_dtype)) self.w2s = nn.Parameter( - torch.empty( - self.num_total_experts, - self.d_model, - self.intermediate_size, - dtype=self.params_dtype)) + torch.empty(self.num_total_experts, + self.d_model, + self.intermediate_size, + dtype=self.params_dtype)) set_weight_attrs( self.ws,