From ff266f21ca4b1268e602e52d049453e5a24695de Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Thu, 2 May 2024 14:47:54 -0500 Subject: [PATCH] Fix primal evaluation (#45) --- Project.toml | 2 +- README.md | 3 +-- examples/optcontrol.jl | 4 ++-- src/Rules/EnzymeRules.jl | 9 +++------ 4 files changed, 7 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index e2e7df1..7292f3d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Checkpointing" uuid = "eb46d486-4f9c-4c3d-b445-a617f2a2f1ca" authors = ["Michel Schanen ", "Sri Hari Krishna Narayanan "] -version = "0.9.1" +version = "0.9.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/README.md b/README.md index e2fde81..0b42525 100644 --- a/README.md +++ b/README.md @@ -63,7 +63,6 @@ end function sumheat(heat::Heat, chkpscheme::Scheme, tsteps::Int64) # AD: Create shadow copy for derivatives @checkpoint_struct chkpscheme heat for i in 1:tsteps - # checkpoint_struct_for(advance, heat) heat.Tlast .= heat.Tnext advance(heat) end @@ -87,7 +86,7 @@ function heat(scheme::Scheme, tsteps::Int) heat.Tnext[end] = 0 # Compute gradient - autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), scheme, tsteps) + autodiff(Enzyme.ReverseWithPrimal, sumheat, Duplicated(heat, dheat), Const(scheme), Const(tsteps)) return heat.Tnext, dheat.Tnext[2:end-1] end diff --git a/examples/optcontrol.jl b/examples/optcontrol.jl index 007b801..5daaf85 100644 --- a/examples/optcontrol.jl +++ b/examples/optcontrol.jl @@ -4,7 +4,7 @@ # Technique for Resilience. United States: N. p., 2016. https://www.osti.gov/biblio/1364654. using Checkpointing -using Zygote +using Enzyme include("optcontrolfunc.jl") @@ -69,7 +69,7 @@ function muoptcontrol(scheme, steps, ::EnzymeTool) end return model.F[2] end - autodiff(Enzyme.ReverseWithPrimal, foo, Duplicated(model, bmodel)) + autodiff(Enzyme.Reverse, foo, Duplicated(model, bmodel)) F = model.F L = bmodel.F diff --git a/src/Rules/EnzymeRules.jl b/src/Rules/EnzymeRules.jl index a85a161..97cebfc 100644 --- a/src/Rules/EnzymeRules.jl +++ b/src/Rules/EnzymeRules.jl @@ -11,8 +11,8 @@ function augmented_primal( model, range, ) + primal = func.val(body.val, alg.val, deepcopy(model.val), range.val) if needs_primal(config) - primal = func.val(body.val, alg.val, model.val, range.val) return AugmentedReturn(primal, nothing, (model.val,)) else return AugmentedReturn(nothing, nothing, (model.val,)) @@ -50,12 +50,9 @@ function augmented_primal( model, condition, ) + primal = func.val(body.val, alg.val, deepcopy(model.val), condition.val) if needs_primal(config) - return AugmentedReturn( - func.val(body.val, alg.val, model.val, condition.val), - nothing, - (model.val,), - ) + return AugmentedReturn(primal, nothing, (model.val,)) else return AugmentedReturn(nothing, nothing, (model.val,)) end