From fce399c78de376827e47b4f3bebcdcebfe51871d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 29 Oct 2024 11:09:06 -0400 Subject: [PATCH] feat: support tracing scalars (#205) * feat: support tracing scalars * test: add scalar tests * fix: return concrete scalars * refactor: rename union type to ConcreteRScalar --- deps/ReactantExtra/make-bindings.jl | 6 +- src/Compiler.jl | 11 ++++ src/ConcreteRArray.jl | 89 ++++++++++++++++++++--------- src/Reactant.jl | 2 +- src/Tracing.jl | 47 ++++++++++++--- test/basic.jl | 47 ++++++++++++++- test/compile.jl | 2 +- 7 files changed, 161 insertions(+), 43 deletions(-) diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index e8c8ef62..1ba13eb0 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -6,10 +6,8 @@ function build_file(output_path) dir=@__DIR__, ), ) - Base.Filesystem.cp( - joinpath(@__DIR__, "bazel-bin", file), - output_path; - force=true, + return Base.Filesystem.cp( + joinpath(@__DIR__, "bazel-bin", file), output_path; force=true ) end diff --git a/src/Compiler.jl b/src/Compiler.jl index 081819c0..71bf5157 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -5,6 +5,7 @@ import ..Reactant: MLIR, XLA, ConcreteRArray, + ConcreteRNumber, TracedRArray, TracedRNumber, OrderedIdDict, @@ -30,6 +31,16 @@ function create_result(tocopy::T, path, result_stores) where {T} return Expr(:new, T, elems...) end +function create_result(tocopy::ConcreteRNumber{T}, path, result_stores) where {T} + if haskey(result_stores, path) + restore = result_stores[path] + delete!(result_stores, path) + return :(ConcreteRNumber{$T}($restore)) + end + # We will set the data for this later + return :(ConcreteRNumber{$T}($(tocopy.data))) +end + function create_result(tocopy::ConcreteRArray{T,N}, path, result_stores) where {T,N} if haskey(result_stores, path) restore = result_stores[path] diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index e0534eea..f7bf1653 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -4,11 +4,34 @@ end mutable struct ConcreteRArray{T,N} <: RArray{T,N} data::XLA.AsyncBuffer - # data::XLAArray{T, N} + # data::XLAArray{T, N} shape::NTuple{N,Int} end -ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray(fill(data)) +mutable struct ConcreteRNumber{T} <: RNumber{T} + data::XLA.AsyncBuffer +end + +function ConcreteRNumber( + data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[] +) where {T<:Number} + crarray = ConcreteRArray(fill(data); client, idx) + return ConcreteRNumber{T}(crarray.data) +end + +Base.size(::ConcreteRNumber) = () + +function ConcreteRArray( + data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[] +) where {T<:Number} + Base.depwarn( + "ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead", + :ConcreteRArray, + ) + return ConcreteRArray(fill(data); client, idx) +end + +const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}} Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x) @@ -48,7 +71,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El # XLA.from_row_major(data) end -function synchronize(x::ConcreteRArray) +function synchronize(x::Union{ConcreteRArray,ConcreteRNumber}) XLA.synced_buffer(x.data) return nothing end @@ -60,7 +83,7 @@ end # return ConcreteRArray{T,N}(x.data) # end -function to_float(X::ConcreteRArray{T,0}) where {T} +function to_number(X::ConcreteRScalar{T}) where {T} data = Ref{T}() XLA.await(X.data) buf = X.data.buffer @@ -70,36 +93,49 @@ function to_float(X::ConcreteRArray{T,0}) where {T} return data[] end -function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T} - return to_float(x) +Base.convert(::Type{T}, x::ConcreteRScalar{T}) where {T} = to_number(x) + +for jlop in ( + :(Base.isless), + :(Base.:+), + :(Base.:-), + :(Base.:*), + :(Base.:/), + :(Base.:^), + :(Base.:(==)), + ), + T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) + + @eval begin + $(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y)) + $(jlop)(x::$(T), y::Number) = $(jlop)(to_number(x), y) + $(jlop)(x::Number, y::$(T)) = $(jlop)(x, to_number(y)) + end end -for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^)) +for T in (ConcreteRNumber, ConcreteRArray{<:Any,0}) @eval begin - function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U} - return $jlop(to_float(x), to_float(y)) + function Base.isapprox(x::$(T), y::Number; kwargs...) + return Base.isapprox(to_number(x), y; kwargs...) end - function $jlop(x::ConcreteRArray{T,0}, y) where {T} - return $jlop(to_float(x), y) + + function Base.isapprox(x::Number, y::$(T); kwargs...) + return Base.isapprox(x, to_number(y); kwargs...) end - function $jlop(x, y::ConcreteRArray{U,0}) where {U} - return $jlop(x, to_float(y)) + + function Base.isapprox(x::$(T), y::$(T); kwargs...) + return Base.isapprox(to_number(x), to_number(y); kwargs...) end end end -function Base.isapprox(x::ConcreteRArray{T,0}, y; kwargs...) where {T} - return Base.isapprox(to_float(x), y; kwargs...) -end - -function Base.isapprox(x, y::ConcreteRArray{T,0}; kwargs...) where {T} - return Base.isapprox(x, to_float(y); kwargs...) -end - -function Base.isapprox( - x::ConcreteRArray{T,0}, y::ConcreteRArray{T2,0}; kwargs... -) where {T,T2} - return Base.isapprox(to_float(x), to_float(y); kwargs...) +function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} + if X.data == XLA.AsyncEmptyBuffer + println(io, "") + return nothing + end + str = sprint(show, to_number(X)) + return print(io, "$(typeof(X))($(str))") end function Base.print_array(io::IO, X::ConcreteRArray) @@ -115,7 +151,8 @@ function Base.show(io::IO, X::ConcreteRArray) println(io, "") return nothing end - return Base.show(io, convert(Array, X)) + str = sprint(show, convert(Array, X)) + return print(io, "$(typeof(X))($(str))") end const getindex_warned = Ref(false) diff --git a/src/Reactant.jl b/src/Reactant.jl index 4ae7852b..0aa348d7 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -94,7 +94,7 @@ include("Tracing.jl") include("Compiler.jl") using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile -export ConcreteRArray, @compile, @code_hlo, @jit +export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit const registry = Ref{MLIR.IR.DialectRegistry}() function __init__() diff --git a/src/Tracing.jl b/src/Tracing.jl index c3c6ed78..d4bf45cd 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -189,7 +189,8 @@ function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,m elseif mode == TracedToConcrete @inline base_typec(TV::TT) where {TT<:UnionAll} = UnionAll(TV.var, base_typec(TV.body)) - @inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...} + @inline base_typec(TV::TT) where {TT<:DataType} = + (T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...} return base_typec(T) elseif mode == TracedTrack || mode == TracedSetPath return T @@ -232,6 +233,7 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, + kwargs..., ) where {RT} if haskey(seen, prev) return seen[prev] @@ -308,6 +310,21 @@ function make_tracer( return res end +function make_tracer(seen, prev::ConcreteRNumber{T}, path, mode; kwargs...) where {T} + if mode == ArrayToConcrete + return prev + end + if mode != ConcreteToTraced + throw("Cannot trace existing trace type") + end + if haskey(seen, prev) + return seen[prev]::TracedRNumber{T} + end + res = TracedRNumber{T}((path,), nothing) + seen[prev] = res + return res +end + function make_tracer( seen, @nospecialize(prev::TracedRArray{T,N}), @@ -315,6 +332,7 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, + kwargs..., ) where {T,N} if mode == ConcreteToTraced throw("Cannot trace existing trace type") @@ -389,9 +407,9 @@ function make_tracer( if mode == TracedToConcrete if haskey(seen, prev) - return seen[prev]::ConcreteRArray{T,0} + return seen[prev]::ConcreteRNumber{T} end - res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev)) + res = ConcreteRNumber{T}(XLA.AsyncEmptyBuffer) seen[prev] = res return res end @@ -400,8 +418,13 @@ function make_tracer( end function make_tracer( - seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs... -) where {RT<:AbstractFloat} + seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs... +) where {RT<:Number} + if mode == ArrayToConcrete + length(track_numbers) == 0 && return prev + should_convert = any(Base.Fix1(<:, RT), track_numbers) + return should_convert ? ConcreteRNumber(prev) : prev + end return prev end @@ -414,10 +437,15 @@ function make_tracer( mode; toscalar=false, tobatch=nothing, + kwargs..., ) where {RT} return Complex( - make_tracer(seen, prev.re, append_path(path, :re), mode; toscalar, tobatch), - make_tracer(seen, prev.im, append_path(path, :im), mode; toscalar, tobatch), + make_tracer( + seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs... + ), + make_tracer( + seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs... + ), ) end @@ -489,8 +517,9 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...) return res end -@inline function to_rarray(@nospecialize(x)) - return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete) +@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=()) + track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ()) + return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers) end to_rarray(x::ReactantPrimitive) = ConcreteRArray(x) diff --git a/test/basic.jl b/test/basic.jl index 2e2d7c6a..505c1b25 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -80,8 +80,7 @@ sum_compare(x) = sum(x) > 0 # Ensure we are tracing as scalars. Else this will fail due to > not being defined on # arrays f = @compile sum_compare(a) - # We need to use [] to unwrap the scalar. We will fix this in the future. - @test f(a)[] == sum_compare(x) + @test f(a) == sum_compare(x) end function mysoftmax!(x) @@ -445,3 +444,47 @@ end c = Reactant.compile(+, (a, b))(a, b) @test c == ones(CT, 2) + ones(CT, 2) end + +@testset "Scalars" begin + @testset "Only Scalars" begin + x = (3, 3.14) + + f1(x) = x[1] * x[2] + + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + f2 = @compile f1(x_ra) + @test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) ≈ 5 * 5.2 + @test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) isa ConcreteRNumber + + x_ra = Reactant.to_rarray(x) + f3 = @compile f1(x_ra) + @test f3(Reactant.to_rarray((5, 5.2))) ≈ f1(x) + @test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcreteRNumber) + @test f3(Reactant.to_rarray((5, 5.2))) isa Number + + x_ra = Reactant.to_rarray(x; track_numbers=(Int,)) + f4 = @compile f1(x_ra) + @test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) ≈ 5 * 3.14 + @test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) isa ConcreteRNumber + end + + @testset "Mixed" begin + x = (3, [3.14]) + + f1(x) = x[1] * x[2] + + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + f2 = @compile f1(x_ra) + res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,))) + @test only(res2) ≈ 5 * 3.14 + @test res2 isa ConcreteRArray + + x_ra = Reactant.to_rarray(x) + + f3 = @compile f1(x_ra) + res3 = f3(Reactant.to_rarray((5, [3.14]))) + @test only(res3) ≈ only(f1(x)) + @test res3 isa ConcreteRArray + end +end diff --git a/test/compile.jl b/test/compile.jl index 85692ad7..662d6142 100644 --- a/test/compile.jl +++ b/test/compile.jl @@ -11,7 +11,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a= f = @compile sum(x2) - @test f(x2) isa @NamedTuple{a::Reactant.ConcreteRArray{Float64,0}} + @test f(x2) isa @NamedTuple{a::Reactant.ConcreteRNumber{Float64}} @test isapprox(f(x2).a, sum(x.a)) end end