Skip to content

Commit

Permalink
feat: functionalities for supporting NeuralOperators.jl (#217)
Browse files Browse the repository at this point in the history
* feat: add `fft` and variants

* fix: support dimension argument to inverse ffts

* fix: correct the semantics of fft_length

* feat: extend support to non-standard dimensions

* feat: compile NNlib.pad_constant with stablehlo.pad

* test: NNlib.pad_constant

* test: fft testing against FFTW
  • Loading branch information
avik-pal authored Nov 5, 2024
1 parent b51717c commit cc8e24f
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
test_group:
- core
- neural_networks
- integration
arch:
- x64
assertions:
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@ Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"

[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"
Expand Down
112 changes: 112 additions & 0 deletions ext/ReactantAbstractFFTsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
module ReactantAbstractFFTsExt

using AbstractFFTs: AbstractFFTs
using Reactant: Reactant, MLIR, TracedRArray

function check_contiguous_innermost_dims(dims, N)
@assert sort([dims...]) == [dims...] "un-sorted dims are not supported"
all(i -> dims[i] == dims[i - 1] + 1, 2:(length(dims))) || return false
dims[1] != 1 && return false
return true
end

function compute_correct_pdims(x::AbstractArray, dims::Int)
counter = 0
return ntuple(ndims(x)) do i
i == 1 && return dims
counter += 1
return counter
end
end

function compute_correct_pdims(x::AbstractArray, dims)
counter = 0
return ntuple(ndims(x)) do i
i length(dims) && return dims[i]
counter += 1
while counter dims
counter += 1
end
return counter
end
end

for op in (:rfft, :fft, :ifft)
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims)
@assert maximum(dims) ndims(x) "dims out of range"
if dims isa Integer
if dims != 1
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), nothing, 1)
end
if !check_contiguous_innermost_dims(dims, ndims(x))
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), nothing, length(dims))
end
end

for op in (:irfft,)
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims)
@assert maximum(dims) ndims(x) "dims out of range"
if dims isa Integer
if dims != 1
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), d, 1)
end
if !check_contiguous_innermost_dims(dims, ndims(x))
pdims = compute_correct_pdims(x, dims)
return permutedims(
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims)
)
end
return generalized_fft(x, $(Meta.quot(op)), d, length(dims))
end
end

function generalized_fft(x::TracedRArray{T,N}, mode::Symbol, d, first_n::Int) where {T,N}
@assert mode (:rfft, :irfft, :fft, :ifft)

x = permutedims(x, reverse(1:N))
fft_type_str = uppercase(string(mode))
fft_type = MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), fft_type_str)

if d === nothing
@assert mode (:rfft, :fft, :ifft)
if mode == :rfft
@assert T <: Real
rT = Complex{T}
res_size = [size(x)[1:(end - 1)]..., size(x, N) ÷ 2 + 1]
else
@assert T <: Complex
rT = T
res_size = [size(x)...]
end
fft_length = [size(x, i) for i in (ndims(x) - first_n + 1):ndims(x)]
else
@assert mode == :irfft
@assert T <: Complex
rT = real(T)
res_size = [size(x)[1:(end - 1)]..., d]
fft_length = [res_size[i] for i in (ndims(x) - first_n + 1):ndims(x)]
end

@assert 1 length(fft_length) 3 "stablehlo.fft only supports up to rank 3"
mlir_type = MLIR.IR.TensorType(res_size, Reactant.MLIR.IR.Type(rT))
op = MLIR.Dialects.stablehlo.fft(x.mlir_data; fft_type, fft_length, result_0=mlir_type)
x = TracedRArray{rT,N}((), MLIR.IR.result(op, 1), Tuple(res_size))
return permutedims(x, reverse(1:N))
end

end
20 changes: 20 additions & 0 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,24 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
return permutedims(res, (2, 3, 1))
end

function NNlib.pad_constant(
x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
) where {T,N}
value = Reactant.promote_to(TracedRNumber{T}, value)
edge_padding_low = [i[1] for i in pad]
edge_padding_high = [i[2] for i in pad]
interior_padding = [0 for i in pad]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.pad(
x.mlir_data,
value.mlir_data;
edge_padding_low,
edge_padding_high,
interior_padding,
),
1,
)
return TracedRArray{T,N}((), res, size(MLIR.IR.type(res)))
end

end # module ReactantNNlibExt
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Expand Down
46 changes: 46 additions & 0 deletions test/integration/fft.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
using FFTW, Reactant

