Skip to content

Commit

Permalink
fix: replace args
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 28, 2024
1 parent 09f78b6 commit 54fa319
Showing 1 changed file with 32 additions and 30 deletions.
62 changes: 32 additions & 30 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,45 +42,38 @@ function trace_if(mod, expr)
all_vars = all_input_vars all_output_vars

true_branch_fn = quote
() -> begin
$(Expr(:meta, :inline))
$(expr.args[2])
return ($(all_output_vars...),)
end
$(true_branch_fn_name) =
($(all_input_vars...),) -> begin
$(Expr(:meta, :inline))
$(expr.args[2])
return ($(all_output_vars...),)
end
end
true_branch_fn = cleanup_expr_to_avoid_boxing(
true_branch_fn, true_branch_fn_name, all_vars
)
true_branch_fn = quote
$(true_branch_fn_name) =
let $(map(Base.Fix2(makelet, true_branch_fn_name), all_input_vars)...)
$(true_branch_fn)
end
end

false_branch_fn = quote
() -> begin
$(Expr(:meta, :inline))
$(expr.args[3])
return ($(all_output_vars...),)
end
$(false_branch_fn_name) =
($(all_input_vars...),) -> begin
$(Expr(:meta, :inline))
$(expr.args[3])
return ($(all_output_vars...),)
end
end
false_branch_fn = cleanup_expr_to_avoid_boxing(
false_branch_fn, false_branch_fn_name, all_vars
)
false_branch_fn = quote
$(false_branch_fn_name) =
let $(map(Base.Fix2(makelet, false_branch_fn_name), all_input_vars)...)
$(false_branch_fn)
end
end

return quote
if any($(is_traced), ($(all_input_vars...),))
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(expr.args[1]), $(true_branch_fn_name), $(false_branch_fn_name)
$(expr.args[1]),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
)
else
$(expr)
Expand All @@ -95,19 +88,28 @@ is_traced(::TracedRNumber) = true
makelet(x, prepend::Symbol) = :($(Symbol(prepend, x)) = $(x))

# Generate this dummy function and later we remove it during tracing
function traced_if(cond, true_fn::TFn, false_fn::FFn) where {TFn,FFn}
return cond ? true_fn() : false_fn()
function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
return cond ? true_fn(args) : false_fn(args)
end

function traced_if(cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn) where {TFn,FFn}
_, true_branch_compiled, true_branch_results, _, _, _, _, _, true_linear_results = Reactant.make_mlir_fn(
true_fn, (), (), string(gensym("true_branch")), false; return_dialect=:stablehlo
function traced_if(
cond::TracedRNumber{Bool}, true_fn::TFn, false_fn::FFn, args
) where {TFn,FFn}
(_, true_branch_compiled, true_branch_results, _, _, _, true_input_args, _, true_linear_results) = Reactant.make_mlir_fn(
true_fn, args, (), string(gensym("true_branch")), false; return_dialect=:stablehlo
)

_, false_branch_compiled, false_branch_results, _, _, _, _, _, false_linear_results = Reactant.make_mlir_fn(
false_fn, (), (), string(gensym("false_branch")), false; return_dialect=:stablehlo
(_, false_branch_compiled, false_branch_results, _, _, _, false_input_args, _, false_linear_results) = Reactant.make_mlir_fn(
false_fn, args, (), string(gensym("false_branch")), false; return_dialect=:stablehlo
)

for (of, with) in zip(true_input_args, args)
MLIR.API.mlirValueReplaceAllUsesOfWith(of.mlir_data, with.mlir_data)
end
for (of, with) in zip(false_input_args, args)
MLIR.API.mlirValueReplaceAllUsesOfWith(of.mlir_data, with.mlir_data)
end

@assert length(true_branch_results) == length(false_branch_results) "true branch returned $(length(true_branch_results)) results, false branch returned $(length(false_branch_results)). This shouldn't happen."

for (i, (tr, fr)) in enumerate(zip(true_branch_results, false_branch_results))
Expand Down

0 comments on commit 54fa319

Please sign in to comment.