Skip to content

Commit

Permalink
fix: rope
Browse files Browse the repository at this point in the history
  • Loading branch information
yirongjie committed Jul 30, 2024
1 parent 04f01b7 commit 6e941b5
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions src/backends/cpu/CPURoPE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,52 @@ void CPURoPE::rope_llama(shared_ptr<Tensor> input, shared_ptr<Tensor> output){
void CPURoPE::rope_hf(shared_ptr<Tensor> input, shared_ptr<Tensor> output){
auto out_dtype = output->dtype();
int partial_dimension = (input->dimension()) * partial_rotary_factor_;
int half = (int)(partial_dimension / 2);
assert(partial_dimension%2==0);
if(output->ctype() == BSHD){
if (out_dtype == MLLM_TYPE_F32){
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (int n = 0; n < input->batch(); ++n) {
for (int h = 0; h < input->head(); ++h) {
for (int s = 0; s < input->sequence(); ++s) { // sequance
for (int d = 0; d < partial_dimension/2; ++d) {
auto v = input->ptrAt<float>(n, h, s, d);
auto o = output->ptrAt<float>(n, h, s, d);
float in_value = v[0];
float in_value_2 = v[half];
float sin_value = sin_[s + h_cnt_][d];
float cos_value = cos_[s + h_cnt_][d];
auto value = in_value * cos_value - in_value_2 * sin_value;
auto value2 = in_value * sin_value + in_value_2 * cos_value;
o[0] = value;
o[half] = value2;
}
}
}
}
}else if(out_dtype == MLLM_TYPE_F16){
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (int n = 0; n < input->batch(); ++n) {
for (int h = 0; h < input->head(); ++h) {
for (int s = 0; s < input->sequence(); ++s) { // sequance
for (int d = 0; d < partial_dimension/2; ++d) {
auto v = input->ptrAt<float>(n, h, s, d);
auto o = output->ptrAt<mllm_fp16_t>(n, h, s, d);
float in_value = v[0];
float in_value_2 = v[half];
float sin_value = sin_[s + h_cnt_][d];
float cos_value = cos_[s + h_cnt_][d];
auto value = in_value * cos_value - in_value_2 * sin_value;
auto value2 = in_value * sin_value + in_value_2 * cos_value;
o[0] = MLLM_FP32_TO_FP16(value);
o[half] = MLLM_FP32_TO_FP16(value2);
}
}
}
}
}
return;
}
#pragma omp parallel for collapse(4) num_threads(thread_count)
for (int n = 0; n < input->batch(); ++n) {
for (int h = 0; h < input->head(); ++h) {
Expand All @@ -152,10 +197,10 @@ void CPURoPE::rope_hf(shared_ptr<Tensor> input, shared_ptr<Tensor> output){
float cos_value = cos_[s + h_cnt_][d];
auto value = in_value * cos_value - in_value_2 * sin_value;
auto value2 = in_value * sin_value + in_value_2 * cos_value;
if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) {
if (out_dtype == MLLM_TYPE_F32) {
output->setDataAt<float>(n, h, s, d, value);
output->setDataAt<float>(n, h, s, d+ partial_dimension / 2, value2);
} else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) {
} else if (out_dtype == MLLM_TYPE_F16) {
output->setDataAt<mllm_fp16_t>(n, h, s, d, MLLM_FP32_TO_FP16(value));
output->setDataAt<mllm_fp16_t>(n, h, s, d+ partial_dimension / 2, MLLM_FP32_TO_FP16(value2));
}
Expand Down

0 comments on commit 6e941b5

Please sign in to comment.