Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sam2][frame-loading] Add streaming frame loading on top of lazy loading #377

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions sam2/sam2_video_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ def init_state(
video_path,
offload_video_to_cpu=False,
offload_state_to_cpu=False,
async_loading_frames=False,
frame_load_config=None,
):
"""Initialize an inference state."""
frame_load_config = frame_load_config or {}
compute_device = self.device # device of the model
images, video_height, video_width = load_video_frames(
video_path=video_path,
image_size=self.image_size,
offload_video_to_cpu=offload_video_to_cpu,
async_loading_frames=async_loading_frames,
frame_load_config=frame_load_config,
compute_device=compute_device,
)
inference_state = {}
Expand Down
211 changes: 186 additions & 25 deletions sam2/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import os
import warnings

from abc import abstractmethod
from threading import Thread

import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
from types import LambdaType


def get_sdpa_settings():
Expand Down Expand Up @@ -89,39 +93,35 @@ def mask_to_box(masks: torch.Tensor):
return bbox_coords


def _load_img_as_tensor(img_path, image_size):
img_pil = Image.open(img_path)
def _load_img_pil_as_tensor(img_id, img_pil, image_size):
img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
img_np = img_np / 255.0
else:
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_id}")
img = torch.from_numpy(img_np).permute(2, 0, 1)
video_width, video_height = img_pil.size # the original video size
return img, video_height, video_width


class AsyncVideoFrameLoader:
class LazyVideoFrameLoader:
"""
A list of video frames to be load asynchronously without blocking session start.
Abstract class that defines primitives to load frames lazily.
"""

