Skip to content

Commit

Permalink
basic support for symmetric smoothquant model
Browse files Browse the repository at this point in the history
  • Loading branch information
rasmith committed Oct 18, 2024
1 parent 968345a commit 90a3e0f
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
42 changes: 28 additions & 14 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,15 @@ __global__ void static_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type const* scale_ptr, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) / scale);
out[i] = float_to_int8_rn(static_cast<float>(input[i]) / scale);
}
}

Expand All @@ -111,14 +114,18 @@ __global__ void static_scaled_int8_azp_quant_kernel(
scale_type const* scale_ptr, azp_type const* azp_ptr,
const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
scale_type const scale = *scale_ptr;
azp_type const azp = *azp_ptr;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

for (int i = tid; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const val = static_cast<float>(input[i]);
auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp);
out[token_idx * hidden_size + i] = quant_val;
out[i] = quant_val;
}
}

Expand All @@ -127,12 +134,16 @@ __global__ void dynamic_scaled_int8_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, const int hidden_size) {
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;
float absmax_val = 0.0f;
float const zero = 0.0f;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = static_cast<float>(input[token_idx * hidden_size + i]);
float val = static_cast<float>(input[i]);
val = val > zero ? val : -val;
absmax_val = val > absmax_val ? val : absmax_val;
}
Expand All @@ -150,22 +161,25 @@ __global__ void dynamic_scaled_int8_quant_kernel(

float const tmp_scale = 127.0f / block_absmax_val;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] = float_to_int8_rn(
static_cast<float>(input[token_idx * hidden_size + i]) * tmp_scale);
out[i] = float_to_int8_rn(static_cast<float>(input[i]) * tmp_scale);
}
}

template <typename scalar_t, typename scale_type, typename azp_type>
__global__ void dynamic_scaled_int8_azp_quant_kernel(
scalar_t const* __restrict__ input, int8_t* __restrict__ out,
scale_type* scale, azp_type* azp, const int hidden_size) {
int const token_idx = blockIdx.x;
int64_t const token_idx = blockIdx.x;

// Must be performed using 64-bit math to avoid integer overflow.
out += token_idx * hidden_size;
input += token_idx * hidden_size;

// Scan for the min and max value for this token
float max_val = std::numeric_limits<float>::min();
float min_val = std::numeric_limits<float>::max();
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto val = static_cast<float>(input[token_idx * hidden_size + i]);
auto val = static_cast<float>(input[i]);
max_val = std::max(max_val, val);
min_val = std::min(min_val, val);
}
Expand Down Expand Up @@ -200,10 +214,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(

// Quantize the values
for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) {
auto const val = static_cast<float>(input[token_idx * hidden_size + i]);
auto const val = static_cast<float>(input[i]);
auto const quant_val =
int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val);
out[token_idx * hidden_size + i] = quant_val;
out[i] = quant_val;
}
}

Expand Down
11 changes: 8 additions & 3 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ def selective_scan_fwd_fake(u: torch.Tensor, delta: torch.Tensor,
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool:
return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability)


def cutlass_scaled_mm(a: torch.Tensor,
b: torch.Tensor,
scale_a: torch.Tensor,
Expand All @@ -531,9 +530,15 @@ def cutlass_scaled_mm(a: torch.Tensor,

m = a.shape[0]
n = b.shape[1]
out = torch.empty((m, n), dtype=out_dtype, device=a.device)

torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
if is_hip():
out = torch.mm(a.to(torch.float32), b.to(torch.float32))
out = scale_a * out
out = scale_b.T * out
out = out.to(out_dtype)
else:
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)

return out

Expand Down

0 comments on commit 90a3e0f

Please sign in to comment.