Skip to content

Commit

Permalink
refactor: move code to ReactantCore
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2024
1 parent 212783e commit 8a0c9a7
Show file tree
Hide file tree
Showing 11 changed files with 225 additions and 175 deletions.
15 changes: 11 additions & 4 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,20 @@ steps:
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
- JuliaCI/julia-test#v1:
test_args: "--gpu"
- JuliaCI/julia-coverage#v1:
codecov: true
dirs:
- src
- ext
- lib/ReactantCore/src
commands: |
julia --project=. -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
julia --project=. -e 'println("--- :julia: Run Tests")
using Pkg
Pkg.test(; coverage="user")'
agents:
queue: "juliagpu"
cuda: "*"
Expand All @@ -34,7 +41,7 @@ steps:
command: |
julia --project=benchmark -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop([PackageSpec(path=pwd())])'
Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
julia --project=benchmark -e 'println("--- :julia: Run Benchmarks")
include("benchmark/runbenchmarks.jl")'
Expand All @@ -59,7 +66,7 @@ steps:
command: |
julia --project=benchmark -e 'println("--- :julia: Instantiating project")
using Pkg
Pkg.develop([PackageSpec(path=pwd())])'
Pkg.develop([PackageSpec(path=pwd()), PackageSpec(path="lib/ReactantCore")])'
julia --project=benchmark -e 'println("--- :julia: Run Benchmarks")
include("benchmark/runbenchmarks.jl")'
Expand Down
21 changes: 18 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,19 @@ jobs:
julia --color=yes --project=deps -e 'using Pkg; Pkg.instantiate()'
SDKROOT=`xcrun --show-sdk-path` julia --color=yes --project=deps deps/build_local.jl
cp LocalPreferences.toml test/
- uses: julia-actions/julia-runtest@v1
# if: steps.buildpkg.outcome == 'success'
- name: "Install Dependencies and Run Tests"
run: |
import Pkg
Pkg.Registry.update()
# Install packages present in subdirectories
dev_pks = Pkg.PackageSpec[]
for path in ("lib/ReactantCore",)
push!(dev_pks, Pkg.PackageSpec(; path))
end
Pkg.develop(dev_pks)
Pkg.instantiate()
Pkg.test(; coverage="user")
shell: julia --color=yes --code-coverage=user --depwarn=yes --project=. {0}
id: run_tests
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
Expand All @@ -108,7 +119,11 @@ jobs:
- run: |
julia --color=yes --project=docs -e '
using Pkg
Pkg.develop([PackageSpec(path=pwd()), PackageSpec("Reactant_jll")])
Pkg.develop([
PackageSpec(path=pwd()),
PackageSpec("Reactant_jll"),
PackageSpec(path="lib/ReactantCore")
])
Pkg.instantiate()'
env:
JULIA_PKG_SERVER_REGISTRY_PREFERENCE: eager
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
Expand Down
6 changes: 2 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,10 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Reactant_jll = "0192cb87-2b54-54ad-80e0-3be72ad8a3c0"
Scratch = "6c6a2e73-6563-6170-7368-637461726353"

Expand All @@ -32,11 +31,10 @@ ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13"
ExpressionExplorer = "1"
MacroTools = "0.5.13"
NNlib = "0.9"
OrderedCollections = "1"
Preferences = "1.4"
ReactantCore = "0.1"
Reactant_jll = "0.0.22"
Scratch = "1.2"
Statistics = "1.10"
Expand Down
21 changes: 21 additions & 0 deletions lib/ReactantCore/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2024 Enzyme Automatic Differentiation Compiler

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
13 changes: 13 additions & 0 deletions lib/ReactantCore/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name = "ReactantCore"
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>", "Sergio Sánchez Ramírez <sergio.sanchez.ramirez@bsc.es>", "Paul Berg <paul@plutojl.org>"]
version = "0.1.0"

[deps]
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

[compat]
ExpressionExplorer = "1"
MacroTools = "0.5.13"
julia = "1.10"
147 changes: 147 additions & 0 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
module ReactantCore

using ExpressionExplorer: ExpressionExplorer
using MacroTools: MacroTools

