Skip to content

Commit

Permalink
fix: rewrite quantize Q4_0_4X4
Browse files Browse the repository at this point in the history
  • Loading branch information
yirongjie committed Jul 29, 2024
1 parent 1bbb173 commit 855e425
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 2 deletions.
1 change: 1 addition & 0 deletions examples/demo_llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ int main(int argc, char **argv) {
chatPostProcessing(out_token, input_tensor, {});
}
printf("\n");
model.profiling();
}

return 0;
Expand Down
7 changes: 7 additions & 0 deletions src/backends/cpu/compute/GEMM_AArch64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2201,4 +2201,11 @@ void quantize_row_q4_0_4x4(const float * __restrict x, void * __restrict y, int
assert(k%QK4_0 == 0);
std::cout<<"Quantize 4x4:"<<k<<"/4096="<<k/4096;
auto size = quantize_q4_0_nr_bl(x, y, k/4096, 4096, 4, 4);
}


void quantize_row_q4_0_4x4(const float * __restrict x, void * __restrict y, int k, int raw){
assert(k%QK4_0 == 0);
std::cout<<"Quantize 4x4:"<<k<<"/"<<raw<<"="<<k/raw;
auto size = quantize_q4_0_nr_bl(x, y, k/raw, raw, 4, 4);
}
1 change: 1 addition & 0 deletions src/backends/cpu/compute/GEMM_AArch64.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ void mllm_gemm_q4_0_4x8_q8_0(int n, float * __restrict s, size_t bs, const void
void mllm_gemm_q4_0_8x8_q8_0(int n, float * __restrict s, size_t bs, const void * __restrict vx, const void * __restrict vy, int nr, int nc);

void quantize_row_q4_0_4x4(const float * __restrict x, void * __restrict y, int k);
void quantize_row_q4_0_4x4(const float * __restrict x, void * __restrict y, int k, int raw);

#endif // MLLM_GEMM_HPP
65 changes: 64 additions & 1 deletion src/quantizer/QuantWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ vector<string> fp32_layers = {"norm", "rope", "bias","rotary_emb", "embed_tokens
"modality_preprocessors", "modality_heads", "modality_postprocessors", "pre_transformer_layer"};
vector<string> q6_layers = {"w2", "wv", "dense_h_to_4h", "v_proj", "down_proj"};
vector<string> q4x4_2_q4_layers = {"w2", "wv", "dense_h_to_4h", "v_proj", "down_proj"};
vector<string> q4x4_2_q4_layers_ = {"wv", "v_proj"};

int tmp_hidden_dim = -1;

Expand Down Expand Up @@ -128,7 +129,7 @@ void QuantWriter::quantParams(DataType dataType) {
break;
case MLLM_TYPE_Q4_0_4_4:
std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_Q4_0) << "\t";
block_t = alloc_quant_block(size, dataType);
block_t = alloc_quant_block(size, MLLM_TYPE_Q4_0);
quant_ptr = block_t.first;
quantize_row_q4_0(param, quant_ptr, size);
size = block_t.second;
Expand Down Expand Up @@ -268,6 +269,68 @@ void QuantWriter::quantParams(DataType dataType) {
}
writeIndex();
}

void QuantWriter::quantParams_q4_(DataType dataType) {
for (const auto &name : param_names_) {
// int force_quant_type = -1;
auto size = param_loader_->offsets_[name].second / sizeof(float);
if(find_names(name, {"norm"})) {
tmp_hidden_dim = size;
}
}
quant_type_ = dataType;
for (const auto &name : param_names_) {
// int force_quant_type = -1;
auto *param = getParam(name);
if (param == nullptr) {
__exit(-1);
}
auto size = param_loader_->offsets_[name].second / sizeof(float);
if(find_names(name, {"norm"})) {
tmp_hidden_dim = size;
}
void *quant_ptr = nullptr;
std::pair<void *, uint64_t> block_t;
if(find_names(name, fp32_layers)) {
std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_F32) << "\t";
const auto s = param_loader_->offsets_[name].second / sizeof(float);
const auto tsize = alloc_quant_block(s, MLLM_TYPE_F32).second;
writeParam(name, MLLM_TYPE_F32, param, tsize);
std::cout << " size:" << tsize << std::endl;
}else if (find_names(name, q4x4_2_q4_layers_)) {
// std::cout<<"q4x4_2_q4_layers"<<std::endl;
std::cout << "Quantize param " << name << " to " << DataTypeName(MLLM_TYPE_Q4_0) << "\t";
block_t = alloc_quant_block(size, MLLM_TYPE_Q4_0);
quant_ptr = block_t.first;
quantize_row_q4_0(param, quant_ptr, size);
size = block_t.second;
if (quant_ptr != nullptr) {
writeParam(name, MLLM_TYPE_Q4_0, quant_ptr, size);
std::cout << " size:" << size <<" type:"<< DataTypeName(MLLM_TYPE_Q4_0)<< std::endl;
}
}else {
std::cout << "Quantize param " << name << " to " << DataTypeName(dataType) << "\t";
block_t = alloc_quant_block(size, dataType);
quant_ptr = block_t.first;
int tmp_hidden_dim_q4 = tmp_hidden_dim;
if (find_names(name, {"w2"})){
tmp_hidden_dim_q4 =(size/tmp_hidden_dim);
}
quantize_row_q4_0_4x4(param, quant_ptr, size, tmp_hidden_dim_q4);
size = block_t.second;
if (quant_ptr != nullptr) {
writeParam(name, quant_type_, quant_ptr, size);
std::cout << " size:" << size << std::endl;
}
// writeParam(name, quant_type_, quant_ptr, size);
#ifndef TEST
delete[] (char *)quant_ptr;
#endif
}
}
writeIndex();
}

void QuantWriter::writeParam(string name, DataType type, void *data, uint64_t size) {
#ifdef TEST
data_[name] = (char *)data;
Expand Down
1 change: 1 addition & 0 deletions src/quantizer/QuantWriter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class QuantWriter : public ParamWriter {
explicit QuantWriter(std::string output_path, std::string input_path);
int readParams();
void quantParams(DataType dataType);
void quantParams_q4_(DataType dataType);

#ifdef TEST
std::unordered_map<string, char *> data_;
Expand Down
2 changes: 1 addition & 1 deletion src/quantizer/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ int main(int argc, char **argv) {
} else if (quant_type == "Q8_K") {
quant_writer.quantParams(MLLM_TYPE_Q8_K);
} else if (quant_type == "Q4_0_4_4") {
quant_writer.quantParams(MLLM_TYPE_Q4_0_4_4);
quant_writer.quantParams_q4_(MLLM_TYPE_Q4_0_4_4);
} else {
std::cout << "Quant type " << quant_type << " is not supported\n";
return -1;
Expand Down

0 comments on commit 855e425

Please sign in to comment.