From aa7a26bc822b24e6b81546b2043a6639e1c6e336 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 29 Oct 2024 12:31:53 -0400 Subject: [PATCH] feat: support elseif --- src/ControlFlow.jl | 39 ++++++++++++++++++++++++--------------- src/TracedRNumber.jl | 10 ++++++++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/ControlFlow.jl b/src/ControlFlow.jl index 3f44d0d9..ad4ecef9 100644 --- a/src/ControlFlow.jl +++ b/src/ControlFlow.jl @@ -21,30 +21,33 @@ function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars) end function trace_if(mod, expr) - @assert expr.head == :if @assert length(expr.args) == 3 "`@trace` expects an `else` block for `if` blocks." - # XXX: support `elseif` blocks - @assert expr.args[3].head == :block "`elseif` blocks are not supported yet." true_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[2]) true_branch_input_list = [true_branch_symbols.references...] true_branch_assignments = [true_branch_symbols.assignments...] true_branch_fn_name = gensym(:true_branch) - false_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[3]) + else_block, discard_vars = if expr.args[3].head != :elseif + expr.args[3], nothing + else + trace_if(mod, expr.args[3]) + end + + false_branch_symbols = ExpressionExplorer.compute_symbols_state(else_block) false_branch_input_list = [false_branch_symbols.references...] false_branch_assignments = [false_branch_symbols.assignments...] false_branch_fn_name = gensym(:false_branch) all_input_vars = true_branch_input_list ∪ false_branch_input_list all_output_vars = true_branch_assignments ∪ false_branch_assignments + discard_vars !== nothing && setdiff!(all_output_vars, discard_vars) all_vars = all_input_vars ∪ all_output_vars true_branch_fn = quote $(true_branch_fn_name) = ($(all_input_vars...),) -> begin - $(Expr(:meta, :inline)) $(expr.args[2]) return ($(all_output_vars...),) end @@ -56,8 +59,7 @@ function trace_if(mod, expr) false_branch_fn = quote $(false_branch_fn_name) = ($(all_input_vars...),) -> begin - $(Expr(:meta, :inline)) - $(expr.args[3]) + $(else_block) return ($(all_output_vars...),) end end @@ -65,16 +67,23 @@ function trace_if(mod, expr) false_branch_fn, false_branch_fn_name, all_vars ) + reactant_code_block = quote + $(true_branch_fn) + $(false_branch_fn) + ($(all_output_vars...),) = $(traced_if)( + $(expr.args[1]), + $(true_branch_fn_name), + $(false_branch_fn_name), + ($(all_input_vars...),), + ) + end + + expr.head != :if && + return reactant_code_block, (true_branch_fn_name, false_branch_fn_name) + return quote if any($(is_traced), ($(all_input_vars...),)) - $(true_branch_fn) - $(false_branch_fn) - ($(all_output_vars...),) = $(traced_if)( - $(expr.args[1]), - $(true_branch_fn_name), - $(false_branch_fn_name), - ($(all_input_vars...),), - ) + $(reactant_code_block) else $(expr) end diff --git a/src/TracedRNumber.jl b/src/TracedRNumber.jl index 579f73e1..6b470ab4 100644 --- a/src/TracedRNumber.jl +++ b/src/TracedRNumber.jl @@ -145,10 +145,20 @@ for (jlop, hloop, hlocomp) in ( function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T} return $(jlop)(lhs, promote_to(lhs, rhs)) end + function $(jlop)( + @nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::Number) + ) where {T} + return $(jlop)(lhs, promote_to(lhs, rhs)) + end function $(jlop)(@nospecialize(lhs), @nospecialize(rhs::TracedRNumber{T})) where {T} return $(jlop)(promote_to(rhs, lhs), rhs) end + function $(jlop)( + @nospecialize(lhs::Number), @nospecialize(rhs::TracedRNumber{T}) + ) where {T} + return $(jlop)(promote_to(rhs, lhs), rhs) + end function $(jlop)( @nospecialize(lhs::TracedRNumber{T1}), @nospecialize(rhs::TracedRNumber{T2})