Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient implementation of Tensor::ones() for metal #2512

Merged
merged 9 commits into from
Oct 1, 2024
36 changes: 32 additions & 4 deletions candle-core/src/metal_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1917,10 +1917,38 @@ impl BackendDevice for MetalDevice {
))
}

fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
// TODO Is there a faster way ?
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
self.storage_from_cpu_storage(&cpu_storage)
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
let name = match dtype {
DType::U8 => "fill_u8",
DType::U32 => "fill_u32",
DType::I64 => "fill_i64",
DType::F16 => "fill_f16",
DType::BF16 => "fill_bf16",
DType::F32 => "fill_f32",
DType::F64 => {
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
return self.storage_from_cpu_storage(&cpu_storage);
}
};
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
let command_buffer = self.command_buffer()?;
candle_metal_kernels::call_const_fill(
&self.device,
&command_buffer,
&self.kernels,
name,
shape.elem_count(),
&buffer,
1.,
)
.map_err(MetalError::from)?;

Ok(MetalStorage::new(
buffer,
self.clone(),
shape.elem_count(),
dtype,
))
}

fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
Expand Down
30 changes: 30 additions & 0 deletions candle-core/tests/tensor_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,36 @@ fn ones(device: &Device) -> Result<()> {
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
);
assert_eq!(
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
[
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
],
[
half::f16::from_f32(1.0),
half::f16::from_f32(1.0),
half::f16::from_f32(1.0)
]
],
);
assert_eq!(
Tensor::ones((2, 3), DType::BF16, device)?.to_vec2::<half::bf16>()?,
[
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
],
[
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0),
half::bf16::from_f32(1.0)
]
],
);
Ok(())
}

Expand Down
39 changes: 39 additions & 0 deletions candle-metal-kernels/src/fill.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <metal_stdlib>

using namespace metal;

template<typename T> METAL_FUNC void fill_with(
device T *out,
constant float &value,
constant size_t &numel,
uint tid [[thread_position_in_grid]]
) {
if (tid >= numel) {
return;
}
out[tid] = static_cast<T>(value);
AnubhabB marked this conversation as resolved.
Show resolved Hide resolved
}

#define FILL_OP(NAME, T) \
kernel void fill_##NAME( \
device T *out, \
constant float &value, \
constant size_t &numel, \
uint tid [[thread_position_in_grid]] \
) { \
fill_with<T>(out, value, numel, tid); \
} \


#define FILL_OPS(NAME, T) \
FILL_OP(NAME, T) \

FILL_OPS(u8, uchar)
FILL_OPS(u32, uint)
FILL_OPS(i64, long)
FILL_OPS(f16, half)
FILL_OPS(f32, float)

#if __METAL_VERSION__ >= 310
FILL_OPS(bf16, bfloat)
#endif
28 changes: 28 additions & 0 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ const AFFINE: &str = include_str!("affine.metal");
const BINARY: &str = include_str!("binary.metal");
const CAST: &str = include_str!("cast.metal");
const CONV: &str = include_str!("conv.metal");
const FILL: &str = include_str!("fill.metal");
const INDEXING: &str = include_str!("indexing.metal");
// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
Expand All @@ -31,6 +32,7 @@ pub enum Source {
Binary,
Cast,
Conv,
Fill,
Gemm,
Indexing,
Mfa,
Expand Down Expand Up @@ -196,6 +198,7 @@ impl Kernels {
Source::Binary => BINARY,
Source::Cast => CAST,
Source::Conv => CONV,
Source::Fill => FILL,
Source::Gemm => MLX_GEMM,
Source::Indexing => INDEXING,
Source::Quantized => QUANTIZED,
Expand Down Expand Up @@ -2357,5 +2360,30 @@ pub fn call_mlx_gemm(
Ok(())
}

pub fn call_const_fill(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
length: usize,
output: &Buffer,
v: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();

encoder.set_compute_pipeline_state(&pipeline);

set_params!(encoder, (output, v, length));

let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);

encoder.use_resource(output, metal::MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);

Ok(())
}

#[cfg(test)]
mod tests;
65 changes: 65 additions & 0 deletions candle-metal-kernels/src/tests.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use half::{bf16, f16};
use metal::MTLResourceOptions;
use rand::Rng;

fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
let ptr = buffer.contents() as *const T;
Expand Down Expand Up @@ -2307,3 +2308,67 @@ fn conv_transpose1d_u32() {
let expected = vec![1, 4, 10, 20, 25, 24, 16];
assert_eq!(results, expected);
}

fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
let dev = device();
let kernels = Kernels::new();
let command_queue = dev.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

let buffer = dev.new_buffer(
(len * std::mem::size_of::<T>()) as u64,
MTLResourceOptions::StorageModePrivate,
);

call_const_fill(&dev, command_buffer, &kernels, name, len, &buffer, value).unwrap();

command_buffer.commit();
command_buffer.wait_until_completed();

read_to_vec::<T>(&buffer, len)
}

#[test]
fn const_fill() {
let fills = [
"fill_u8",
"fill_u32",
"fill_i64",
"fill_f16",
"fill_bf16",
"fill_f32",
];

for name in fills {
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
let value = rand::thread_rng().gen_range(1. ..19.);

match name {
"fill_u8" => {
let v = constant_fill::<u8>(name, len, value);
assert_eq!(v, vec![value as u8; len])
}
"fill_u32" => {
let v = constant_fill::<u32>(name, len, value);
assert_eq!(v, vec![value as u32; len])
}
"fill_i64" => {
let v = constant_fill::<i64>(name, len, value);
assert_eq!(v, vec![value as i64; len])
}
"fill_f16" => {
let v = constant_fill::<f16>(name, len, value);
assert_eq!(v, vec![f16::from_f32(value); len])
}
"fill_bf16" => {
let v = constant_fill::<bf16>(name, len, value);
assert_eq!(v, vec![bf16::from_f32(value); len])
}
"fill_f32" => {
let v = constant_fill::<f32>(name, len, value);
assert_eq!(v, vec![value; len])
}
_ => unimplemented!(),
};
}
}
Loading