Skip to content

Commit

Permalink
Add a dedicated test.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Oct 27, 2024
1 parent 083d0ea commit 2b16d41
Showing 1 changed file with 30 additions and 0 deletions.
30 changes: 30 additions & 0 deletions candle-core/tests/custom_op_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,33 @@ fn inplace_op1() -> Result<()> {
);
Ok(())
}

#[cfg(feature = "cuda")]
#[allow(clippy::approx_constant)]
#[test]
fn ug_op() -> Result<()> {
let kernel = {
use ug::lang::op;

let layout = ug::Layout::from_shape(&[12]);
let ptr = op::Arg::ptr(ug::DType::F32);
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
let src = op::unary(op::UnaryOp::Exp, src)?;
let st = op::store(ptr.id(), layout, src)?;
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
let opts: ug::lower_op::Opts = Default::default();
kernel.lower(&opts.with_global(0, 12))?
};
let device = Device::new_cuda(0)?;
let op = candle_core::UgIOp1::new("test", kernel, &device)?;
let t = Tensor::arange(0u32, 12u32, &device)?.to_dtype(DType::F32)?;
t.inplace_op1(&op)?;
assert_eq!(
to_vec1_round(&t, 4)?,
&[
1.0, 2.7183, 7.3891, 20.0855, 54.5982, 148.4132, 403.4287, 1096.6334, 2980.9578,
8103.0806, 22026.469, 59874.133
]
);
Ok(())
}

0 comments on commit 2b16d41

Please sign in to comment.