-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Implement `conj` for `TracedRArray` * Generalize `adjoint` for `TracedRVecOrMat` * Implement `conj` for `TracedRNumber` * Implement `conj!` for `TracedRArray` * Implement `DenseElementsAttribute` for array of `Complex` * Implement `Base.real`, `Base.imag` for `TracedRNumber` * Fix typo * Implement `Base.real`, `Base.imag` for `TracedRArray` * Fix pointer length in `MLIR.IR.DenseElementsAttribute` on `Complex{T}` * Move complex tests to new file * Fix `ConcreteRArray` constructor on `Number` * Fix `to_rarray` on primitive number types * Fix `conj` tests on numbers * Write tests for `real`, `imag` * Remove duplicated method * Update src/TracedRNumber.jl Co-authored-by: Paul Berg <paul@plutojl.org> * fix `image` on `TracedRArray` of reals --------- Co-authored-by: Paul Berg <paul@plutojl.org>
- Loading branch information
Showing
8 changed files
with
179 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
using Test | ||
using Reactant | ||
|
||
@testset "conj" begin | ||
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile conj(x_concrete) | ||
@test only(f(x_concrete)) == conj(x) | ||
end | ||
|
||
@testset "$(typeof(x))" for x in ( | ||
fill(1.0 + 2.0im), | ||
fill(1.0), | ||
[1.0 + 2.0im; 3.0 + 4.0im], | ||
[1.0; 3.0], | ||
[1.0 + 2.0im 3.0 + 4.0im], | ||
[1.0 2.0], | ||
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], | ||
[1.0 3.0; 5.0 7.0], | ||
) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile conj(x_concrete) | ||
@test f(x_concrete) == conj(x) | ||
end | ||
end | ||
|
||
@testset "conj!" begin | ||
@testset "$(typeof(x))" for x in ( | ||
fill(1.0 + 2.0im), | ||
fill(1.0), | ||
[1.0 + 2.0im; 3.0 + 4.0im], | ||
[1.0; 3.0], | ||
[1.0 + 2.0im 3.0 + 4.0im], | ||
[1.0 2.0], | ||
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], | ||
[1.0 3.0; 5.0 7.0], | ||
) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile conj!(x_concrete) | ||
@test f(x_concrete) == conj(x) | ||
@test x_concrete == conj(x) | ||
end | ||
end | ||
|
||
@testset "real" begin | ||
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile real(x_concrete) | ||
@test only(f(x_concrete)) == real(x) | ||
end | ||
|
||
@testset "$(typeof(x))" for x in ( | ||
fill(1.0 + 2.0im), | ||
fill(1.0), | ||
[1.0 + 2.0im; 3.0 + 4.0im], | ||
[1.0; 3.0], | ||
[1.0 + 2.0im 3.0 + 4.0im], | ||
[1.0 2.0], | ||
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], | ||
[1.0 3.0; 5.0 7.0], | ||
) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile real(x_concrete) | ||
@test f(x_concrete) == real(x) | ||
end | ||
end | ||
|
||
@testset "imag" begin | ||
@testset "$(typeof(x))" for x in (1.0, 1.0 + 2.0im) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile imag(x_concrete) | ||
@test only(f(x_concrete)) == imag(x) | ||
end | ||
|
||
@testset "$(typeof(x))" for x in ( | ||
fill(1.0 + 2.0im), | ||
fill(1.0), | ||
[1.0 + 2.0im; 3.0 + 4.0im], | ||
[1.0; 3.0], | ||
[1.0 + 2.0im 3.0 + 4.0im], | ||
[1.0 2.0], | ||
[1.0+2.0im 3.0+4.0im; 5.0+6.0im 7.0+8.0im], | ||
[1.0 3.0; 5.0 7.0], | ||
) | ||
x_concrete = Reactant.to_rarray(x) | ||
f = @compile imag(x_concrete) | ||
@test f(x_concrete) == imag(x) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1dd24b8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reactant.jl Benchmarks
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme)
1332810111
ns1249734857
ns1.07
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant
1310849226
ns1242994162
ns1.05
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme)
1398288716
ns1195264640
ns1.17
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme)
2623910053
ns2306719819
ns1.14
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux
215139523
ns213417826
ns1.01
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme)
5334887020
ns5619173445
ns0.95
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant
5125148459
ns5334345862
ns0.96
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme)
5128099715
ns5283772358
ns0.97
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme)
7093041527
ns6769958907
ns1.05
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux
31469979010
ns35047348986
ns0.90
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme)
1390179366
ns1277061285
ns1.09
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant
1268184459.5
ns1265330704
ns1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme)
1270873858.5
ns1319818472.5
ns0.96
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme)
2492061660
ns2485738704
ns1.00
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux
8221824
ns8548946
ns0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme)
1711092077
ns1640403929
ns1.04
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant
1558493025
ns1624981145
ns0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme)
1546123882
ns1613497904
ns0.96
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme)
2735606253
ns2925330438
ns0.94
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux
2456759080
ns3005994981
ns0.82
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme)
1286461508.5
ns1302523637
ns0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant
1287819878
ns1279881184.5
ns1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme)
1229738497.5
ns1221896798
ns1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme)
2423291169
ns2518812295
ns0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux
20850848
ns21046449.5
ns0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme)
2148735048
ns2222479042
ns0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant
2138885245
ns2232948599
ns0.96
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme)
2133624257
ns2244870900
ns0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme)
3388591813
ns3549740380
ns0.95
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux
5994160924
ns5525916823.5
ns1.08
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme)
1312161644.5
ns1276973264.5
ns1.03
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant
1286524449.5
ns1268819089
ns1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme)
1304833938.5
ns1208687674
ns1.08
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme)
2654597387
ns2407396694
ns1.10
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux
7062374
ns6966972
ns1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme)
1463335858
ns1477620923
ns0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant
1418279069
ns1472576060
ns0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme)
1407215120
ns1474827913
ns0.95
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme)
2610977913
ns2778453642
ns0.94
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux
1313118431
ns1130986304.5
ns1.16
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme)
1266033162
ns1217166466
ns1.04
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant
1246060690
ns1289351790
ns0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme)
1334724021
ns1290846696.5
ns1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme)
2615371467
ns2439329631
ns1.07
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux
11338191
ns11335374
ns1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme)
1712640221
ns1767669698
ns0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant
1714535441
ns1753861643
ns0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme)
1699386874
ns1730166719
ns0.98
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme)
2934490244
ns3055075413
ns0.96
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux
3109868396.5
ns3163273720
ns0.98
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme)
1304791782
ns1252706203
ns1.04
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant
1285502833
ns1244953037
ns1.03
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme)
1277688710
ns1272847971
ns1.00
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme)
2599975980
ns2683822024
ns0.97
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux
25551082.5
ns25478417
ns1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme)
2164823256
ns2236278568
ns0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant
2168156538
ns2237924534
ns0.97
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme)
2195060353
ns2212087492
ns0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme)
3415103088
ns3546320207
ns0.96
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux
6792188737
ns5763894130.5
ns1.18
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme)
1251711662
ns1320883045
ns0.95
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant
1314443020
ns1311502435
ns1.00
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme)
1296451333
ns1359516723
ns0.95
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme)
2569030485
ns2652444409
ns0.97
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux
50146964
ns50149906
ns1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme)
3044550425
ns3038494804
ns1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant
3049319408
ns3034868554
ns1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme)
2999963120
ns3043196341
ns0.99
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme)
4363286803
ns4493792409
ns0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux
10042481459
ns11147401934
ns0.90
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme)
1302500995
ns1322310080
ns0.99
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant
1297504625
ns1325171411
ns0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme)
1310827906
ns1287824673
ns1.02
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme)
2446072271
ns2586974329
ns0.95
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux
67921126
ns68180485
ns1.00
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme)
3156309083
ns3262785501
ns0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant
3173315537
ns3248806786
ns0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme)
3130494484
ns3234623743
ns0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme)
4504318538
ns4707557622
ns0.96
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux
14749507671
ns13676029593
ns1.08
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme)
1303044095
ns1316892556
ns0.99
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant
1304317104.5
ns1260539573
ns1.03
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme)
1332711162
ns1339224548
ns1.00
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme)
2642760989
ns2439170674
ns1.08
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux
19441467
ns19630937
ns0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme)
1863044822
ns1914379352
ns0.97
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant
1837806913
ns1908767496
ns0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme)
1855662911
ns1940246170
ns0.96
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme)
3030096911
ns3177313151
ns0.95
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux
3331794574
ns3458788796.5
ns0.96
This comment was automatically generated by workflow using github-action-benchmark.