@testset "fft" begin
x = rand(ComplexF32, 2, 2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test_throws AssertionError @jit(fft(x_ra))

x = rand(ComplexF32, 2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test @jit(fft(x_ra)) fft(x)
@test @jit(fft(x_ra, (1, 2))) fft(x, (1, 2))
@test @jit(fft(x_ra, (1, 2, 3))) fft(x, (1, 2, 3))
@test @jit(fft(x_ra, (2, 3))) fft(x, (2, 3))
@test @jit(fft(x_ra, (1, 3))) fft(x, (1, 3))

@test_throws AssertionError @jit(fft(x_ra, (3, 2)))
@test_throws AssertionError @jit(fft(x_ra, (1, 4)))

y_ra = @jit(fft(x_ra))
@test @jit(ifft(y_ra)) x
end

@testset "rfft" begin
x = rand(2, 2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test_throws AssertionError @jit(rfft(x_ra))

x = rand(2, 3, 4)
x_ra = Reactant.ConcreteRArray(x)

@test @jit(rfft(x_ra)) rfft(x)
@test @jit(rfft(x_ra, (1, 2))) rfft(x, (1, 2))
@test @jit(rfft(x_ra, (1, 2, 3))) rfft(x, (1, 2, 3))
@test @jit(rfft(x_ra, (2, 3))) rfft(x, (2, 3))
@test @jit(rfft(x_ra, (1, 3))) rfft(x, (1, 3))

@test_throws AssertionError @jit(rfft(x_ra, (3, 2)))
@test_throws AssertionError @jit(rfft(x_ra, (1, 4)))

y_ra = @jit(rfft(x_ra))
@test @jit(irfft(y_ra, 2)) x
@test @jit(irfft(y_ra, 3)) irfft(rfft(x), 3)
end
41 changes: 41 additions & 0 deletions test/nn/nnlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,44 @@ end

@test @jit(batched_mul(x_ra, y_ra)) batched_mul(x, y)
end

@testset "Constant Padding: NNlib.pad_constant" begin
x = rand(Float32, 4, 4)
x_ra = Reactant.ConcreteRArray(x)

# Symmetric Padding
@test @jit(NNlib.pad_constant(x_ra, (1, 1))) NNlib.pad_constant(x, (1, 1))
@test @jit(NNlib.pad_constant(x_ra, (1, 1, 1, 1))) NNlib.pad_constant(x, (1, 1, 1, 1))

# Asymmetric Padding
@test @jit(NNlib.pad_constant(x_ra, (1, 3, 2, 1))) NNlib.pad_constant(x, (1, 3, 2, 1))
@test @jit(NNlib.pad_constant(x_ra, (1, 0))) NNlib.pad_constant(x, (1, 0))

# Symmetric Padding with value (test type-casting)
@test @jit(NNlib.pad_constant(x_ra, (1, 1), 2)) NNlib.pad_constant(x, (1, 1), 2)
@test @jit(NNlib.pad_constant(x_ra, (1, 1, 1, 1), 2))
NNlib.pad_constant(x, (1, 1, 1, 1), 2)

# Asymmetric Padding with value (test type-casting)
@test @jit(NNlib.pad_constant(x_ra, (1, 3, 2, 1), 2))
NNlib.pad_constant(x, (1, 3, 2, 1), 2)
@test @jit(NNlib.pad_constant(x_ra, (1, 0), 2)) NNlib.pad_constant(x, (1, 0), 2)

# pad_zeros just forward to pad_constant
@test @jit(NNlib.pad_zeros(x_ra, (1, 1))) NNlib.pad_zeros(x, (1, 1))
@test @jit(NNlib.pad_zeros(x_ra, (1, 1, 1, 1))) NNlib.pad_zeros(x, (1, 1, 1, 1))

sumabs2(f, x) = sum(abs2, f(x))

function ∇sumabs2(f, x)
dx = Enzyme.make_zero(x)
Enzyme.autodiff(Reverse, sumabs2, Active, Const(f), Duplicated(x, dx))
return dx
end

pad_fn = Base.Fix2(NNlib.pad_constant, (1, 1, 1, 1))
@test @jit(∇sumabs2(pad_fn, x_ra)) ∇sumabs2(pad_fn, x)

pad_fn2 = Base.Fix2(NNlib.pad_constant, (1, 0, 1, 3))
@test @jit(∇sumabs2(pad_fn2, x_ra)) ∇sumabs2(pad_fn2, x)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Control Flow" include("control_flow.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "AbstractFFTs" include("integration/fft.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
@testset "Neural Networks" begin
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
Expand Down

0 comments on commit cc8e24f

Please sign in to comment.