Skip to content

Commit

Permalink
feat: use ExpressionsExplorer for parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 28, 2024
1 parent 6842c9e commit 09f78b6
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 127 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Expand All @@ -31,6 +32,7 @@ ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13"
ExpressionExplorer = "1"
MacroTools = "0.5.13"
NNlib = "0.9"
OrderedCollections = "1"
Expand Down
175 changes: 49 additions & 126 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,84 @@ module ControlFlow
using ..Reactant: Reactant, TracedRNumber, TracedRArray
using ..MLIR: MLIR

using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

macro trace(expr)
expr.head == :if && return esc(trace_if(__module__, expr))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
end

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
if x isa Symbol && x all_vars
return Symbol(prepend, x)
end
return x
end
end

function trace_if(mod, expr)
@assert expr.head == :if
@assert length(expr.args) == 3 "`@trace` expects an `else` block for `if` blocks."
# XXX: support `elseif` blocks
@assert expr.args[3].head == :block "`elseif` blocks are not supported yet."

# `var_list` is a list of input variables that are used in the `if` block
# `bound_vars` is a list of variables that are bound in the `if` block
true_branch_var_list = Symbol[]
true_branch_bound_vars = Symbol[]
find_var_uses!(true_branch_var_list, true_branch_bound_vars, expr.args[2])
true_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[2])
true_branch_input_list = [true_branch_symbols.references...]
true_branch_assignments = [true_branch_symbols.assignments...]
true_branch_fn_name = gensym(:true_branch)

false_branch_var_list = Symbol[]
false_branch_bound_vars = Symbol[]
find_var_uses!(false_branch_var_list, false_branch_bound_vars, expr.args[3])
false_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[3])
false_branch_input_list = [false_branch_symbols.references...]
false_branch_assignments = [false_branch_symbols.assignments...]
false_branch_fn_name = gensym(:false_branch)

all_input_vars = true_branch_var_list false_branch_var_list
all_output_vars = true_branch_bound_vars false_branch_bound_vars
all_input_vars = true_branch_input_list false_branch_input_list
all_output_vars = true_branch_assignments false_branch_assignments

all_vars = all_input_vars all_output_vars

true_branch_fn = quote
() -> begin
$(Expr(:meta, :inline))
$(expr.args[2])
return ($(all_output_vars...),)
end
end
true_branch_fn = cleanup_expr_to_avoid_boxing(
true_branch_fn, true_branch_fn_name, all_vars
)
true_branch_fn = quote
$(true_branch_fn_name) =
($(all_input_vars...),) -> begin
$(Expr(:meta, :inline))
$(expr.args[2])
return ($(all_output_vars...),)
let $(map(Base.Fix2(makelet, true_branch_fn_name), all_input_vars)...)
$(true_branch_fn)
end
end

true_branch_fn = MacroTools.prewalk(true_branch_fn) do x
if x isa Symbol && x all_output_vars
return Symbol(:true_branch₋, x)
false_branch_fn = quote
() -> begin
$(Expr(:meta, :inline))
$(expr.args[3])
return ($(all_output_vars...),)
end
return x
end

false_branch_fn = cleanup_expr_to_avoid_boxing(
false_branch_fn, false_branch_fn_name, all_vars
)
false_branch_fn = quote
$(false_branch_fn_name) =
($(all_input_vars...),) -> begin
$(Expr(:meta, :inline))
$(expr.args[3])
return ($(all_output_vars...),)
let $(map(Base.Fix2(makelet, false_branch_fn_name), all_input_vars)...)
$(false_branch_fn)
end
end

false_branch_fn = MacroTools.prewalk(false_branch_fn) do x
if x isa Symbol && x all_output_vars
return Symbol(:false_branch₋, x)
end
return x
end

return quote
if any($(is_traced), ($(all_input_vars...),))
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(expr.args[1]),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
$(expr.args[1]), $(true_branch_fn_name), $(false_branch_fn_name)
)
else
$(expr)
Expand All @@ -83,28 +92,20 @@ is_traced(x) = false
is_traced(::TracedRArray) = true
is_traced(::TracedRNumber) = true

makelet(x) = :($(x) = $(x))
makelet(x, prepend::Symbol) = :($(Symbol(prepend, x)) = $(x))

