Skip to content

Commit

Permalink
feat: support tracing scalars (#205)
Browse files Browse the repository at this point in the history
* feat: support tracing scalars

* test: add scalar tests

* fix: return concrete scalars

* refactor: rename union type to ConcreteRScalar
  • Loading branch information
avik-pal authored Oct 29, 2024
1 parent 6866f05 commit fce399c
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 43 deletions.
6 changes: 2 additions & 4 deletions deps/ReactantExtra/make-bindings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ function build_file(output_path)
dir=@__DIR__,
),
)
Base.Filesystem.cp(
joinpath(@__DIR__, "bazel-bin", file),
output_path;
force=true,
return Base.Filesystem.cp(
joinpath(@__DIR__, "bazel-bin", file), output_path; force=true
)
end

Expand Down
11 changes: 11 additions & 0 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ..Reactant:
MLIR,
XLA,
ConcreteRArray,
ConcreteRNumber,
TracedRArray,
TracedRNumber,
OrderedIdDict,
Expand All @@ -30,6 +31,16 @@ function create_result(tocopy::T, path, result_stores) where {T}
return Expr(:new, T, elems...)
end

function create_result(tocopy::ConcreteRNumber{T}, path, result_stores) where {T}
if haskey(result_stores, path)
restore = result_stores[path]
delete!(result_stores, path)
return :(ConcreteRNumber{$T}($restore))
end
# We will set the data for this later
return :(ConcreteRNumber{$T}($(tocopy.data)))
end

function create_result(tocopy::ConcreteRArray{T,N}, path, result_stores) where {T,N}
if haskey(result_stores, path)
restore = result_stores[path]
Expand Down
89 changes: 63 additions & 26 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,34 @@ end

mutable struct ConcreteRArray{T,N} <: RArray{T,N}
data::XLA.AsyncBuffer
# data::XLAArray{T, N}
# data::XLAArray{T, N}
shape::NTuple{N,Int}
end

ConcreteRArray(data::T) where {T<:Number} = ConcreteRArray(fill(data))
mutable struct ConcreteRNumber{T} <: RNumber{T}
data::XLA.AsyncBuffer
end

function ConcreteRNumber(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[]
) where {T<:Number}
crarray = ConcreteRArray(fill(data); client, idx)
return ConcreteRNumber{T}(crarray.data)
end

Base.size(::ConcreteRNumber) = ()

function ConcreteRArray(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[]
) where {T<:Number}
Base.depwarn(
"ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead",
:ConcreteRArray,
)
return ConcreteRArray(fill(data); client, idx)
end

const ConcreteRScalar{T} = Union{ConcreteRArray{T,0},ConcreteRNumber{T}}

Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x)

Expand Down Expand Up @@ -48,7 +71,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
# XLA.from_row_major(data)
end

function synchronize(x::ConcreteRArray)
function synchronize(x::Union{ConcreteRArray,ConcreteRNumber})
XLA.synced_buffer(x.data)
return nothing
end
Expand All @@ -60,7 +83,7 @@ end
# return ConcreteRArray{T,N}(x.data)
# end

function to_float(X::ConcreteRArray{T,0}) where {T}
function to_number(X::ConcreteRScalar{T}) where {T}
data = Ref{T}()
XLA.await(X.data)
buf = X.data.buffer
Expand All @@ -70,36 +93,49 @@ function to_float(X::ConcreteRArray{T,0}) where {T}
return data[]
end

function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T}
return to_float(x)
Base.convert(::Type{T}, x::ConcreteRScalar{T}) where {T} = to_number(x)

for jlop in (
:(Base.isless),
:(Base.:+),
:(Base.:-),
:(Base.:*),
:(Base.:/),
:(Base.:^),
:(Base.:(==)),
),
T in (ConcreteRNumber, ConcreteRArray{<:Any,0})

@eval begin
$(jlop)(x::$(T), y::$(T)) = $(jlop)(to_number(x), to_number(y))
$(jlop)(x::$(T), y::Number) = $(jlop)(to_number(x), y)
$(jlop)(x::Number, y::$(T)) = $(jlop)(x, to_number(y))
end
end

