Skip to content

Commit

Permalink
test: test if conditions
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent e151a9e commit fad1819
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 2 deletions.
58 changes: 56 additions & 2 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,47 @@ end
MissingTracedValue() = MissingTracedValue(())

# Code generation
"""
@trace <expr>
Converts certain expressions like control flow into a Reactant friendly form. Importantly,
if no traced value is found inside the expression, then there is no overhead.
## Currently Supported
- `if` conditions (with `elseif` and other niceties)
# Extended Help
## Caveats (Deviations from Core Julia Semantics)
### New variables introduced
```julia
@trace if x > 0
y = x + 1
p = 1
else
y = x - 1
end
```
In the outer scope `p` is not defined if `x ≤ 0`. However, for the traced version, it is
defined and set to a dummy value.
### Short Circuiting Operations
```julia
@trace if x > 0 && z > 0
y = x + 1
else
y = x - 1
end
```
`&&` and `||` are short circuiting operations. In the traced version, we replace them with
`&` and `|` respectively.
"""
macro trace(expr)
expr.head == :if && return esc(trace_if(__module__, expr))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
Expand All @@ -24,7 +65,8 @@ end
function trace_if(mod, expr)
expr.head == :if && error_if_return(expr)

condition_vars = [ExpressionExplorer.compute_symbols_state(expr.args[1]).references...]
cond_expr = remove_shortcircuiting(expr.args[1])
condition_vars = [ExpressionExplorer.compute_symbols_state(cond_expr).references...]

true_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[2])
true_branch_input_list = [true_branch_symbols.references...]
Expand Down Expand Up @@ -101,7 +143,7 @@ function trace_if(mod, expr)
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(expr.args[1]),
$(cond_expr),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
Expand All @@ -112,6 +154,7 @@ function trace_if(mod, expr)
return reactant_code_block, (true_branch_fn_name, false_branch_fn_name)

all_check_vars = [all_input_vars..., condition_vars...]
unique!(all_check_vars)
return quote
if any($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
Expand All @@ -121,6 +164,17 @@ function trace_if(mod, expr)
end
end

function remove_shortcircuiting(expr)
return MacroTools.prewalk(expr) do x
if MacroTools.@capture(x, a_ && b_)
return :($a & $b)
elseif MacroTools.@capture(x, a_ || b_)
return :($a | $b)
end
return x
end
end

# Generate this dummy function and later we remove it during tracing
function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
return cond ? true_fn(args) : false_fn(args)
Expand Down
10 changes: 10 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ function Base.ifelse(
)
end

Base.:&(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) = x * y
Base.:|(x::TracedRNumber{Bool}, y::TracedRNumber{Bool}) = x + y
function Base.:!(x::TracedRNumber{Bool})
true_val = promote_to(TracedRNumber{Bool}, true)
return TracedRNumber{Bool}(
(),
MLIR.IR.result(MLIR.Dialects.stablehlo.xor(x.mlir_data, true_val.mlir_data), 1),
)
end

function Base.literal_pow(
::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}}
) where {T,P}
Expand Down
267 changes: 267 additions & 0 deletions test/control_flow.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,267 @@
using Reactant, Test

function condition1(x)
y = sum(x)
@trace if y > 0
z = y + 1
else
z = y - 1
end
return z
end

@testset "condition1" begin
x = rand(2, 10)
x_ra = Reactant.to_rarray(x)

@test @jit(condition1(x_ra)) condition1(x)

x = -rand(2, 10)
x_ra = Reactant.to_rarray(x)

@test @jit(condition1(x_ra)) condition1(x)
end

function condition1_missing_var(x)
y = sum(x)
@trace if y > 0
z = y + 1
p = -1
else
z = y - 1
end
return z
end

@testset "condition1_missing_var" begin
x = rand(2, 10)
x_ra = Reactant.to_rarray(x)

@test @jit(condition1_missing_var(x_ra)) condition1_missing_var(x)

x = -rand(2, 10)
x_ra = Reactant.to_rarray(x)

@test @jit(condition1_missing_var(x_ra)) condition1_missing_var(x)
end

@testset "return not supported" begin
@test_throws LoadError @eval @trace if x > 0
return 1
end
end

function condition2_nested_if(x, y)
x_sum = sum(x)
@trace if x_sum > 0
y_sum = sum(y)
@trace if y_sum > 0
z = x_sum + y_sum
else
z = x_sum - y_sum
end
else
y_sum = sum(y)
z = x_sum - y_sum
end
return z
end

function condition2_if_else_if(x, y)
x_sum = sum(x)
y_sum = sum(y)
@trace if x_sum > 0 && y_sum > 0
z = x_sum + y_sum
elseif x_sum > 0
z = x_sum - y_sum
else
z = y_sum - x_sum
end
return z
end