# Generate this dummy function and later we remove it during tracing
function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
if cond
return true_fn(args...)
else
return false_fn(args...)
end
function traced_if(cond, true_fn::TFn, false_fn::FFn) where {TFn,FFn}
return cond ? true_fn() : false_fn()
end

function traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
function traced_if(cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn) where {TFn,FFn}
_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results = Reactant.make_mlir_fn(
true_fn, args, (), string(gensym("true_branch")), false;
return_dialect=:stablehlo
true_fn, (), (), string(gensym("true_branch")), false; return_dialect=:stablehlo
)

_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results = Reactant.make_mlir_fn(
false_fn, args, (), string(gensym("false_branch")), false;
return_dialect=:stablehlo
false_fn, (), (), string(gensym("false_branch")), false; return_dialect=:stablehlo
)

@assert length(true_branch_results) == length(false_branch_results) "true branch returned $(length(true_branch_results)) results, false branch returned $(length(false_branch_results)). This shouldn't happen."
Expand Down Expand Up @@ -145,84 +146,6 @@ function traced_if(
end
end

# XXX: Use `ExpressionExplorer.jl` instead
# NOTE: Adapted from https://github.com/c42f/FastClosures.jl/blob/master/src/FastClosures.jl
function find_var_uses!(varlist, bound_vars, ex)
if isa(ex, Symbol)
var = ex
if !(var in bound_vars)
var varlist || push!(varlist, var)
end
return varlist
elseif isa(ex, Expr)
if ex.head == :quote || ex.head == :line || ex.head == :inbounds
return varlist
end
if ex.head == :(=)
find_var_uses_lhs!(varlist, bound_vars, ex.args[1])
find_var_uses!(varlist, bound_vars, ex.args[2])
elseif ex.head == :kw
find_var_uses!(varlist, bound_vars, ex.args[2])
elseif ex.head == :for ||
ex.head == :while ||
ex.head == :comprehension ||
ex.head == :let
# New scopes
inner_bindings = copy(bound_vars)
find_var_uses!(varlist, inner_bindings, ex.args)
elseif ex.head == :try
# New scope + ex.args[2] is a new binding
find_var_uses!(varlist, copy(bound_vars), ex.args[1])
catch_bindings = copy(bound_vars)
!isa(ex.args[2], Symbol) || push!(catch_bindings, ex.args[2])
find_var_uses!(varlist, catch_bindings, ex.args[3])
if length(ex.args) > 3
finally_bindings = copy(bound_vars)
find_var_uses!(varlist, finally_bindings, ex.args[4])
end
elseif ex.head == :call
find_var_uses!(varlist, bound_vars, ex.args[2:end])
elseif ex.head == :local
foreach(ex.args) do e
if !isa(e, Symbol)
find_var_uses!(varlist, bound_vars, e)
end
end
elseif ex.head == :(::)
find_var_uses_lhs!(varlist, bound_vars, ex)
else
find_var_uses!(varlist, bound_vars, ex.args)
end
end
return varlist
end

function find_var_uses!(varlist, bound_vars, exs::Vector)
return foreach(e -> find_var_uses!(varlist, bound_vars, e), exs)
end

# Find variable uses on the left hand side of an assignment. Some of what may
# be variable uses turn into bindings in this context (cf. tuple unpacking).
function find_var_uses_lhs!(varlist, bound_vars, ex)
if isa(ex, Symbol)
var = ex
var bound_vars || push!(bound_vars, var)
elseif isa(ex, Expr)
if ex.head == :tuple
find_var_uses_lhs!(varlist, bound_vars, ex.args)
elseif ex.head == :(::)
find_var_uses!(varlist, bound_vars, ex.args[2])
find_var_uses_lhs!(varlist, bound_vars, ex.args[1])
else
find_var_uses!(varlist, bound_vars, ex.args)
end
end
end

function find_var_uses_lhs!(varlist, bound_vars, exs::Vector)
return foreach(e -> find_var_uses_lhs!(varlist, bound_vars, e), exs)
end

export @trace

end
4 changes: 3 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ function make_mlir_fn(
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
return (
true,
make_mlir_fn(apply, (f, args...), kwargs, name, concretein; toscalar)[2:end]...,
make_mlir_fn(
apply, (f, args...), kwargs, name, concretein; toscalar, return_dialect
)[2:end]...,
)
end

Expand Down

0 comments on commit 09f78b6

Please sign in to comment.