diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 482adfb1..14b400f0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -30,6 +30,7 @@ jobs: test_group: - core - neural_networks + - integration arch: - x64 assertions: diff --git a/Project.toml b/Project.toml index 5f0869c8..08d38921 100644 --- a/Project.toml +++ b/Project.toml @@ -16,11 +16,13 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0" Scratch = "6c6a2e73-6563-6170-7368-637461726353" [weakdeps] +AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [extensions] +ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" ReactantNNlibExt = "NNlib" ReactantStatisticsExt = "Statistics" diff --git a/ext/ReactantAbstractFFTsExt.jl b/ext/ReactantAbstractFFTsExt.jl new file mode 100644 index 00000000..32a92fc1 --- /dev/null +++ b/ext/ReactantAbstractFFTsExt.jl @@ -0,0 +1,112 @@ +module ReactantAbstractFFTsExt + +using AbstractFFTs: AbstractFFTs +using Reactant: Reactant, MLIR, TracedRArray + +function check_contiguous_innermost_dims(dims, N) + @assert sort([dims...]) == [dims...] "un-sorted dims are not supported" + all(i -> dims[i] == dims[i - 1] + 1, 2:(length(dims))) || return false + dims[1] != 1 && return false + return true +end + +function compute_correct_pdims(x::AbstractArray, dims::Int) + counter = 0 + return ntuple(ndims(x)) do i + i == 1 && return dims + counter += 1 + return counter + end +end + +function compute_correct_pdims(x::AbstractArray, dims) + counter = 0 + return ntuple(ndims(x)) do i + i ≤ length(dims) && return dims[i] + counter += 1 + while counter ∈ dims + counter += 1 + end + return counter + end +end + +for op in (:rfft, :fft, :ifft) + @eval function AbstractFFTs.$(op)(x::TracedRArray, dims) + @assert maximum(dims) ≤ ndims(x) "dims out of range" + if dims isa Integer + if dims != 1 + pdims = compute_correct_pdims(x, dims) + return permutedims( + AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims) + ) + end + return generalized_fft(x, $(Meta.quot(op)), nothing, 1) + end + if !check_contiguous_innermost_dims(dims, ndims(x)) + pdims = compute_correct_pdims(x, dims) + return permutedims( + AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims) + ) + end + return generalized_fft(x, $(Meta.quot(op)), nothing, length(dims)) + end +end + +for op in (:irfft,) + @eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims) + @assert maximum(dims) ≤ ndims(x) "dims out of range" + if dims isa Integer + if dims != 1 + pdims = compute_correct_pdims(x, dims) + return permutedims( + AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims) + ) + end + return generalized_fft(x, $(Meta.quot(op)), d, 1) + end + if !check_contiguous_innermost_dims(dims, ndims(x)) + pdims = compute_correct_pdims(x, dims) + return permutedims( + AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims) + ) + end + return generalized_fft(x, $(Meta.quot(op)), d, length(dims)) + end +end + +function generalized_fft(x::TracedRArray{T,N}, mode::Symbol, d, first_n::Int) where {T,N} + @assert mode ∈ (:rfft, :irfft, :fft, :ifft) + + x = permutedims(x, reverse(1:N)) + fft_type_str = uppercase(string(mode)) + fft_type = MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), fft_type_str) + + if d === nothing + @assert mode ∈ (:rfft, :fft, :ifft) + if mode == :rfft + @assert T <: Real + rT = Complex{T} + res_size = [size(x)[1:(end - 1)]..., size(x, N) ÷ 2 + 1] + else + @assert T <: Complex + rT = T + res_size = [size(x)...] + end + fft_length = [size(x, i) for i in (ndims(x) - first_n + 1):ndims(x)] + else + @assert mode == :irfft + @assert T <: Complex + rT = real(T) + res_size = [size(x)[1:(end - 1)]..., d] + fft_length = [res_size[i] for i in (ndims(x) - first_n + 1):ndims(x)] + end + + @assert 1 ≤ length(fft_length) ≤ 3 "stablehlo.fft only supports up to rank 3" + mlir_type = MLIR.IR.TensorType(res_size, Reactant.MLIR.IR.Type(rT)) + op = MLIR.Dialects.stablehlo.fft(x.mlir_data; fft_type, fft_length, result_0=mlir_type) + x = TracedRArray{rT,N}((), MLIR.IR.result(op, 1), Tuple(res_size)) + return permutedims(x, reverse(1:N)) +end + +end diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index f7ee82c2..ef749eaf 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -260,4 +260,24 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe return permutedims(res, (2, 3, 1)) end +function NNlib.pad_constant( + x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value +) where {T,N} + value = Reactant.promote_to(TracedRNumber{T}, value) + edge_padding_low = [i[1] for i in pad] + edge_padding_high = [i[2] for i in pad] + interior_padding = [0 for i in pad] + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.pad( + x.mlir_data, + value.mlir_data; + edge_padding_low, + edge_padding_high, + interior_padding, + ), + 1, + ) + return TracedRArray{T,N}((), res, size(MLIR.IR.type(res))) +end + end # module ReactantNNlibExt diff --git a/test/Project.toml b/test/Project.toml index 937025d1..0a456679 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,6 +1,7 @@ [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" diff --git a/test/integration/fft.jl b/test/integration/fft.jl new file mode 100644 index 00000000..d39ac6d2 --- /dev/null +++ b/test/integration/fft.jl @@ -0,0 +1,46 @@ +using FFTW, Reactant + +@testset "fft" begin + x = rand(ComplexF32, 2, 2, 3, 4) + x_ra = Reactant.ConcreteRArray(x) + + @test_throws AssertionError @jit(fft(x_ra)) + + x = rand(ComplexF32, 2, 3, 4) + x_ra = Reactant.ConcreteRArray(x) + + @test @jit(fft(x_ra)) ≈ fft(x) + @test @jit(fft(x_ra, (1, 2))) ≈ fft(x, (1, 2)) + @test @jit(fft(x_ra, (1, 2, 3))) ≈ fft(x, (1, 2, 3)) + @test @jit(fft(x_ra, (2, 3))) ≈ fft(x, (2, 3)) + @test @jit(fft(x_ra, (1, 3))) ≈ fft(x, (1, 3)) + + @test_throws AssertionError @jit(fft(x_ra, (3, 2))) + @test_throws AssertionError @jit(fft(x_ra, (1, 4))) + + y_ra = @jit(fft(x_ra)) + @test @jit(ifft(y_ra)) ≈ x +end + +@testset "rfft" begin + x = rand(2, 2, 3, 4) + x_ra = Reactant.ConcreteRArray(x) + + @test_throws AssertionError @jit(rfft(x_ra)) + + x = rand(2, 3, 4) + x_ra = Reactant.ConcreteRArray(x) + + @test @jit(rfft(x_ra)) ≈ rfft(x) + @test @jit(rfft(x_ra, (1, 2))) ≈ rfft(x, (1, 2)) + @test @jit(rfft(x_ra, (1, 2, 3))) ≈ rfft(x, (1, 2, 3)) + @test @jit(rfft(x_ra, (2, 3))) ≈ rfft(x, (2, 3)) + @test @jit(rfft(x_ra, (1, 3))) ≈ rfft(x, (1, 3)) + + @test_throws AssertionError @jit(rfft(x_ra, (3, 2))) + @test_throws AssertionError @jit(rfft(x_ra, (1, 4))) + + y_ra = @jit(rfft(x_ra)) + @test @jit(irfft(y_ra, 2)) ≈ x + @test @jit(irfft(y_ra, 3)) ≈ irfft(rfft(x), 3) +end diff --git a/test/nn/nnlib.jl b/test/nn/nnlib.jl index be698df7..5b2f8bf1 100644 --- a/test/nn/nnlib.jl +++ b/test/nn/nnlib.jl @@ -117,3 +117,44 @@ end @test @jit(batched_mul(x_ra, y_ra)) ≈ batched_mul(x, y) end + +@testset "Constant Padding: NNlib.pad_constant" begin + x = rand(Float32, 4, 4) + x_ra = Reactant.ConcreteRArray(x) + + # Symmetric Padding + @test @jit(NNlib.pad_constant(x_ra, (1, 1))) ≈ NNlib.pad_constant(x, (1, 1)) + @test @jit(NNlib.pad_constant(x_ra, (1, 1, 1, 1))) ≈ NNlib.pad_constant(x, (1, 1, 1, 1)) + + # Asymmetric Padding + @test @jit(NNlib.pad_constant(x_ra, (1, 3, 2, 1))) ≈ NNlib.pad_constant(x, (1, 3, 2, 1)) + @test @jit(NNlib.pad_constant(x_ra, (1, 0))) ≈ NNlib.pad_constant(x, (1, 0)) + + # Symmetric Padding with value (test type-casting) + @test @jit(NNlib.pad_constant(x_ra, (1, 1), 2)) ≈ NNlib.pad_constant(x, (1, 1), 2) + @test @jit(NNlib.pad_constant(x_ra, (1, 1, 1, 1), 2)) ≈ + NNlib.pad_constant(x, (1, 1, 1, 1), 2) + + # Asymmetric Padding with value (test type-casting) + @test @jit(NNlib.pad_constant(x_ra, (1, 3, 2, 1), 2)) ≈ + NNlib.pad_constant(x, (1, 3, 2, 1), 2) + @test @jit(NNlib.pad_constant(x_ra, (1, 0), 2)) ≈ NNlib.pad_constant(x, (1, 0), 2) + + # pad_zeros just forward to pad_constant + @test @jit(NNlib.pad_zeros(x_ra, (1, 1))) ≈ NNlib.pad_zeros(x, (1, 1)) + @test @jit(NNlib.pad_zeros(x_ra, (1, 1, 1, 1))) ≈ NNlib.pad_zeros(x, (1, 1, 1, 1)) + + sumabs2(f, x) = sum(abs2, f(x)) + + function ∇sumabs2(f, x) + dx = Enzyme.make_zero(x) + Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx)) + return dx + end + + pad_fn = Base.Fix2(NNlib.pad_constant, (1, 1, 1, 1)) + @test @jit(∇sumabs2(pad_fn, x_ra)) ≈ ∇sumabs2(pad_fn, x) + + pad_fn2 = Base.Fix2(NNlib.pad_constant, (1, 0, 1, 3)) + @test @jit(∇sumabs2(pad_fn2, x_ra)) ≈ ∇sumabs2(pad_fn2, x) +end diff --git a/test/runtests.jl b/test/runtests.jl index bfc2f4a4..74414f30 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -56,6 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) @safetestset "Control Flow" include("control_flow.jl") end + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" + @safetestset "AbstractFFTs" include("integration/fft.jl") + end + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" @testset "Neural Networks" begin @safetestset "NNlib Primitives" include("nn/nnlib.jl")