export @trace, MissingTracedValue

# Traits
is_traced(x) = false

# New Type signifying that a value is missing
mutable struct MissingTracedValue
paths::Tuple
end

MissingTracedValue() = MissingTracedValue(())

# Code generation
macro trace(expr)
expr.head == :if && return esc(trace_if(__module__, expr))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
end

function trace_if(mod, expr)
expr.head == :if && error_if_return(expr)

condition_vars = [ExpressionExplorer.compute_symbols_state(expr.args[1]).references...]

true_branch_symbols = ExpressionExplorer.compute_symbols_state(expr.args[2])
true_branch_input_list = [true_branch_symbols.references...]
true_branch_assignments = [true_branch_symbols.assignments...]
all_true_branch_vars = true_branch_input_list true_branch_assignments
true_branch_fn_name = gensym(:true_branch)

else_block, discard_vars = if length(expr.args) == 3
if expr.args[3].head != :elseif
expr.args[3], nothing
else
trace_if(mod, expr.args[3])
end
elseif length(expr.args) == 2
tmp_expr = []
for var in true_branch_assignments
push!(tmp_expr, :($(var) = $(var)))
end
Expr(:block, tmp_expr...), nothing
else
dump(expr)
error("This shouldn't happen")
end

false_branch_symbols = ExpressionExplorer.compute_symbols_state(else_block)
false_branch_input_list = [false_branch_symbols.references...]
false_branch_assignments = [false_branch_symbols.assignments...]
all_false_branch_vars = false_branch_input_list false_branch_assignments
false_branch_fn_name = gensym(:false_branch)

all_input_vars = true_branch_input_list false_branch_input_list
all_output_vars = true_branch_assignments false_branch_assignments
discard_vars !== nothing && setdiff!(all_output_vars, discard_vars)

all_vars = all_input_vars all_output_vars

non_existant_true_branch_vars = setdiff(all_output_vars, all_true_branch_vars)
true_branch_extras = Expr(
:block,
[:($(var) = $(MissingTracedValue())) for var in non_existant_true_branch_vars]...,
)

true_branch_fn = quote
$(true_branch_fn_name) =
($(all_input_vars...),) -> begin
$(expr.args[2])
$(true_branch_extras)
return ($(all_output_vars...),)
end
end
true_branch_fn = cleanup_expr_to_avoid_boxing(
true_branch_fn, true_branch_fn_name, all_vars
)

non_existant_false_branch_vars = setdiff(all_output_vars, all_false_branch_vars)
false_branch_extras = Expr(
:block,
[:($(var) = $(MissingTracedValue())) for var in non_existant_false_branch_vars]...,
)

false_branch_fn = quote
$(false_branch_fn_name) =
($(all_input_vars...),) -> begin
$(else_block)
$(false_branch_extras)
return ($(all_output_vars...),)
end
end
false_branch_fn = cleanup_expr_to_avoid_boxing(
false_branch_fn, false_branch_fn_name, all_vars
)

reactant_code_block = quote
$(true_branch_fn)
$(false_branch_fn)
($(all_output_vars...),) = $(traced_if)(
$(expr.args[1]),
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),),
)
end

expr.head != :if &&
return reactant_code_block, (true_branch_fn_name, false_branch_fn_name)

all_check_vars = [all_input_vars..., condition_vars...]
return quote
if any($(is_traced), ($(all_check_vars...),))
$(reactant_code_block)
else
$(expr)
end
end
end

# Generate this dummy function and later we remove it during tracing
function traced_if(cond, true_fn::TFn, false_fn::FFn, args) where {TFn,FFn}
return cond ? true_fn(args) : false_fn(args)
end

function cleanup_expr_to_avoid_boxing(expr, prepend::Symbol, all_vars)
return MacroTools.postwalk(expr) do x
if x isa Symbol && x all_vars
return Symbol(prepend, x)
end
return x
end
end

function error_if_return(expr)
return MacroTools.postwalk(expr) do x
if x isa Expr && x.head == :return
error("Cannot use @trace on a block that contains a return statement")
end
return x
end
end

end
Loading

0 comments on commit 8a0c9a7

Please sign in to comment.