From c922127f72559cf503acc72cd18ebf808623d06b Mon Sep 17 00:00:00 2001 From: Ko Sugawara Date: Wed, 14 Aug 2024 11:51:45 +0900 Subject: [PATCH] Use skimage.measure.label if failed in importing sam2._C module --- sam2/utils/misc.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index e2d39a08..bfc27c73 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -58,9 +58,16 @@ def get_connected_components(mask): - counts: A tensor of shape (N, 1, H, W) containing the area of the connected components for foreground pixels and 0 for background pixels. """ - from sam2 import _C + try: + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + except Exception: + import skimage.measure - return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + return skimage.measure.label( + mask.to(torch.uint8).contiguous().cpu().numpy(), return_num=True + ) def mask_to_box(masks: torch.Tensor):