Skip to content

Commit

Permalink
Implement conj, conj! for TracedRArray (#169)
Browse files Browse the repository at this point in the history
* Implement `conj` for `TracedRArray`

* Generalize `adjoint` for `TracedRVecOrMat`

* Implement `conj` for `TracedRNumber`

* Implement `conj!` for `TracedRArray`

* Implement `DenseElementsAttribute` for array of `Complex`

* Implement `Base.real`, `Base.imag` for `TracedRNumber`

* Fix typo

* Implement `Base.real`, `Base.imag` for `TracedRArray`

* Fix pointer length in `MLIR.IR.DenseElementsAttribute` on `Complex{T}`

* Move complex tests to new file

* Fix `ConcreteRArray` constructor on `Number`

* Fix `to_rarray` on primitive number types

* Fix `conj` tests on numbers

* Write tests for `real`, `imag`

* Remove duplicated method

* Update src/TracedRNumber.jl

Co-authored-by: Paul Berg <paul@plutojl.org>

* fix `image` on `TracedRArray` of reals

---------

Co-authored-by: Paul Berg <paul@plutojl.org>
  • Loading branch information
mofeing and Pangoraw authored Oct 28, 2024
1 parent 2f6a6d9 commit 1dd24b8
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct ConcreteRArray{T,N} <: RArray{T,N}
shape::NTuple{N,Int}
end

ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray{T,0}(data, ())
ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray(fill(data))

Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)

Expand Down
53 changes: 52 additions & 1 deletion src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,62 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
)
end

Base.conj(A::TracedRArray) = A
function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N}
return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.chlo.conj(
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
),
1,
),
size(A),
)
end

Base.conj!(A::TracedRArray) = A
function Base.conj!(A::TracedRArray{T,N}) where {T<:Complex,N}
A.mlir_data = MLIR.IR.result(
MLIR.Dialects.chlo.conj(A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))),
1,
)
return A
end

Base.real(A::TracedRArray) = A
function Base.real(A::TracedRArray{Complex{T},N}) where {T,N}
return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.real(
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
),
1,
),
size(A),
)
end

Base.imag(A::TracedRArray) = zero(A)
function Base.imag(A::TracedRArray{Complex{T},N}) where {T,N}
return TracedRArray{T,N}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.imag(
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
),
1,
),
size(A),
)
end

function Base.transpose(A::AnyTracedRVecOrMat)
A = ndims(A) == 1 ? reshape(A, :, 1) : A
return permutedims(A, (2, 1))
end
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
Base.adjoint(A::AnyTracedRVecOrMat) = conj(transpose(A))

function promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N}
if isa(rhs, TracedRArray)
Expand Down
30 changes: 30 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,36 @@ for (jlop, hloop) in (
end
end

Base.conj(x::TracedRNumber) = x
function Base.conj(x::TracedRNumber{T}) where {T<:Complex}
return TracedRNumber{T}(
(),
MLIR.IR.result(
MLIR.Dialects.chlo.conj(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1
),
)
end

Base.real(x::TracedRNumber) = x
function Base.real(x::TracedRNumber{Complex{T}}) where {T}
return TracedRNumber{T}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.real(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1
),
)
end

Base.imag(x::TracedRNumber) = zero(x)
function Base.imag(x::TracedRNumber{Complex{T}}) where {T}
return TracedRNumber{T}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.imag(x.mlir_data; result=mlir_type(TracedRNumber{T})), 1
),
)
end

# XXX: Enzyme-MLIR doesn't have `abs` adjoint defined
Base.abs2(x::TracedRNumber{<:Real}) = x^2

Expand Down
2 changes: 2 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -492,3 +492,5 @@ end
@inline function to_rarray(@nospecialize(x))
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete)
end

