From ada2ea2103fac3ae6b40a700acafb1f3c9cf6ed8 Mon Sep 17 00:00:00 2001 From: Andy Polyakov Date: Sun, 7 Jan 2024 22:05:23 +0100 Subject: [PATCH] poc/ntt-cuda: modernize. --- poc/ntt-cuda/src/lib.rs | 24 ++---------------------- poc/ntt-cuda/tests/ntt.rs | 2 +- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/poc/ntt-cuda/src/lib.rs b/poc/ntt-cuda/src/lib.rs index 6847b4e..92162a1 100644 --- a/poc/ntt-cuda/src/lib.rs +++ b/poc/ntt-cuda/src/lib.rs @@ -2,27 +2,7 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -sppark::cuda_error!(); - -#[repr(C)] -pub enum NTTInputOutputOrder { - NN = 0, - NR = 1, - RN = 2, - RR = 3, -} - -#[repr(C)] -enum NTTDirection { - Forward = 0, - Inverse = 1, -} - -#[repr(C)] -enum NTTType { - Standard = 0, - Coset = 1, -} +use sppark::{NTTInputOutputOrder, NTTDirection, NTTType}; extern "C" { fn compute_ntt( @@ -32,7 +12,7 @@ extern "C" { ntt_order: NTTInputOutputOrder, ntt_direction: NTTDirection, ntt_type: NTTType, - ) -> cuda::Error; + ) -> sppark::Error; } /// Compute an in-place NTT on the input data. diff --git a/poc/ntt-cuda/tests/ntt.rs b/poc/ntt-cuda/tests/ntt.rs index 9f69a26..bb8c1ba 100644 --- a/poc/ntt-cuda/tests/ntt.rs +++ b/poc/ntt-cuda/tests/ntt.rs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0, see LICENSE for details. // SPDX-License-Identifier: Apache-2.0 -use ntt_cuda::NTTInputOutputOrder; +use sppark::NTTInputOutputOrder; const DEFAULT_GPU: usize = 0;