Skip to content

Commit

Permalink
Adding a test for writing out adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Nov 6, 2024
1 parent f8ff233 commit dd2a45a
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions test/output_chkp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ using Checkpointing
using Enzyme
using Serialization
using Test

import Base.length
import Base.iterate
mutable struct ChkpOut
x::Vector{Float64}
end

Base.length(chkp::ChkpOut) = length(chkp.x)
Base.iterate(chkp::ChkpOut) = iterate(chkp.x)
Base.iterate(chkp::ChkpOut, i) = iterate(chkp.x, i)

function loops(chkp::ChkpOut, scheme::Scheme, iters::Int)
@checkpoint_struct scheme chkp for i = 1:iters
chkp.x .= 2.0 * sqrt.(chkp.x) .* sqrt.(chkp.x)
Expand All @@ -33,6 +38,7 @@ fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r")
# List all checkpoints
saved_chkp = sort(parse.(Int, (keys(fid))))
println("Checkpoints saved: $saved_chkp")
chkp = Checkpointing.deserialize(read(fid["3"]))
chkp = Checkpointing.deserialize(read(fid["1"]))
@test isa(chkp, ChkpOut)
@test all(dx .== chkp.x)
close(fid)

0 comments on commit dd2a45a

Please sign in to comment.