Skip to content

Commit

Permalink
Made fb.plot.images more customizable.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Davidson committed Feb 11, 2021
1 parent 2343d58 commit dd3b180
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
12 changes: 10 additions & 2 deletions foolbox/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ def images(
nrows: Optional[int] = None,
figsize: Optional[Tuple[float, float]] = None,
scale: float = 1,
labels: Any = None,
return_fig: bool = False,
**kwargs: Any,
) -> None:
) -> Optional[Tuple[Any, Any]]:
import matplotlib.pyplot as plt

x: ep.Tensor = ep.astensor(images)
Expand Down Expand Up @@ -57,7 +59,7 @@ def images(
nrows=nrows,
figsize=figsize,
squeeze=False,
constrained_layout=True,
constrained_layout=False,
**kwargs,
)

Expand All @@ -68,5 +70,11 @@ def images(
ax.set_yticks([])
ax.axis("off")
i = row * ncols + col
if labels is not None:
ax.set_title(labels[i])
if i < len(x):
ax.imshow(x[i])

if return_fig:
return fig, axes
return None
4 changes: 4 additions & 0 deletions tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ def test_plot(dummy: ep.Tensor) -> None:
fbn.plot.images(images, ncols=3)
fbn.plot.images(images, nrows=2, ncols=6)
fbn.plot.images(images, nrows=2, ncols=4)
fbn.plot.images(
images, nrows=2, ncols=4, labels=[str(i) for i in range(len(images))]
)
fbn.plot.images(images, nrows=2, ncols=4, return_fig=True)
with pytest.raises(ValueError):
images = ep.zeros(dummy, (10, 3, 3, 3))
fbn.plot.images(images)
Expand Down

0 comments on commit dd3b180

Please sign in to comment.