From dd3b1804b9db1a9302328d944ca395a9e5cc2187 Mon Sep 17 00:00:00 2001 From: Mike Davidson Date: Thu, 11 Feb 2021 17:52:44 -0500 Subject: [PATCH] Made fb.plot.images more customizable. --- foolbox/plot.py | 12 ++++++++++-- tests/test_plot.py | 4 ++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/foolbox/plot.py b/foolbox/plot.py index 7d4141fad..4ce32df09 100644 --- a/foolbox/plot.py +++ b/foolbox/plot.py @@ -13,8 +13,10 @@ def images( nrows: Optional[int] = None, figsize: Optional[Tuple[float, float]] = None, scale: float = 1, + labels: Any = None, + return_fig: bool = False, **kwargs: Any, -) -> None: +) -> Optional[Tuple[Any, Any]]: import matplotlib.pyplot as plt x: ep.Tensor = ep.astensor(images) @@ -57,7 +59,7 @@ def images( nrows=nrows, figsize=figsize, squeeze=False, - constrained_layout=True, + constrained_layout=False, **kwargs, ) @@ -68,5 +70,11 @@ def images( ax.set_yticks([]) ax.axis("off") i = row * ncols + col + if labels is not None: + ax.set_title(labels[i]) if i < len(x): ax.imshow(x[i]) + + if return_fig: + return fig, axes + return None diff --git a/tests/test_plot.py b/tests/test_plot.py index 4928715fc..c3202c49b 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -13,6 +13,10 @@ def test_plot(dummy: ep.Tensor) -> None: fbn.plot.images(images, ncols=3) fbn.plot.images(images, nrows=2, ncols=6) fbn.plot.images(images, nrows=2, ncols=4) + fbn.plot.images( + images, nrows=2, ncols=4, labels=[str(i) for i in range(len(images))] + ) + fbn.plot.images(images, nrows=2, ncols=4, return_fig=True) with pytest.raises(ValueError): images = ep.zeros(dummy, (10, 3, 3, 3)) fbn.plot.images(images)