Skip to content

Commit

Permalink
feat: support elseif
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 29, 2024
1 parent 8ab051f commit aa7a26b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
39 changes: 24 additions & 15 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,33 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
end

function trace_if(mod, expr)
@assert expr.head == :if
@assert length(expr.args) == 3 "`@trace` expects an `else` block for `if` blocks."
# XXX: support `elseif` blocks
@assert expr.args[3].head == :block "`elseif` blocks are not supported yet."

true_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[2])
true_branch_input_list = [true_branch_symbols.references...]
true_branch_assignments = [true_branch_symbols.assignments...]
true_branch_fn_name = gensym(:true_branch)

false_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[3])
else_block, discard_vars = if expr.args[3].head != :elseif
expr.args[3], nothing
else
trace_if(mod, expr.args[3])
end

false_branch_symbols = ExpressionExplorer.compute_symbols_state(else_block)
false_branch_input_list = [false_branch_symbols.references...]
false_branch_assignments = [false_branch_symbols.assignments...]
false_branch_fn_name = gensym(:false_branch)

all_input_vars = true_branch_input_list false_branch_input_list
all_output_vars = true_branch_assignments false_branch_assignments
discard_vars !== nothing && setdiff!(all_output_vars, discard_vars)

all_vars = all_input_vars all_output_vars

true_branch_fn = quote
$(true_branch_fn_name) =
($(all_input_vars...),) -> begin
$(Expr(:meta, :inline))
$(expr.args[2])
return ($(all_output_vars...),)
end
Expand All @@ -56,25 +59,31 @@ function trace_if(mod, expr)
false_branch_fn = quote
$(false_branch_fn_name) =
($(all_input_vars...),) -> begin
$(Expr(:meta, :inline))
$(expr.args[3])
$(else_block)
return ($(all_output_vars...),)
end
end
false_branch_fn = cleanup_expr_to_avoid_boxing(
false_branch_fn, false_branch_fn_name, all_vars
)

reactant_code_block = quote
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(expr.args[1]),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
)
end

expr.head != :if &&
return reactant_code_block, (true_branch_fn_name, false_branch_fn_name)

return quote
if any($(is_traced), ($(all_input_vars...),))
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(expr.args[1]),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
)
$(reactant_code_block)
else
$(expr)
end
Expand Down
10 changes: 10 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,20 @@ for (jlop, hloop, hlocomp) in (
function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}
return $(jlop)(lhs, promote_to(lhs, rhs))
end
function $(jlop)(
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number)
) where {T}
return $(jlop)(lhs, promote_to(lhs, rhs))
end

function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T}
return $(jlop)(promote_to(rhs, lhs), rhs)
end
function $(jlop)(
@nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T})
) where {T}
return $(jlop)(promote_to(rhs, lhs), rhs)
end

function $(jlop)(
@nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2})
Expand Down

0 comments on commit aa7a26b

Please sign in to comment.