From dd2a45a99152e8330cf1af2583905a6b625eb556 Mon Sep 17 00:00:00 2001 From: Michel Schanen Date: Tue, 5 Nov 2024 15:15:57 -0600 Subject: [PATCH] Adding a test for writing out adjoints --- test/output_chkp.jl | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/output_chkp.jl b/test/output_chkp.jl index b790b41..909815a 100644 --- a/test/output_chkp.jl +++ b/test/output_chkp.jl @@ -2,11 +2,16 @@ using Checkpointing using Enzyme using Serialization using Test - +import Base.length +import Base.iterate mutable struct ChkpOut x::Vector{Float64} end +Base.length(chkp::ChkpOut) = length(chkp.x) +Base.iterate(chkp::ChkpOut) = iterate(chkp.x) +Base.iterate(chkp::ChkpOut, i) = iterate(chkp.x, i) + function loops(chkp::ChkpOut, scheme::Scheme, iters::Int) @checkpoint_struct scheme chkp for i = 1:iters chkp.x .= 2.0 * sqrt.(chkp.x) .* sqrt.(chkp.x) @@ -33,6 +38,7 @@ fid = Checkpointing.HDF5.h5open("adjoint_chkp.h5", "r") # List all checkpoints saved_chkp = sort(parse.(Int, (keys(fid)))) println("Checkpoints saved: $saved_chkp") -chkp = Checkpointing.deserialize(read(fid["3"])) +chkp = Checkpointing.deserialize(read(fid["1"])) @test isa(chkp, ChkpOut) +@test all(dx .== chkp.x) close(fid)