def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
self.image_size = image_size
self.offload_video_to_cpu = offload_video_to_cpu
self.img_mean = img_mean
self.img_std = img_std
# items in `self.images` will be loaded asynchronously
self.images = [None] * len(img_paths)
self.images = [None] * self.__len__()
# catch and raise any exceptions in the async loading thread
self.exception = None
# video_height and video_width be filled when loading the first image
Expand All @@ -131,18 +131,25 @@ def __init__(

# load the first frame to fill video_height and video_width and also
# to cache it (since it's most likely where the user will click)
self.__getitem__(0)
self.__getitem__(self.get_first_frame_num())

# load the rest of frames asynchronously without blocking the session start
def _load_frames():
if self.should_preload():
self.thread = Thread(
target=self.load_frames,
daemon=True,
)
self.thread.start()

def load_frames(self):
asyncio.run(self.preload())

async def preload(self):
async for index in self.get_preload_generator():
try:
for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
self.__getitem__(n)
self.__getitem__(index)
except Exception as e:
self.exception = e

self.thread = Thread(target=_load_frames, daemon=True)
self.thread.start()
if self.propagate_preload_errors():
self.exception = e

def __getitem__(self, index):
if self.exception is not None:
Expand All @@ -152,8 +159,8 @@ def __getitem__(self, index):
if img is not None:
return img

img, video_height, video_width = _load_img_as_tensor(
self.img_paths[index], self.image_size
img, video_height, video_width = _load_img_pil_as_tensor(
self.get_image_id(index), self.load_image(index), self.image_size
)
self.video_height = video_height
self.video_width = video_width
Expand All @@ -166,16 +173,132 @@ def __getitem__(self, index):
return img

def __len__(self):
return self.get_length()

@abstractmethod
def get_first_frame_num(self):
raise NotImplementedError

@abstractmethod
def should_preload(self):
raise NotImplementedError

@abstractmethod
def get_preload_generator(self):
raise NotImplementedError

@abstractmethod
def propagate_preload_errors(self):
raise NotImplementedError

@abstractmethod
def load_image(self, index):
raise NotImplementedError

@abstractmethod
def get_image_id(self, index):
raise NotImplementedError

@abstractmethod
def get_length(self):
raise NotImplementedError


class AsyncVideoFrameLoader(LazyVideoFrameLoader):
"""
A list of video frames to be load asynchronously without blocking session start.
"""

def __init__(
self,
img_paths,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.img_paths = img_paths
LazyVideoFrameLoader.__init__(
self, image_size, offload_video_to_cpu, img_mean, img_std, compute_device
)

def get_first_frame_num(self):
return 0

def should_preload(self):
return True

def get_preload_generator(self):
async def _available(img_paths):
for i in tqdm(len(img_paths), desc="frame loading (JPEG)"):
yield i

return _available(self.img_paths)

def propagate_preload_errors(self):
return True

def load_image(self, index):
return Image.load(self.img_paths[index])

def get_image_id(self, index):
return self.img_paths[index]

def get_length(self):
return len(self.images)



class StreamingVideoFrameLoader(LazyVideoFrameLoader):
"""
A list of video frames that can be loaded lazily even if they are produced after session start.
"""
def __init__(
self,
loader_func,
stream_config,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
):
self.loader_func = loader_func
self.stream_config = stream_config
LazyVideoFrameLoader.__init__(
self, image_size, offload_video_to_cpu, img_mean, img_std, compute_device
)

def get_first_frame_num(self):
return self.stream_config.get("first_frame_num", 0)

def should_preload(self):
return self.stream_config.get("preload_gen", None) is not None

def get_preload_generator(self):
return self.stream_config.get("preload_gen")

def propagate_preload_errors(self):
return self.stream_config.get("propagate_preload_errors", True)

def load_image(self, index):
return self.loader_func(index)

def get_image_id(self, index):
return str(index)

def get_length(self):
return self.stream_config.get("max_frames")


def load_video_frames(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
frame_load_config=None,
compute_device=torch.device("cuda"),
):
"""
Expand All @@ -184,6 +307,7 @@ def load_video_frames(
"""
is_bytes = isinstance(video_path, bytes)
is_str = isinstance(video_path, str)
is_func = isinstance(video_path, LambdaType)
is_mp4_path = is_str and os.path.splitext(video_path)[-1] in [".mp4", ".MP4"]
if is_bytes or is_mp4_path:
return load_video_frames_from_video_file(
Expand All @@ -201,7 +325,18 @@ def load_video_frames(
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
async_loading_frames=async_loading_frames,
frame_load_config=frame_load_config,
compute_device=compute_device,
)

elif is_func:
return load_video_frames_from_lambda(
loader_func=video_path,
image_size=image_size,
offload_video_to_cpu=offload_video_to_cpu,
img_mean=img_mean,
img_std=img_std,
frame_load_config=frame_load_config,
compute_device=compute_device,
)
else:
Expand All @@ -210,13 +345,37 @@ def load_video_frames(
)


def load_video_frames_from_lambda(
loader_func,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
frame_load_config=None,
compute_device=torch.device("cuda"),
):
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

lazy_images = StreamingVideoFrameLoader(
loader_func,
frame_load_config,
image_size,
offload_video_to_cpu,
img_mean,
img_std,
compute_device,
)
return lazy_images, lazy_images.video_height, lazy_images.video_width


def load_video_frames_from_jpg_images(
video_path,
image_size,
offload_video_to_cpu,
img_mean=(0.485, 0.456, 0.406),
img_std=(0.229, 0.224, 0.225),
async_loading_frames=False,
frame_load_config=None,
compute_device=torch.device("cuda"),
):
"""
Expand Down Expand Up @@ -253,7 +412,7 @@ def load_video_frames_from_jpg_images(
img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]

if async_loading_frames:
if frame_load_config.get("async", False):
lazy_images = AsyncVideoFrameLoader(
img_paths,
image_size,
Expand All @@ -266,7 +425,9 @@ def load_video_frames_from_jpg_images(

images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
images[n], video_height, video_width = _load_img_pil_as_tensor(
img_path, Image.open(img_path), image_size
)
if not offload_video_to_cpu:
images = images.to(compute_device)
img_mean = img_mean.to(compute_device)
Expand Down
4 changes: 2 additions & 2 deletions tools/vos_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def vos_inference(
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
video_path=video_dir, frame_load_config=None
)
height = inference_state["video_height"]
width = inference_state["video_width"]
Expand Down Expand Up @@ -273,7 +273,7 @@ def vos_separate_inference_per_object(
]
frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
inference_state = predictor.init_state(
video_path=video_dir, async_loading_frames=False
video_path=video_dir, frame_load_config=None
)
height = inference_state["video_height"]
width = inference_state["video_width"]
Expand Down
Loading