@testset "condition2: multiple conditions" begin
x = rand(2, 10)
y = rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(condition2_nested_if(x_ra, y_ra)) condition2_nested_if(x, y) broken = true
@test @jit(condition2_if_else_if(x_ra, y_ra)) condition2_if_else_if(x, y)

y = -rand(2, 10)
y_ra = Reactant.to_rarray(y)

@test @jit(condition2_nested_if(x_ra, y_ra)) condition2_nested_if(x, y) broken = true
@test @jit(condition2_if_else_if(x_ra, y_ra)) condition2_if_else_if(x, y)

x = -rand(2, 10)
y = -rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(condition2_nested_if(x_ra, y_ra)) condition2_nested_if(x, y)
@test @jit(condition2_if_else_if(x_ra, y_ra)) condition2_if_else_if(x, y)
end

function condition3_mixed_conditions(x, y)
x_sum = sum(x)
y_sum = sum(y)
@trace if x_sum > 0 && y_sum > 0
z = x_sum + y_sum
else
z = -(x_sum + y_sum)
end
return z
end

@testset "condition3: mixed conditions" begin
x = rand(2, 10)
y = rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(condition3_mixed_conditions(x_ra, y_ra)) condition3_mixed_conditions(x, y)

x = -rand(2, 10)
y = -rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(condition3_mixed_conditions(x_ra, y_ra)) condition3_mixed_conditions(x, y)

x = rand(2, 10)
y = -rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)
@test @jit(condition3_mixed_conditions(x_ra, y_ra)) condition3_mixed_conditions(x, y)

y = rand(2, 10)
z = -rand(2, 10)
y_ra = Reactant.to_rarray(y)
z_ra = Reactant.to_rarray(z)
@test @jit(condition3_mixed_conditions(x_ra, y_ra)) condition3_mixed_conditions(x, y)
end

function condition4_mixed_conditions(x, y)
x_sum = sum(x)
y_sum = sum(y)
@trace if x_sum > 0 || y_sum > 0 && !(y_sum > 0)
z = x_sum + y_sum
p = 1
else
z = -(x_sum + y_sum)
p = -1
end
return z
end

@testset "condition4: mixed conditions" begin
x = rand(2, 10)
y = rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(condition4_mixed_conditions(x_ra, y_ra)) condition4_mixed_conditions(x, y)

x = -rand(2, 10)
y = -rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

@test @jit(condition4_mixed_conditions(x_ra, y_ra)) condition4_mixed_conditions(x, y)

x = rand(2, 10)
y = -rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)
@test @jit(condition4_mixed_conditions(x_ra, y_ra)) condition4_mixed_conditions(x, y)

y = rand(2, 10)
z = -rand(2, 10)
y_ra = Reactant.to_rarray(y)
z_ra = Reactant.to_rarray(z)
@test @jit(condition4_mixed_conditions(x_ra, y_ra)) condition4_mixed_conditions(x, y)
end

function condition5_multiple_returns(x, y)
x_sum = sum(x)
y_sum = sum(y)
@trace if x_sum > 0
z = x_sum + y_sum
p = 1
else
z = -(x_sum + y_sum)
p = -1
end
return z, p
end

@testset "condition5: multiple returns" begin
x = rand(2, 10)
y = rand(2, 10)
x_ra = Reactant.to_rarray(x)
y_ra = Reactant.to_rarray(y)

res_ra = @jit(condition5_multiple_returns(x_ra, y_ra))
res = condition5_multiple_returns(x, y)
@test res_ra[1] res[1]
@test res_ra[2] res[2]
end

function condition6_bareif_relu(x)
@trace if x < 0
x = 0.0
end
return x
end

@testset "condition6: bareif relu" begin
x = 2.0
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))

res_ra = @jit(condition6_bareif_relu(x_ra))
res = condition6_bareif_relu(x)
@test res_ra res

x = -2.0
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))

res_ra = @jit(condition6_bareif_relu(x_ra))
res = condition6_bareif_relu(x)
@test res_ra res
end

function condition7_bare_elseif(x)
@trace if x > 0
x = x + 1
elseif x < 0
x = x - 1
elseif x == 0
x = x
end
return x
end

@testset "condition7: bare elseif" begin
x = 2.0
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))

res_ra = @jit(condition7_bare_elseif(x_ra))
res = condition7_bare_elseif(x)
@test res_ra res

x = -2.0
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))

res_ra = @jit(condition7_bare_elseif(x_ra))
res = condition7_bare_elseif(x)
@test res_ra res

x = 0.0
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))

res_ra = @jit(condition7_bare_elseif(x_ra))
res = condition7_bare_elseif(x)
@test res_ra res
end

0 comments on commit fad1819

Please sign in to comment.