diff --git a/lib/ReactantCore/src/ReactantCore.jl b/lib/ReactantCore/src/ReactantCore.jl index b67668e4..76d20582 100644 --- a/lib/ReactantCore/src/ReactantCore.jl +++ b/lib/ReactantCore/src/ReactantCore.jl @@ -16,6 +16,47 @@ end MissingTracedValue() = MissingTracedValue(()) # Code generation +""" + @trace + +Converts certain expressions like control flow into a Reactant friendly form. Importantly, +if no traced value is found inside the expression, then there is no overhead. + +## Currently Supported + +- `if` conditions (with `elseif` and other niceties) + +# Extended Help + +## Caveats (Deviations from Core Julia Semantics) + +### New variables introduced + +```julia +@trace if x > 0 + y = x + 1 + p = 1 +else + y = x - 1 +end +``` + +In the outer scope `p` is not defined if `x ≤ 0`. However, for the traced version, it is +defined and set to a dummy value. + +### Short Circuiting Operations + +```julia +@trace if x > 0 && z > 0 + y = x + 1 +else + y = x - 1 +end +``` + +`&&` and `||` are short circuiting operations. In the traced version, we replace them with +`&` and `|` respectively. +""" macro trace(expr) expr.head == :if && return esc(trace_if(__module__, expr)) return error("Only `if-elseif-else` blocks are currently supported by `@trace`") @@ -24,7 +65,8 @@ end function trace_if(mod, expr) expr.head == :if && error_if_return(expr) - condition_vars = [ExpressionExplorer.compute_symbols_state(expr.args[1]).references...] + cond_expr = remove_shortcircuiting(expr.args[1]) + condition_vars = [ExpressionExplorer.compute_symbols_state(cond_expr).references...] true_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[2]) true_branch_input_list = [true_branch_symbols.references...] @@ -101,7 +143,7 @@ function trace_if(mod, expr) $(true_branch_fn) $(false_branch_fn) ($(all_output_vars...),) = $(traced_if)( - $(expr.args[1]), + $(cond_expr), $(true_branch_fn_name), $(false_branch_fn_name), ($(all_input_vars...),), @@ -112,6 +154,7 @@ function trace_if(mod, expr) return reactant_code_block, (true_branch_fn_name, false_branch_fn_name) all_check_vars = [all_input_vars..., condition_vars...] + unique!(all_check_vars) return quote if any($(is_traced), ($(all_check_vars...),)) $(reactant_code_block) @@ -121,6 +164,17 @@ function trace_if(mod, expr) end end +function remove_shortcircuiting(expr) + return MacroTools.prewalk(expr) do x + if MacroTools.@capture(x, a_ && b_) + return :($a & $b) + elseif MacroTools.@capture(x, a_ || b_) + return :($a | $b) + end + return x + end +end + # Generate this dummy function and later we remove it during tracing function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn} return cond ? true_fn(args) : false_fn(args) diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index b564c4a3..d088fc14 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -188,6 +188,16 @@ function Base.ifelse( ) end +Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) = x * y +Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) = x + y +function Base.:!(x::TracedRNumber{Bool}) + true_val = promote_to(TracedRNumber{Bool}, true) + return TracedRNumber{Bool}( + (), + MLIR.IR.result(MLIR.Dialects.stablehlo.xor(x.mlir_data, true_val.mlir_data), 1), + ) +end + function Base.literal_pow( ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}} ) where {T,P} diff --git a/test/control_flow.jl b/test/control_flow.jl index e69de29b..57131c3d 100644 --- a/test/control_flow.jl +++ b/test/control_flow.jl @@ -0,0 +1,267 @@ +using Reactant, Test + +function condition1(x) + y = sum(x) + @trace if y > 0 + z = y + 1 + else + z = y - 1 + end + return z +end + +@testset "condition1" begin + x = rand(2, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(condition1(x_ra)) ≈ condition1(x) + + x = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(condition1(x_ra)) ≈ condition1(x) +end + +function condition1_missing_var(x) + y = sum(x) + @trace if y > 0 + z = y + 1 + p = -1 + else + z = y - 1 + end + return z +end + +@testset "condition1_missing_var" begin + x = rand(2, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(condition1_missing_var(x_ra)) ≈ condition1_missing_var(x) + + x = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + + @test @jit(condition1_missing_var(x_ra)) ≈ condition1_missing_var(x) +end + +@testset "return not supported" begin + @test_throws LoadError @eval @trace if x > 0 + return 1 + end +end + +function condition2_nested_if(x, y) + x_sum = sum(x) + @trace if x_sum > 0 + y_sum = sum(y) + @trace if y_sum > 0 + z = x_sum + y_sum + else + z = x_sum - y_sum + end + else + y_sum = sum(y) + z = x_sum - y_sum + end + return z +end + +function condition2_if_else_if(x, y) + x_sum = sum(x) + y_sum = sum(y) + @trace if x_sum > 0 && y_sum > 0 + z = x_sum + y_sum + elseif x_sum > 0 + z = x_sum - y_sum + else + z = y_sum - x_sum + end + return z +end + +@testset "condition2: multiple conditions" begin + x = rand(2, 10) + y = rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition2_nested_if(x_ra, y_ra)) ≈ condition2_nested_if(x, y) broken = true + @test @jit(condition2_if_else_if(x_ra, y_ra)) ≈ condition2_if_else_if(x, y) + + y = -rand(2, 10) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition2_nested_if(x_ra, y_ra)) ≈ condition2_nested_if(x, y) broken = true + @test @jit(condition2_if_else_if(x_ra, y_ra)) ≈ condition2_if_else_if(x, y) + + x = -rand(2, 10) + y = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition2_nested_if(x_ra, y_ra)) ≈ condition2_nested_if(x, y) + @test @jit(condition2_if_else_if(x_ra, y_ra)) ≈ condition2_if_else_if(x, y) +end + +function condition3_mixed_conditions(x, y) + x_sum = sum(x) + y_sum = sum(y) + @trace if x_sum > 0 && y_sum > 0 + z = x_sum + y_sum + else + z = -(x_sum + y_sum) + end + return z +end + +@testset "condition3: mixed conditions" begin + x = rand(2, 10) + y = rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition3_mixed_conditions(x_ra, y_ra)) ≈ condition3_mixed_conditions(x, y) + + x = -rand(2, 10) + y = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition3_mixed_conditions(x_ra, y_ra)) ≈ condition3_mixed_conditions(x, y) + + x = rand(2, 10) + y = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + @test @jit(condition3_mixed_conditions(x_ra, y_ra)) ≈ condition3_mixed_conditions(x, y) + + y = rand(2, 10) + z = -rand(2, 10) + y_ra = Reactant.to_rarray(y) + z_ra = Reactant.to_rarray(z) + @test @jit(condition3_mixed_conditions(x_ra, y_ra)) ≈ condition3_mixed_conditions(x, y) +end + +function condition4_mixed_conditions(x, y) + x_sum = sum(x) + y_sum = sum(y) + @trace if x_sum > 0 || y_sum > 0 && !(y_sum > 0) + z = x_sum + y_sum + p = 1 + else + z = -(x_sum + y_sum) + p = -1 + end + return z +end + +@testset "condition4: mixed conditions" begin + x = rand(2, 10) + y = rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition4_mixed_conditions(x_ra, y_ra)) ≈ condition4_mixed_conditions(x, y) + + x = -rand(2, 10) + y = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + @test @jit(condition4_mixed_conditions(x_ra, y_ra)) ≈ condition4_mixed_conditions(x, y) + + x = rand(2, 10) + y = -rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + @test @jit(condition4_mixed_conditions(x_ra, y_ra)) ≈ condition4_mixed_conditions(x, y) + + y = rand(2, 10) + z = -rand(2, 10) + y_ra = Reactant.to_rarray(y) + z_ra = Reactant.to_rarray(z) + @test @jit(condition4_mixed_conditions(x_ra, y_ra)) ≈ condition4_mixed_conditions(x, y) +end + +function condition5_multiple_returns(x, y) + x_sum = sum(x) + y_sum = sum(y) + @trace if x_sum > 0 + z = x_sum + y_sum + p = 1 + else + z = -(x_sum + y_sum) + p = -1 + end + return z, p +end + +@testset "condition5: multiple returns" begin + x = rand(2, 10) + y = rand(2, 10) + x_ra = Reactant.to_rarray(x) + y_ra = Reactant.to_rarray(y) + + res_ra = @jit(condition5_multiple_returns(x_ra, y_ra)) + res = condition5_multiple_returns(x, y) + @test res_ra[1] ≈ res[1] + @test res_ra[2] ≈ res[2] +end + +function condition6_bareif_relu(x) + @trace if x < 0 + x = 0.0 + end + return x +end + +@testset "condition6: bareif relu" begin + x = 2.0 + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + res_ra = @jit(condition6_bareif_relu(x_ra)) + res = condition6_bareif_relu(x) + @test res_ra ≈ res + + x = -2.0 + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + res_ra = @jit(condition6_bareif_relu(x_ra)) + res = condition6_bareif_relu(x) + @test res_ra ≈ res +end + +function condition7_bare_elseif(x) + @trace if x > 0 + x = x + 1 + elseif x < 0 + x = x - 1 + elseif x == 0 + x = x + end + return x +end + +@testset "condition7: bare elseif" begin + x = 2.0 + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + res_ra = @jit(condition7_bare_elseif(x_ra)) + res = condition7_bare_elseif(x) + @test res_ra ≈ res + + x = -2.0 + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + res_ra = @jit(condition7_bare_elseif(x_ra)) + res = condition7_bare_elseif(x) + @test res_ra ≈ res + + x = 0.0 + x_ra = Reactant.to_rarray(x; track_numbers=(Number,)) + + res_ra = @jit(condition7_bare_elseif(x_ra)) + res = condition7_bare_elseif(x) + @test res_ra ≈ res +end