for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^))
for T in (ConcreteRNumber, ConcreteRArray{<:Any,0})
@eval begin
function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U}
return $jlop(to_float(x), to_float(y))
function Base.isapprox(x::$(T), y::Number; kwargs...)
return Base.isapprox(to_number(x), y; kwargs...)
end
function $jlop(x::ConcreteRArray{T,0}, y) where {T}
return $jlop(to_float(x), y)

function Base.isapprox(x::Number, y::$(T); kwargs...)
return Base.isapprox(x, to_number(y); kwargs...)
end
function $jlop(x, y::ConcreteRArray{U,0}) where {U}
return $jlop(x, to_float(y))

function Base.isapprox(x::$(T), y::$(T); kwargs...)
return Base.isapprox(to_number(x), to_number(y); kwargs...)
end
end
end

function Base.isapprox(x::ConcreteRArray{T,0}, y; kwargs...) where {T}
return Base.isapprox(to_float(x), y; kwargs...)
end

function Base.isapprox(x, y::ConcreteRArray{T,0}; kwargs...) where {T}
return Base.isapprox(x, to_float(y); kwargs...)
end

function Base.isapprox(
x::ConcreteRArray{T,0}, y::ConcreteRArray{T2,0}; kwargs...
) where {T,T2}
return Base.isapprox(to_float(x), to_float(y); kwargs...)
function Base.show(io::IO, X::ConcreteRScalar{T}) where {T}
if X.data == XLA.AsyncEmptyBuffer
println(io, "<Empty buffer>")
return nothing
end
str = sprint(show, to_number(X))
return print(io, "$(typeof(X))($(str))")
end

function Base.print_array(io::IO, X::ConcreteRArray)
Expand All @@ -115,7 +151,8 @@ function Base.show(io::IO, X::ConcreteRArray)
println(io, "<Empty buffer>")
return nothing
end
return Base.show(io, convert(Array, X))
str = sprint(show, convert(Array, X))
return print(io, "$(typeof(X))($(str))")
end

const getindex_warned = Ref(false)
Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ include("Tracing.jl")
include("Compiler.jl")

