-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: functionalities for supporting NeuralOperators.jl (#217)
* feat: add `fft` and variants * fix: support dimension argument to inverse ffts * fix: correct the semantics of fft_length * feat: extend support to non-standard dimensions * feat: compile NNlib.pad_constant with stablehlo.pad * test: NNlib.pad_constant * test: fft testing against FFTW
- Loading branch information
Showing
8 changed files
with
227 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,6 +30,7 @@ jobs: | |
test_group: | ||
- core | ||
- neural_networks | ||
- integration | ||
arch: | ||
- x64 | ||
assertions: | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
module ReactantAbstractFFTsExt | ||
|
||
using AbstractFFTs: AbstractFFTs | ||
using Reactant: Reactant, MLIR, TracedRArray | ||
|
||
function check_contiguous_innermost_dims(dims, N) | ||
@assert sort([dims...]) == [dims...] "un-sorted dims are not supported" | ||
all(i -> dims[i] == dims[i - 1] + 1, 2:(length(dims))) || return false | ||
dims[1] != 1 && return false | ||
return true | ||
end | ||
|
||
function compute_correct_pdims(x::AbstractArray, dims::Int) | ||
counter = 0 | ||
return ntuple(ndims(x)) do i | ||
i == 1 && return dims | ||
counter += 1 | ||
return counter | ||
end | ||
end | ||
|
||
function compute_correct_pdims(x::AbstractArray, dims) | ||
counter = 0 | ||
return ntuple(ndims(x)) do i | ||
i ≤ length(dims) && return dims[i] | ||
counter += 1 | ||
while counter ∈ dims | ||
counter += 1 | ||
end | ||
return counter | ||
end | ||
end | ||
|
||
for op in (:rfft, :fft, :ifft) | ||
@eval function AbstractFFTs.$(op)(x::TracedRArray, dims) | ||
@assert maximum(dims) ≤ ndims(x) "dims out of range" | ||
if dims isa Integer | ||
if dims != 1 | ||
pdims = compute_correct_pdims(x, dims) | ||
return permutedims( | ||
AbstractFFTs.$(op)(permutedims(x, pdims), 1), invperm(pdims) | ||
) | ||
end | ||
return generalized_fft(x, $(Meta.quot(op)), nothing, 1) | ||
end | ||
if !check_contiguous_innermost_dims(dims, ndims(x)) | ||
pdims = compute_correct_pdims(x, dims) | ||
return permutedims( | ||
AbstractFFTs.$(op)(permutedims(x, pdims), 1:length(dims)), invperm(pdims) | ||
) | ||
end | ||
return generalized_fft(x, $(Meta.quot(op)), nothing, length(dims)) | ||
end | ||
end | ||
|
||
for op in (:irfft,) | ||
@eval function AbstractFFTs.$(op)(x::TracedRArray, d::Int, dims) | ||
@assert maximum(dims) ≤ ndims(x) "dims out of range" | ||
if dims isa Integer | ||
if dims != 1 | ||
pdims = compute_correct_pdims(x, dims) | ||
return permutedims( | ||
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1), invperm(pdims) | ||
) | ||
end | ||
return generalized_fft(x, $(Meta.quot(op)), d, 1) | ||
end | ||
if !check_contiguous_innermost_dims(dims, ndims(x)) | ||
pdims = compute_correct_pdims(x, dims) | ||
return permutedims( | ||
AbstractFFTs.$(op)(permutedims(x, pdims), d, 1:length(dims)), invperm(pdims) | ||
) | ||
end | ||
return generalized_fft(x, $(Meta.quot(op)), d, length(dims)) | ||
end | ||
end | ||
|
||
function generalized_fft(x::TracedRArray{T,N}, mode::Symbol, d, first_n::Int) where {T,N} | ||
@assert mode ∈ (:rfft, :irfft, :fft, :ifft) | ||
|
||
x = permutedims(x, reverse(1:N)) | ||
fft_type_str = uppercase(string(mode)) | ||
fft_type = MLIR.API.stablehloFftTypeAttrGet(MLIR.IR.context(), fft_type_str) | ||
|
||
if d === nothing | ||
@assert mode ∈ (:rfft, :fft, :ifft) | ||
if mode == :rfft | ||
@assert T <: Real | ||
rT = Complex{T} | ||
res_size = [size(x)[1:(end - 1)]..., size(x, N) ÷ 2 + 1] | ||
else | ||
@assert T <: Complex | ||
rT = T | ||
res_size = [size(x)...] | ||
end | ||
fft_length = [size(x, i) for i in (ndims(x) - first_n + 1):ndims(x)] | ||
else | ||
@assert mode == :irfft | ||
@assert T <: Complex | ||
rT = real(T) | ||
res_size = [size(x)[1:(end - 1)]..., d] | ||
fft_length = [res_size[i] for i in (ndims(x) - first_n + 1):ndims(x)] | ||
end | ||
|
||
@assert 1 ≤ length(fft_length) ≤ 3 "stablehlo.fft only supports up to rank 3" | ||
mlir_type = MLIR.IR.TensorType(res_size, Reactant.MLIR.IR.Type(rT)) | ||
op = MLIR.Dialects.stablehlo.fft(x.mlir_data; fft_type, fft_length, result_0=mlir_type) | ||
x = TracedRArray{rT,N}((), MLIR.IR.result(op, 1), Tuple(res_size)) | ||
return permutedims(x, reverse(1:N)) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
using FFTW, Reactant | ||
|
||
@testset "fft" begin | ||
x = rand(ComplexF32, 2, 2, 3, 4) | ||
x_ra = Reactant.ConcreteRArray(x) | ||
|
||
@test_throws AssertionError @jit(fft(x_ra)) | ||
|
||
x = rand(ComplexF32, 2, 3, 4) | ||
x_ra = Reactant.ConcreteRArray(x) | ||
|
||
@test @jit(fft(x_ra)) ≈ fft(x) | ||
@test @jit(fft(x_ra, (1, 2))) ≈ fft(x, (1, 2)) | ||
@test @jit(fft(x_ra, (1, 2, 3))) ≈ fft(x, (1, 2, 3)) | ||
@test @jit(fft(x_ra, (2, 3))) ≈ fft(x, (2, 3)) | ||
@test @jit(fft(x_ra, (1, 3))) ≈ fft(x, (1, 3)) | ||
|
||
@test_throws AssertionError @jit(fft(x_ra, (3, 2))) | ||
@test_throws AssertionError @jit(fft(x_ra, (1, 4))) | ||
|
||
y_ra = @jit(fft(x_ra)) | ||
@test @jit(ifft(y_ra)) ≈ x | ||
end | ||
|
||
@testset "rfft" begin | ||
x = rand(2, 2, 3, 4) | ||
x_ra = Reactant.ConcreteRArray(x) | ||
|
||
@test_throws AssertionError @jit(rfft(x_ra)) | ||
|
||
x = rand(2, 3, 4) | ||
x_ra = Reactant.ConcreteRArray(x) | ||
|
||
@test @jit(rfft(x_ra)) ≈ rfft(x) | ||
@test @jit(rfft(x_ra, (1, 2))) ≈ rfft(x, (1, 2)) | ||
@test @jit(rfft(x_ra, (1, 2, 3))) ≈ rfft(x, (1, 2, 3)) | ||
@test @jit(rfft(x_ra, (2, 3))) ≈ rfft(x, (2, 3)) | ||
@test @jit(rfft(x_ra, (1, 3))) ≈ rfft(x, (1, 3)) | ||
|
||
@test_throws AssertionError @jit(rfft(x_ra, (3, 2))) | ||
@test_throws AssertionError @jit(rfft(x_ra, (1, 4))) | ||
|
||
y_ra = @jit(rfft(x_ra)) | ||
@test @jit(irfft(y_ra, 2)) ≈ x | ||
@test @jit(irfft(y_ra, 3)) ≈ irfft(rfft(x), 3) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters