Skip to content

Commit

Permalink
Adding a test for writing out adjoints
Browse files Browse the repository at this point in the history
  • Loading branch information
michel2323 committed Nov 5, 2024
1 parent f8ff233 commit c9c1622
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions test/output_chkp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@ using Checkpointing
using Enzyme
using Serialization
using Test

import Base.length
import Base.iterate
mutable struct ChkpOut
x::Vector{Float64}
end

Base.length(chkp::ChkpOut) = length(chkp.x)
Base.iterate(chkp::ChkpOut, i) = iterate(chkp.x, i)

function loops(chkp::ChkpOut, scheme::Scheme, iters::Int)
@checkpoint_struct scheme chkp for i = 1:iters
chkp.x .= 2.0 * sqrt.(chkp.x) .* sqrt.(chkp.x)
Expand All @@ -33,6 +37,7 @@ fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r")
# List all checkpoints
saved_chkp = sort(parse.(Int, (keys(fid))))
println("Checkpoints saved: $saved_chkp")
chkp = Checkpointing.deserialize(read(fid["3"]))
chkp = Checkpointing.deserialize(read(fid["1"]))
@test isa(chkp, ChkpOut)
@test all(dx .== chkp.x)
close(fid)

0 comments on commit c9c1622

Please sign in to comment.