diff --git a/src/backends/cpu/CPURoPE.cpp b/src/backends/cpu/CPURoPE.cpp index e47c740e..2b9e2932 100644 --- a/src/backends/cpu/CPURoPE.cpp +++ b/src/backends/cpu/CPURoPE.cpp @@ -140,7 +140,52 @@ void CPURoPE::rope_llama(shared_ptr input, shared_ptr output){ void CPURoPE::rope_hf(shared_ptr input, shared_ptr output){ auto out_dtype = output->dtype(); int partial_dimension = (input->dimension()) * partial_rotary_factor_; + int half = (int)(partial_dimension / 2); assert(partial_dimension%2==0); + if(output->ctype() == BSHD){ + if (out_dtype == MLLM_TYPE_F32){ +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension/2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = value; + o[half] = value2; + } + } + } + } + }else if(out_dtype == MLLM_TYPE_F16){ +#pragma omp parallel for collapse(4) num_threads(thread_count) + for (int n = 0; n < input->batch(); ++n) { + for (int h = 0; h < input->head(); ++h) { + for (int s = 0; s < input->sequence(); ++s) { // sequance + for (int d = 0; d < partial_dimension/2; ++d) { + auto v = input->ptrAt(n, h, s, d); + auto o = output->ptrAt(n, h, s, d); + float in_value = v[0]; + float in_value_2 = v[half]; + float sin_value = sin_[s + h_cnt_][d]; + float cos_value = cos_[s + h_cnt_][d]; + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + o[0] = MLLM_FP32_TO_FP16(value); + o[half] = MLLM_FP32_TO_FP16(value2); + } + } + } + } + } + return; + } #pragma omp parallel for collapse(4) num_threads(thread_count) for (int n = 0; n < input->batch(); ++n) { for (int h = 0; h < input->head(); ++h) { @@ -152,10 +197,10 @@ void CPURoPE::rope_hf(shared_ptr input, shared_ptr output){ float cos_value = cos_[s + h_cnt_][d]; auto value = in_value * cos_value - in_value_2 * sin_value; auto value2 = in_value * sin_value + in_value_2 * cos_value; - if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F32) { + if (out_dtype == MLLM_TYPE_F32) { output->setDataAt(n, h, s, d, value); output->setDataAt(n, h, s, d+ partial_dimension / 2, value2); - } else if (output->dtypeAt(n, h, s, d) == MLLM_TYPE_F16) { + } else if (out_dtype == MLLM_TYPE_F16) { output->setDataAt(n, h, s, d, MLLM_FP32_TO_FP16(value)); output->setDataAt(n, h, s, d+ partial_dimension / 2, MLLM_FP32_TO_FP16(value2)); }