From 70de544710c558675186b2c176150a1fe9ea8883 Mon Sep 17 00:00:00 2001 From: Margaret Duff Date: Mon, 14 Aug 2023 12:43:50 +0000 Subject: [PATCH] Neatening block function tests --- Wrappers/Python/test/test_functions.py | 33 +++++++++++++++++++------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/Wrappers/Python/test/test_functions.py b/Wrappers/Python/test/test_functions.py index 952cd81fb3..f8b1e6becc 100644 --- a/Wrappers/Python/test/test_functions.py +++ b/Wrappers/Python/test/test_functions.py @@ -1172,17 +1172,34 @@ def test_non_scalar_tau_cil_tv(self): np.testing.assert_allclose(res1.array, res2.array, atol=1e-3) - def test_get_p2(self): - self.assertEquals(self.tv._p2, None) + def test_get_p2_default(self): data = dataexample.SHAPES.get(size=(64, 64)) tv=TotalVariation() + self.assertEquals(tv._p2, None) tv(data) - a=tv.gradient.range_geometry().allocate(0) - b=tv._get_p2() - np.testing.assert_allclose(tv._get_p2()[0].array, tv.gradient.range_geometry().allocate(0)[0].array) - np.testing.assert_allclose(tv._get_p2()[1].array, tv.gradient.range_geometry().allocate(0)[1].array) - - + if isinstance(tv._get_p2(), BlockDataContainer): + for xa,xb in zip(tv._get_p2(),tv.gradient.range_geometry().allocate(0)): + np.testing.assert_allclose(xa.as_array(), xb.as_array(), + rtol=1e-5, atol=1e-5) + + def test_get_p2_after_prox_iteration_has_changed(self): + data = dataexample.SHAPES.get(size=(64, 64)) + tv=TotalVariation() + self.assertEquals(tv._p2, None) + tv.proximal(data, 1.) + if isinstance(tv._get_p2(), BlockDataContainer): + for xa,xb in zip(tv._get_p2(),tv.gradient.range_geometry().allocate(0)): + np.testing.assert_equal(np.any(np.not_equal(xa.as_array(), xb.as_array())), True) + + def test_get_p2_after_prox_iteration_has_not_changed(self): + data = dataexample.SHAPES.get(size=(64, 64)) + tv=TotalVariation(warmstart=False) + self.assertEquals(tv._p2, None) + tv.proximal(data, 1.) + if isinstance(tv._get_p2(), BlockDataContainer): + for xa,xb in zip(tv._get_p2(),tv.gradient.range_geometry().allocate(0)): + np.testing.assert_allclose(xa.as_array(), xb.as_array(), + rtol=1e-5, atol=1e-5) class TestLeastSquares(unittest.TestCase):