diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal index 82e1a0c7a..8cdf0b9d2 100644 --- a/LLama/runtimes/ggml-metal.metal +++ b/LLama/runtimes/ggml-metal.metal @@ -25,9 +25,9 @@ typedef struct { } block_q8_0; kernel void kernel_add( - device const float * src0, - device const float * src1, - device float * dst, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] + src1[tpig]; } @@ -35,18 +35,18 @@ kernel void kernel_add( // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( - device const float * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant int64_t & nb, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] + src1[tpig % ne00]; + dst[tpig] = src0[tpig] + src1[tpig % nb]; } kernel void kernel_mul( - device const float * src0, - device const float * src1, - device float * dst, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * src1[tpig]; } @@ -54,12 +54,12 @@ kernel void kernel_mul( // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_mul_row( - device const float * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, + device const float4 * src0, + device const float4 * src1, + device float4 * dst, + constant int64_t & nb, uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % ne00]; + dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_scale( @@ -528,24 +528,42 @@ kernel void kernel_mul_mat_f16_f32( device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); - sum[tpitg.x] = 0.0f; + uint ith = tpitg.x; + uint nth = tptg.x; + + sum[ith] = 0.0f; - for (int i = tpitg.x; i < ne00; i += tptg.x) { - sum[tpitg.x] += (float) x[i] * (float) y[i]; + for (int i = ith; i < ne00; i += nth) { + sum[ith] += (float) x[i] * (float) y[i]; } // accumulate the sum from all threads in the threadgroup threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = tptg.x/2; i > 0; i /= 2) { - if (tpitg.x < i) { - sum[tpitg.x] += sum[tpitg.x + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%4 == 0) { + for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; } - - if (tpitg.x == 0) { + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith%16 == 0) { + for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (ith == 0) { + for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; } + + // Original implementation. Left behind commented out for now + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (uint i = tptg.x/2; i > 0; i /= 2) { + // if (tpitg.x < i) { + // sum[tpitg.x] += sum[tpitg.x + i]; + // } + // threadgroup_barrier(mem_flags::mem_threadgroup); + //} + // + //if (tpitg.x == 0) { + // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0]; + //} } kernel void kernel_alibi_f32( diff --git a/LLama/runtimes/libllama-metal.dylib b/LLama/runtimes/libllama-metal.dylib index e9c2ee283..3095fcbfb 100755 Binary files a/LLama/runtimes/libllama-metal.dylib and b/LLama/runtimes/libllama-metal.dylib differ diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib index 53318c38c..c6bd44332 100755 Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ diff --git a/qodana.yaml b/qodana.yaml new file mode 100644 index 000000000..99a40de62 --- /dev/null +++ b/qodana.yaml @@ -0,0 +1,29 @@ +#-------------------------------------------------------------------------------# +# Qodana analysis is configured by qodana.yaml file # +# https://www.jetbrains.com/help/qodana/qodana-yaml.html # +#-------------------------------------------------------------------------------# +version: "1.0" + +#Specify inspection profile for code analysis +profile: + name: qodana.starter + +#Enable inspections +#include: +# - name: + +#Disable inspections +#exclude: +# - name: +# paths: +# - + +#Execute shell command before Qodana execution (Applied in CI/CD pipeline) +#bootstrap: sh ./prepare-qodana.sh + +#Install IDE plugins before Qodana execution (Applied in CI/CD pipeline) +#plugins: +# - id: #(plugin id can be found at https://plugins.jetbrains.com) + +#Specify Qodana linter for analysis (Applied in CI/CD pipeline) +linter: jetbrains/qodana-dotnet:latest