to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)
6 changes: 4 additions & 2 deletions src/mlir/IR/Attribute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -567,8 +567,10 @@ end
function DenseElementsAttribute(values::AbstractArray{<:Complex})
shaped_type = TensorType(size(values), Type(eltype(values)))
# TODO: row major
Attribute(
API.mlirDenseElementsAttrRawBufferGet(shaped_type, length(values) * sizeof(eltype(values)), values)
return Attribute(
API.mlirDenseElementsAttrRawBufferGet(
shaped_type, length(values) * Base.elsize(values), values
),
)
end

Expand Down
7 changes: 0 additions & 7 deletions src/mlir/IR/Type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,13 +85,6 @@ Creates a signless integer type of the given bitwidth in the context. The type i
Type(T::Core.Type{<:Integer}; context::Context=context()) =
Type(API.mlirIntegerTypeGet(context, sizeof(T) * 8))

"""
Type(T::Core.Type{<:Complex}; context=context())
Creates a complex type with the given element type.
"""
Type(T::Core.Type{<:Complex}; context=context()) = Type(API.mlirComplexTypeGet(Type(T(im) |> real |> typeof)))

"""
Type(T::Core.Type{<:Signed}; context=context()
Expand Down
89 changes: 89 additions & 0 deletions test/complex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using Test
using Reactant

@testset "conj" begin
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
x_concrete = Reactant.to_rarray(x)
f = @compile conj(x_concrete)
@test only(f(x_concrete)) == conj(x)
end

@testset "$(typeof(x))" for x in (
fill(1.0 + 2.0im),
fill(1.0),
[1.0 + 2.0im; 3.0 + 4.0im],
[1.0; 3.0],
[1.0 + 2.0im 3.0 + 4.0im],
[1.0 2.0],
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
[1.0 3.0; 5.0 7.0],
)
x_concrete = Reactant.to_rarray(x)
f = @compile conj(x_concrete)
@test f(x_concrete) == conj(x)
end
end

@testset "conj!" begin
@testset "$(typeof(x))" for x in (
fill(1.0 + 2.0im),
fill(1.0),
[1.0 + 2.0im; 3.0 + 4.0im],
[1.0; 3.0],
[1.0 + 2.0im 3.0 + 4.0im],
[1.0 2.0],
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
[1.0 3.0; 5.0 7.0],
)
x_concrete = Reactant.to_rarray(x)
f = @compile conj!(x_concrete)
@test f(x_concrete) == conj(x)
@test x_concrete == conj(x)
end
end

@testset "real" begin
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
x_concrete = Reactant.to_rarray(x)
f = @compile real(x_concrete)
@test only(f(x_concrete)) == real(x)
end

@testset "$(typeof(x))" for x in (
fill(1.0 + 2.0im),
fill(1.0),
[1.0 + 2.0im; 3.0 + 4.0im],
[1.0; 3.0],
[1.0 + 2.0im 3.0 + 4.0im],
[1.0 2.0],
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
[1.0 3.0; 5.0 7.0],
)
x_concrete = Reactant.to_rarray(x)
f = @compile real(x_concrete)
@test f(x_concrete) == real(x)
end
end

@testset "imag" begin
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im)
x_concrete = Reactant.to_rarray(x)
f = @compile imag(x_concrete)
@test only(f(x_concrete)) == imag(x)
end

@testset "$(typeof(x))" for x in (
fill(1.0 + 2.0im),
fill(1.0),
[1.0 + 2.0im; 3.0 + 4.0im],
[1.0; 3.0],
[1.0 + 2.0im 3.0 + 4.0im],
[1.0 2.0],
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im],
[1.0 3.0; 5.0 7.0],
)
x_concrete = Reactant.to_rarray(x)
f = @compile imag(x_concrete)
@test f(x_concrete) == imag(x)
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Layout" include("layout.jl")
@safetestset "Tracing" include("tracing.jl")
@safetestset "Basic" include("basic.jl")
@safetestset "Complex" include("complex.jl")
@safetestset "Broadcast" include("bcast.jl")
@safetestset "Struct" include("struct.jl")
@safetestset "Closure" include("closure.jl")
Expand Down

1 comment on commit 1dd24b8

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 1dd24b8 Previous: 2f6a6d9 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1332810111 ns 1249734857 ns 1.07
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1310849226 ns 1242994162 ns 1.05
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1398288716 ns 1195264640 ns 1.17
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2623910053 ns 2306719819 ns 1.14
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 215139523 ns 213417826 ns 1.01
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 5334887020 ns 5619173445 ns 0.95
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5125148459 ns 5334345862 ns 0.96
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5128099715 ns 5283772358 ns 0.97
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7093041527 ns 6769958907 ns 1.05
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 31469979010 ns 35047348986 ns 0.90
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1390179366 ns 1277061285 ns 1.09
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1268184459.5 ns 1265330704 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1270873858.5 ns 1319818472.5 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2492061660 ns 2485738704 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8221824 ns 8548946 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1711092077 ns 1640403929 ns 1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1558493025 ns 1624981145 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1546123882 ns 1613497904 ns 0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2735606253 ns 2925330438 ns 0.94
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2456759080 ns 3005994981 ns 0.82
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1286461508.5 ns 1302523637 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1287819878 ns 1279881184.5 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1229738497.5 ns 1221896798 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2423291169 ns 2518812295 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 20850848 ns 21046449.5 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2148735048 ns 2222479042 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2138885245 ns 2232948599 ns 0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2133624257 ns 2244870900 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3388591813 ns 3549740380 ns 0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5994160924 ns 5525916823.5 ns 1.08
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1312161644.5 ns 1276973264.5 ns 1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1286524449.5 ns 1268819089 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1304833938.5 ns 1208687674 ns 1.08
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2654597387 ns 2407396694 ns 1.10
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7062374 ns 6966972 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1463335858 ns 1477620923 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1418279069 ns 1472576060 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1407215120 ns 1474827913 ns 0.95
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 2610977913 ns 2778453642 ns 0.94
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1313118431 ns 1130986304.5 ns 1.16
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1266033162 ns 1217166466 ns 1.04
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1246060690 ns 1289351790 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1334724021 ns 1290846696.5 ns 1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2615371467 ns 2439329631 ns 1.07
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 11338191 ns 11335374 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1712640221 ns 1767669698 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1714535441 ns 1753861643 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1699386874 ns 1730166719 ns 0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 2934490244 ns 3055075413 ns 0.96
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3109868396.5 ns 3163273720 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1304791782 ns 1252706203 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1285502833 ns 1244953037 ns 1.03
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1277688710 ns 1272847971 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2599975980 ns 2683822024 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 25551082.5 ns 25478417 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2164823256 ns 2236278568 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2168156538 ns 2237924534 ns 0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2195060353 ns 2212087492 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3415103088 ns 3546320207 ns 0.96
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6792188737 ns 5763894130.5 ns 1.18
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1251711662 ns 1320883045 ns 0.95
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1314443020 ns 1311502435 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1296451333 ns 1359516723 ns 0.95
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 2569030485 ns 2652444409 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 50146964 ns 50149906 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3044550425 ns 3038494804 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3049319408 ns 3034868554 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2999963120 ns 3043196341 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4363286803 ns 4493792409 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 10042481459 ns 11147401934 ns 0.90
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1302500995 ns 1322310080 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1297504625 ns 1325171411 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1310827906 ns 1287824673 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 2446072271 ns 2586974329 ns 0.95
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 67921126 ns 68180485 ns 1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3156309083 ns 3262785501 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3173315537 ns 3248806786 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3130494484 ns 3234623743 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 4504318538 ns 4707557622 ns 0.96
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 14749507671 ns 13676029593 ns 1.08
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1303044095 ns 1316892556 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1304317104.5 ns 1260539573 ns 1.03
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1332711162 ns 1339224548 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 2642760989 ns 2439170674 ns 1.08
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 19441467 ns 19630937 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1863044822 ns 1914379352 ns 0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1837806913 ns 1908767496 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1855662911 ns 1940246170 ns 0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3030096911 ns 3177313151 ns 0.95
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3331794574 ns 3458788796.5 ns 0.96

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.