using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, @compile, @code_hlo, @jit
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
Expand Down
47 changes: 38 additions & 9 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ function traced_type(::Type{T}, seen::ST, ::Val{mode}) where {ST,T<:TracedType,m
elseif mode == TracedToConcrete
@inline base_typec(TV::TT) where {TT<:UnionAll} =
UnionAll(TV.var, base_typec(TV.body))
@inline base_typec(TV::TT) where {TT<:DataType} = ConcreteRArray{TV.parameters...}
@inline base_typec(TV::TT) where {TT<:DataType} =
(T <: TracedRArray ? ConcreteRArray : ConcreteRNumber){TV.parameters...}
return base_typec(T)
elseif mode == TracedTrack || mode == TracedSetPath
return T
Expand Down Expand Up @@ -232,6 +233,7 @@ function make_tracer(
mode;
toscalar=false,
tobatch=nothing,
kwargs...,
) where {RT}
if haskey(seen, prev)
return seen[prev]
Expand Down Expand Up @@ -308,13 +310,29 @@ function make_tracer(
return res
end

function make_tracer(seen, prev::ConcreteRNumber{T}, path, mode; kwargs...) where {T}
if mode == ArrayToConcrete
return prev
end
if mode != ConcreteToTraced
throw("Cannot trace existing trace type")
end
if haskey(seen, prev)
return seen[prev]::TracedRNumber{T}
end
res = TracedRNumber{T}((path,), nothing)
seen[prev] = res
return res
end

function make_tracer(
seen,
@nospecialize(prev::TracedRArray{T,N}),
@nospecialize(path),
mode;
toscalar=false,
tobatch=nothing,
kwargs...,
) where {T,N}
if mode == ConcreteToTraced
throw("Cannot trace existing trace type")
Expand Down Expand Up @@ -389,9 +407,9 @@ function make_tracer(

if mode == TracedToConcrete
if haskey(seen, prev)
return seen[prev]::ConcreteRArray{T,0}
return seen[prev]::ConcreteRNumber{T}
end
res = ConcreteRArray{T,0}(XLA.AsyncEmptyBuffer, size(prev))
res = ConcreteRNumber{T}(XLA.AsyncEmptyBuffer)
seen[prev] = res
return res
end
Expand All @@ -400,8 +418,13 @@ function make_tracer(
end

function make_tracer(
seen, @nospecialize(prev::RT), @nospecialize(path), mode; kwargs...
) where {RT<:AbstractFloat}
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
) where {RT<:Number}
if mode == ArrayToConcrete
length(track_numbers) == 0 && return prev
should_convert = any(Base.Fix1(<:, RT), track_numbers)
return should_convert ? ConcreteRNumber(prev) : prev
end
return prev
end

Expand All @@ -414,10 +437,15 @@ function make_tracer(
mode;
toscalar=false,
tobatch=nothing,
kwargs...,
) where {RT}
return Complex(
make_tracer(seen, prev.re, append_path(path, :re), mode; toscalar, tobatch),
make_tracer(seen, prev.im, append_path(path, :im), mode; toscalar, tobatch),
make_tracer(
seen, prev.re, append_path(path, :re), mode; toscalar, tobatch, kwargs...
),
make_tracer(
seen, prev.im, append_path(path, :im), mode; toscalar, tobatch, kwargs...
),
)
end

Expand Down Expand Up @@ -489,8 +517,9 @@ function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
return res
end

@inline function to_rarray(@nospecialize(x))
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete)
@inline function to_rarray(@nospecialize(x); track_numbers::Union{Bool,Tuple}=())
track_numbers isa Bool && (track_numbers = track_numbers ? (Number,) : ())
return make_tracer(OrderedIdDict(), x, (), Reactant.ArrayToConcrete; track_numbers)
end

to_rarray(x::ReactantPrimitive) = ConcreteRArray(x)
47 changes: 45 additions & 2 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,7 @@ sum_compare(x) = sum(x) > 0
# Ensure we are tracing as scalars. Else this will fail due to > not being defined on
# arrays
f = @compile sum_compare(a)
# We need to use [] to unwrap the scalar. We will fix this in the future.
@test f(a)[] == sum_compare(x)
@test f(a) == sum_compare(x)
end

function mysoftmax!(x)
Expand Down Expand Up @@ -445,3 +444,47 @@ end
c = Reactant.compile(+, (a, b))(a, b)
@test c == ones(CT, 2) + ones(CT, 2)
end

@testset "Scalars" begin
@testset "Only Scalars" begin
x = (3, 3.14)

f1(x) = x[1] * x[2]

x_ra = Reactant.to_rarray(x; track_numbers=(Number,))
f2 = @compile f1(x_ra)
@test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) 5 * 5.2
@test f2(Reactant.to_rarray((5, 5.2); track_numbers=(Number,))) isa ConcreteRNumber

x_ra = Reactant.to_rarray(x)
f3 = @compile f1(x_ra)
@test f3(Reactant.to_rarray((5, 5.2))) f1(x)
@test !(f3(Reactant.to_rarray((5, 5.2))) isa ConcreteRNumber)
@test f3(Reactant.to_rarray((5, 5.2))) isa Number

x_ra = Reactant.to_rarray(x; track_numbers=(Int,))
f4 = @compile f1(x_ra)
@test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) 5 * 3.14
@test f4(Reactant.to_rarray((5, 5.2); track_numbers=(Int,))) isa ConcreteRNumber
end

@testset "Mixed" begin
x = (3, [3.14])

f1(x) = x[1] * x[2]

x_ra = Reactant.to_rarray(x; track_numbers=(Number,))

f2 = @compile f1(x_ra)
res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,)))
@test only(res2) 5 * 3.14
@test res2 isa ConcreteRArray

x_ra = Reactant.to_rarray(x)

f3 = @compile f1(x_ra)
res3 = f3(Reactant.to_rarray((5, [3.14])))
@test only(res3) only(f1(x))
@test res3 isa ConcreteRArray
end
end
2 changes: 1 addition & 1 deletion test/compile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=

f = @compile sum(x2)

@test f(x2) isa @NamedTuple{a::Reactant.ConcreteRArray{Float64,0}}
@test f(x2) isa @NamedTuple{a::Reactant.ConcreteRNumber{Float64}}
@test isapprox(f(x2).a, sum(x.a))
end
end
Expand Down

0 comments on commit fce399c

Please sign in to comment.