Skip to content

Commit

Permalink
more custom binary
Browse files Browse the repository at this point in the history
  • Loading branch information
grzuy committed Jan 16, 2024
1 parent c130e88 commit baf52ec
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
14 changes: 13 additions & 1 deletion native/candlex/src/metal_kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,19 @@ pub mod custom_unary {
}

pub mod custom_binary {
ops!(logical_and, logical_or, logical_xor);
ops!(
atan2,
bit_and,
bit_or,
bit_xor,
logical_and,
logical_or,
logical_xor,
pow,
remainder,
shl,
shr
);
}

#[derive(Debug)]
Expand Down
10 changes: 10 additions & 0 deletions native/candlex/src/metal_kernels/custom_binary.metal
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ kernel void FN_NAME( \
output[tid] = OUT_TYPENAME(FN); \
}

CUSTOM_BINARY(int64_t, int64_t, bit_and_i64, x & y)
CUSTOM_BINARY(int64_t, int64_t, bit_or_i64, x | y)
CUSTOM_BINARY(int64_t, int64_t, bit_xor_i64, x ^ y)

CUSTOM_BINARY(float, float, atan2_f32, atan2(x, y))

/* pow */
/* remainder */
/* shl */
/* shr */

CUSTOM_BINARY(int64_t, uint8_t, logical_and_i64, x && y)
CUSTOM_BINARY(uint8_t, uint8_t, logical_and_u8, x && y)
Expand Down
73 changes: 57 additions & 16 deletions native/candlex/src/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,12 @@ macro_rules! custom_unary_bool_op {
}

macro_rules! custom_binary_op {
($struct_name:ident, $name:literal, $cpu_closure:expr, ($($dtypes:ident),+)) => {
($struct_name:ident, $name:ident, $cpu_closure:expr, ($($dtypes:ident),+)) => {
pub(crate) struct $struct_name;

impl CustomOp2 for $struct_name {
fn name(&self) -> &'static str {
$name
stringify!($name)
}

/// The forward pass, as run on a cpu device. Note that the storage can use arbitrary strides,
Expand Down Expand Up @@ -331,7 +331,7 @@ macro_rules! custom_binary_op {
.w()?;
let src1 = src1.slice(layout1.start_offset()..);
let src2 = src2.slice(layout2.start_offset()..);
let func = device.get_or_load_func(&kernel_name::<T>($name), kernels::CUSTOM_BINARY)?;
let func = device.get_or_load_func(&kernel_name::<T>(stringify!($name)), kernels::CUSTOM_BINARY)?;
// SAFETY: Set later by running the kernel.
let out = unsafe { device.alloc::<T>(elem_count1) }.w()?;
let params = (elem_count1, dims1.len(), &dims_and_strides, &src1, &src2, &out);
Expand All @@ -356,6 +356,52 @@ macro_rules! custom_binary_op {
)
)
}

#[cfg(feature = "metal")]
fn metal_fwd(
&self,
s1: &MetalStorage,
l1: &Layout,
s2: &MetalStorage,
l2: &Layout,
) -> Result<(MetalStorage, Shape), candle_core::Error> {
use crate::metal_kernels;
use candle_core::{backend::BackendStorage, DType};

if !(l1.is_contiguous() && l1.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported - l1");
}
if !(l2.is_contiguous() && l2.start_offset() == 0) {
candle_core::bail!("Non contiguous not supported - l2");
}

let device = s1.device();
let dtype = s1.dtype();
let shape = l1.shape();
let elem_count = shape.elem_count();
let command_buffer = device.command_buffer()?;
let output_buffer = device.new_buffer(elem_count, dtype, stringify!($name))?;

let kernel_name = match dtype {
DType::F32 => metal_kernels::custom_binary::contiguous::$name::FLOAT,
DType::I64 => metal_kernels::custom_binary::contiguous::$name::I64,
dtype => {
candle_core::bail!("Metal contiguous custom binary $name {dtype:?} not implemented")
}
};

metal_kernels::call_custom_binary_contiguous(
&device.device(),
&command_buffer,
kernel_name,
elem_count,
&s1.buffer(),
&s2.buffer(),
&output_buffer,
).unwrap();

Ok((MetalStorage::new(output_buffer, device.clone(), dtype), l1.shape().clone()))
}
}
}
}
Expand Down Expand Up @@ -536,19 +582,14 @@ custom_unary_op!(Tan, tan, |v| v.tan(), (BF16, F16, F32, F64));
custom_unary_bool_op!(IsInf, "is_inf", is_infinite, (F32, F64));
custom_unary_bool_op!(IsNan, "is_nan", is_nan, (F32, F64));

custom_binary_op!(BitAnd, "bit_and", |v1, v2| v1 & v2, (U32, I64));
custom_binary_op!(BitOr, "bit_or", |v1, v2| v1 | v2, (U32, I64));
custom_binary_op!(BitXor, "bit_xor", |v1, v2| v1 ^ v2, (U32, I64));
custom_binary_op!(Atan2, "atan2", |v1, v2| v1.atan2(v2), (F32, F64));
custom_binary_op!(Pow, "pow", |v1, v2| v1.powf(v2), (F32, F64));
custom_binary_op!(
Remainder,
"remainder",
|v1, v2| v1 % v2,
(U8, I64, F32, F64)
);
custom_binary_op!(Shl, "shl", |v1, v2| v1 << v2, (U32, I64));
custom_binary_op!(Shr, "shr", |v1, v2| v1 >> v2, (U32, I64));
custom_binary_op!(BitAnd, bit_and, |v1, v2| v1 & v2, (U32, I64));
custom_binary_op!(BitOr, bit_or, |v1, v2| v1 | v2, (U32, I64));
custom_binary_op!(BitXor, bit_xor, |v1, v2| v1 ^ v2, (U32, I64));
custom_binary_op!(Atan2, atan2, |v1, v2| v1.atan2(v2), (F32, F64));
custom_binary_op!(Pow, pow, |v1, v2| v1.powf(v2), (F32, F64));
custom_binary_op!(Remainder, remainder, |v1, v2| v1 % v2, (U8, I64, F32, F64));
custom_binary_op!(Shl, shl, |v1, v2| v1 << v2, (U32, I64));
custom_binary_op!(Shr, shr, |v1, v2| v1 >> v2, (U32, I64));
custom_binary_bool_op!(
LogicalAnd,
logical_and,
Expand Down

0 comments on commit baf52ec

Please sign in to comment.