diff --git a/src/Compiler.jl b/src/Compiler.jl index d71eabaa..081819c0 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -421,6 +421,30 @@ macro compile(options, maybe_call=nothing) end end +""" + @jit f(args...) + + Run @compile f(args..) then immediately execute it +""" +macro jit(options, maybe_call=nothing) + call = something(maybe_call, options) + options = isnothing(maybe_call) ? :(optimize = true) : options + Meta.isexpr(call, :call) || error("@compile: expected call, got $call") + if !Meta.isexpr(options, :(=)) || options.args[1] != :optimize + error("@compile: expected options in format optimize=value, got $options") + end + + options = Expr(:tuple, Expr(:parameters, Expr(:kw, options.args...))) + + quote + options = $(esc(options)) + f = $(esc(call.args[1])) + args = $(esc(Expr(:tuple, call.args[2:end]...))) + fn = compile(f, args; options.optimize) + fn(args...) + end +end + """ codegen_flatten! diff --git a/src/Reactant.jl b/src/Reactant.jl index 82bf06ea..4ae7852b 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -93,8 +93,8 @@ const TracedType = Union{TracedRArray,TracedRNumber} include("Tracing.jl") include("Compiler.jl") -using .Compiler: @compile, @code_hlo, traced_getfield, create_result, compile -export ConcreteRArray, @compile, @code_hlo +using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile +export ConcreteRArray, @compile, @code_hlo, @jit const registry = Ref{MLIR.IR.DialectRegistry}() function __init__()