Skip to content

Commit

Permalink
feat: overload mul!
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 6, 2024
1 parent bba6203 commit 5f62262
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Expand All @@ -34,6 +35,7 @@ ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13"
LinearAlgebra = "1.10"
NNlib = "0.9"
OrderedCollections = "1"
Preferences = "1.4"
Expand Down
1 change: 1 addition & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module Reactant

using ReactantCore: ReactantCore, @trace, MissingTracedValue

using LinearAlgebra: LinearAlgebra
using Adapt: Adapt, WrappedArray

# auxiliary types and functions
Expand Down
30 changes: 20 additions & 10 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,30 +381,40 @@ for (jlop, hloop, hlocomp, merge) in
end
end

function Base.:*(
@nospecialize(lhs::TracedRArray{T,2}), @nospecialize(rhs::TracedRArray{T,2})
) where {T}
lhsty = MLIR.IR.type(lhs.mlir_data)
rhsty = MLIR.IR.type(rhs.mlir_data)
resty = MLIR.IR.TensorType((size(lhs, 1), size(rhs, 2)), eltype(lhsty))
function LinearAlgebra.mul!(
@nospecialize(C::TracedRArray{T1,2}),
@nospecialize(A::TracedRArray{T2,2}),
@nospecialize(B::TracedRArray{T3,2}),
) where {T1,T2,T3}
if size(C) != (size(A, 1), size(B, 2))
throw(
DimensionMismatch(
"C has size $(size(C)), A has size $(size(A)), B has size $(size(B))"
),
)
end
if size(A, 2) != size(B, 1)
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
end
resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1))
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0]
)
prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
precar = MLIR.IR.Attribute([prec, prec])
res = MLIR.IR.result(
C.mlir_data = MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
lhs.mlir_data,
rhs.mlir_data;
A.mlir_data,
B.mlir_data;
result_0=resty,
dot_dimension_numbers=dot_dimension_numbers,
precision_config=precar,
),
1,
)
return TracedRArray{T,2}((), res, (size(lhs, 1), size(rhs, 2)))
return C
end

function Enzyme.Compiler.active_reg_inner(
Expand Down
25 changes: 25 additions & 0 deletions test/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using LinearAlgebra, Reactant

function muladd2(A, x, b)
C = similar(A, promote_type(eltype(A), eltype(b)), size(A, 1), size(x, 2))
mul!(C, A, x)
C .+= b
return C
end

@testset begin
A = Reactant.to_rarray(rand(4, 4))
x = Reactant.to_rarray(rand(4, 2))
b = Reactant.to_rarray(rand(4))

@test @jit(muladd2(A, x, b)) muladd2(A, x, b)

# Mixed Precision
x = Reactant.to_rarray(rand(Float32, 4, 2))

@test @jit(muladd2(A, x, b)) muladd2(A, x, b)

C = similar(A, Float32, size(A, 1), size(x, 2))
@jit(mul!(C, A, x))
@test C A * x
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
@safetestset "Buffer Donation" include("buffer_donation.jl")
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Linear Algebra" include("linear_algebra.jl")
end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
Expand Down

0 comments on commit 5f62262

Please sign in to comment.