From 2b16d41567fc6328be8861c3365e9372c517a2cb Mon Sep 17 00:00:00 2001 From: Laurent Date: Sun, 27 Oct 2024 13:10:47 +0100 Subject: [PATCH] Add a dedicated test. --- candle-core/tests/custom_op_tests.rs | 30 ++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/candle-core/tests/custom_op_tests.rs b/candle-core/tests/custom_op_tests.rs index be59e0c0c..f2c01aca8 100644 --- a/candle-core/tests/custom_op_tests.rs +++ b/candle-core/tests/custom_op_tests.rs @@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> { ); Ok(()) } + +#[cfg(feature = "cuda")] +#[allow(clippy::approx_constant)] +#[test] +fn ug_op() -> Result<()> { + let kernel = { + use ug::lang::op; + + let layout = ug::Layout::from_shape(&[12]); + let ptr = op::Arg::ptr(ug::DType::F32); + let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?; + let src = op::unary(op::UnaryOp::Exp, src)?; + let st = op::store(ptr.id(), layout, src)?; + let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]); + let opts: ug::lower_op::Opts = Default::default(); + kernel.lower(&opts.with_global(0, 12))? + }; + let device = Device::new_cuda(0)?; + let op = candle_core::UgIOp1::new("test", kernel, &device)?; + let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?; + t.inplace_op1(&op)?; + assert_eq!( + to_vec1_round(&t, 4)?, + &[ + 1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578, + 8103.0806, 22026.469, 59874.133 + ] + ); + Ok(()) +}