Skip to content

Commit

Permalink
test: fft testing against FFTW
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 5, 2024
1 parent 34a1da2 commit 3e26602
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 2 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
test_group:
- core
- neural_networks
- integration
arch:
- x64
assertions:
Expand Down
2 changes: 2 additions & 0 deletions ext/ReactantAbstractFFTsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ end

for op in (:rfft, :fft, :ifft)
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims)
@assert maximum(dims) ndims(x) "dims out of range"
if dims isa Integer
if dims != 1
pdims = compute_correct_pdims(x, dims)
Expand All @@ -54,6 +55,7 @@ end

for op in (:irfft,)
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims)
@assert maximum(dims) ndims(x) "dims out of range"
if dims isa Integer
if dims != 1
pdims = compute_correct_pdims(x, dims)
Expand Down
46 changes: 46 additions & 0 deletions test/integration/fft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using FFTW, Reactant

@testset "fft" begin
x = rand(ComplexF32, 2, 2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test_throws AssertionError @jit(fft(x_ra))

x = rand(ComplexF32, 2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test @jit(fft(x_ra)) fft(x)
@test @jit(fft(x_ra, (1, 2))) fft(x, (1, 2))
@test @jit(fft(x_ra, (1, 2, 3))) fft(x, (1, 2, 3))
@test @jit(fft(x_ra, (2, 3))) fft(x, (2, 3))
@test @jit(fft(x_ra, (1, 3))) fft(x, (1, 3))

@test_throws AssertionError @jit(fft(x_ra, (3, 2)))
@test_throws AssertionError @jit(fft(x_ra, (1, 4)))

y_ra = @jit(fft(x_ra))
@test @jit(ifft(y_ra)) x
end

@testset "rfft" begin
x = rand(2, 2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test_throws AssertionError @jit(rfft(x_ra))

x = rand(2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test @jit(rfft(x_ra)) rfft(x)
@test @jit(rfft(x_ra, (1, 2))) rfft(x, (1, 2))
@test @jit(rfft(x_ra, (1, 2, 3))) rfft(x, (1, 2, 3))
@test @jit(rfft(x_ra, (2, 3))) rfft(x, (2, 3))
@test @jit(rfft(x_ra, (1, 3))) rfft(x, (1, 3))

@test_throws AssertionError @jit(rfft(x_ra, (3, 2)))
@test_throws AssertionError @jit(rfft(x_ra, (1, 4)))

y_ra = @jit(rfft(x_ra))
@test @jit(irfft(y_ra, 2)) x
@test @jit(irfft(y_ra, 3)) irfft(rfft(x), 3)
end
3 changes: 1 addition & 2 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,7 @@ end
@test @jit(NNlib.pad_constant(x_ra, (1, 0))) NNlib.pad_constant(x, (1, 0))

# Symmetric Padding with value (test type-casting)
@test @jit(NNlib.pad_constant(x_ra, (1, 1), 2))
NNlib.pad_constant(x, (1, 1), 2)
@test @jit(NNlib.pad_constant(x_ra, (1, 1), 2)) NNlib.pad_constant(x, (1, 1), 2)
@test @jit(NNlib.pad_constant(x_ra, (1, 1, 1, 1), 2))
NNlib.pad_constant(x, (1, 1, 1, 1), 2)

Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Control Flow" include("control_flow.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "AbstractFFTs" include("integration/fft.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
@testset "Neural Networks" begin
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
Expand Down

0 comments on commit 3e26602

Please sign in to comment.