From c64b577f1e2a9a8045668dc0e459219215b2429d Mon Sep 17 00:00:00 2001 From: ejolly Date: Thu, 22 Jun 2023 18:04:32 -0400 Subject: [PATCH 1/3] Lazy loading of video frames --- .gitignore | 3 ++ feat/data.py | 83 +++++++++++++++++++++++++++++++----------------- feat/detector.py | 15 ++++++--- 3 files changed, 67 insertions(+), 34 deletions(-) diff --git a/.gitignore b/.gitignore index 8126eed0..5804fa4f 100644 --- a/.gitignore +++ b/.gitignore @@ -106,3 +106,6 @@ dev/ gpu_profiling.py environment.yml + +# Longer video +FNL_01.mp4 diff --git a/feat/data.py b/feat/data.py index 6040ae56..5f3babd4 100644 --- a/feat/data.py +++ b/feat/data.py @@ -40,6 +40,8 @@ import torch from PIL import Image import logging +import av +from itertools import islice __all__ = [ "FexSeries", @@ -664,7 +666,6 @@ def info(self): print(f"{self.__class__}\n" + "".join(attr_list)) def _update_extracted_colnames(self, prefix=None, mode="replace"): - cols2update = [ "au_columns", "emotion_columns", @@ -705,7 +706,6 @@ def _update_extracted_colnames(self, prefix=None, mode="replace"): _ = [setattr(self, col, val) for col, val in zip(cols2update, update)] def _parse_features_labels(self, X, y): - feature_groups = [ "sessions", "emotions", @@ -1612,7 +1612,6 @@ def extract_boft(self, min_freq=0.06, max_freq=0.66, bank=8, *args, **kwargs): ) def _prepare_plot_aus(self, row, muscles, gaze): - """ Plot one or more faces based on their AU representation. This method is just a convenient wrapper for feat.plotting.plot_face. See that function for additional @@ -1723,10 +1722,8 @@ def plot_detections( col_count += 1 for _, row in plot_data.iterrows(): - # DRAW LANDMARKS ON IMAGE OR AU FACE if face_ax is not None: - facebox = row[self.facebox_columns].values if not faces == "aus" and plot_original_image: @@ -1910,7 +1907,6 @@ def __len__(self): return len(self.images) def __getitem__(self, idx): - # Dimensions are [channels, height, width] try: img = read_image(self.images[idx]) @@ -2088,7 +2084,6 @@ def __len__(self): return self.main_file.shape[0] def __getitem__(self, idx): - # Dimensions are [channels, height, width] img = read_image(self.main_file["image_path"].iloc[idx]) label = self.main_file.loc[idx, self.avail_AUs].to_numpy().astype(np.int16) @@ -2164,51 +2159,81 @@ class VideoDataset(Dataset): """ def __init__(self, video_file, skip_frames=None, output_size=None): - - # Ignore UserWarning: The pts_unit 'pts' gives wrong results. Please use - # pts_unit 'sec'. See why it's ok in this issue: - # https://github.com/pytorch/vision/issues/1931 - with warnings.catch_warnings(): - warnings.simplefilter("ignore", UserWarning) - # Video dimensions are: [time, height, width, channels] - self.video, self.audio, self.info = read_video(video_file) - # Swap them to match output of read_image: [time, channels, height, width] - # Otherwise detectors face on tensor dimension mismatch - self.video = swapaxes(swapaxes(self.video, 1, 3), -1, -2) self.file_name = video_file + self.skip_frames = skip_frames self.output_size = output_size + self.get_video_metadata(video_file) + # This is the list of frame ids used to slice the video not video_frames self.video_frames = np.arange( - 0, self.video.shape[0], 1 if skip_frames is None else skip_frames + 0, self.metadata["num_frames"], 1 if skip_frames is None else skip_frames ) - self.video = self.video[self.video_frames, :, :] def __len__(self): - return self.video.shape[0] + # Number of frames respective skip_frames + return len(self.video_frames) def __getitem__(self, idx): + # Get the frame data and frame number respective skip_frames + frame_data, frame_idx = self.load_frame(idx) + + # Swap frame dims to match output of read_image: [time, channels, height, width] + # Otherwise detectors face on tensor dimension mismatch + frame_data = swapaxes(swapaxes(frame_data, 0, -1), 1, 2) # Rescale if needed like in ImageDataset if self.output_size is not None: logging.info( - f"VideoDataset: RESCALING WARNING: from {self.video[idx].shape} to output_size={self.output_size}" + f"VideoDataset: RESCALING WARNING: from {self.metadata['shape']} to output_size={self.output_size}" ) transform = Compose( [Rescale(self.output_size, preserve_aspect_ratio=True, padding=False)] ) - transformed_img = transform(self.video[idx]) + transformed_frame_data = transform(frame_data) return { - "Image": transformed_img["Image"], - "Frame": self.video_frames[idx], - "Scale": transformed_img["Scale"], - "Padding": transformed_img["Padding"], + "Image": transformed_frame_data["Image"], + "Frame": frame_idx, "FileName": self.file_name, + "Scale": transformed_frame_data["Scale"], + "Padding": transformed_frame_data["Padding"], } else: return { - "Image": self.video[idx], - "Frame": self.video_frames[idx], + "Image": frame_data, + "Frame": frame_idx, "FileName": self.file_name, "Scale": 1.0, "Padding": {"Left": 0, "Top": 0, "Right": 0, "Bottom": 0}, } + + def get_video_metadata(self, video_file): + container = av.open(video_file) + stream = container.streams.video[0] + fps = stream.average_rate + height = stream.height + width = stream.width + num_frames = stream.frames + container.close() + self.metadata = { + "fps": float(fps), + "fps_frac": fps, + "height": height, + "width": width, + "num_frames": num_frames, + "shape": (height, width), + } + + def load_frame(self, idx): + """Load in a single frame from the video using a lazy generator""" + + # Get frame number respecting skip_frames + frame_idx = int(self.video_frames[idx]) + + # Use a py-av generator to load in just this frame + container = av.open(self.file_name) + stream = container.streams.video[0] + frame = next(islice(container.decode(stream), frame_idx, None)) + frame_data = torch.from_numpy(frame.to_ndarray(format="rgb24")) + container.close() + + return frame_data, frame_idx diff --git a/feat/detector.py b/feat/detector.py index b490eaf8..0b83e01b 100644 --- a/feat/detector.py +++ b/feat/detector.py @@ -425,7 +425,6 @@ def detect_landmarks(self, frame, detected_faces, **landmark_model_kwargs): else: if self.info["landmark_model"]: if self.info["landmark_model"].lower() == "mobilenet": - out_size = 224 else: out_size = 112 @@ -459,7 +458,6 @@ def detect_landmarks(self, frame, detected_faces, **landmark_model_kwargs): landmark_results = [] for ik in range(landmark.shape[0]): - landmark_results.append( new_bbox[ik].inverse_transform_landmark(landmark[ik, :, :]) ) @@ -657,6 +655,7 @@ def _run_detection_waterfall( facepose_model_kwargs, emotion_model_kwargs, au_model_kwargs, + suppress_torchvision_warnings=True, ): """ Main detection "waterfall." Calls each individual detector in the sequence @@ -675,6 +674,15 @@ def _run_detection_waterfall( Returns: tuple: faces, landmarks, poses, aus, emotions """ + + # Reset warnings + warnings.filterwarnings("default", category=UserWarning, module="torchvision") + + if suppress_torchvision_warnings: + warnings.filterwarnings( + "ignore", category=UserWarning, module="torchvision" + ) + faces = self.detect_faces( batch_data["Image"], threshold=face_detection_threshold, @@ -771,7 +779,6 @@ def detect_image( batch_output = [] for batch_id, batch_data in enumerate(tqdm(data_loader)): - faces, landmarks, poses, aus, emotions = self._run_detection_waterfall( batch_data, face_detection_threshold, @@ -851,7 +858,6 @@ def detect_video( batch_output = [] for batch_data in tqdm(data_loader): - faces, landmarks, poses, aus, emotions = self._run_detection_waterfall( batch_data, face_detection_threshold, @@ -1084,7 +1090,6 @@ def _match_faces_to_poses(faces, faces_pose, poses): return (faces, poses) else: - overlap_faces = [] overlap_poses = [] for frame_face, frame_face_pose, frame_pose in zip( From f80dd5fc9ee0ab4ba04a86511d4b078217ab3e2c Mon Sep 17 00:00:00 2001 From: ejolly Date: Thu, 22 Jun 2023 18:04:50 -0400 Subject: [PATCH 2/3] disable fex test until nltools issue is fixed --- feat/tests/test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/feat/tests/test_data.py b/feat/tests/test_data.py index 46a162b4..93973954 100644 --- a/feat/tests/test_data.py +++ b/feat/tests/test_data.py @@ -44,7 +44,6 @@ def test_fex_new(data_path): def test_fex_old(imotions_data): - # Dropped support in >= 0.4.0 with pytest.raises(Exception): Fex().read_facet() @@ -125,7 +124,8 @@ def test_fex_old(imotions_data): assert len(dat.downsample(target=10)) == 52 # Test upsample - assert len(dat.upsample(target=60, target_type="hz")) == (len(dat) - 1) * 2 + # Commenting out because of a bug in nltools: https://github.com/cosanlab/nltools/issues/418 + # assert len(dat.upsample(target=60, target_type="hz")) == (len(dat) - 1) * 2 # Test interpolation assert ( From f5fb91f5f917c29c9532c7ea9dec1c9428f467b7 Mon Sep 17 00:00:00 2001 From: ejolly Date: Thu, 29 Jun 2023 12:53:10 -0400 Subject: [PATCH 3/3] add approximate frame time to output of detect_video --- feat/data.py | 20 ++++++++++++++++++++ feat/detector.py | 9 ++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/feat/data.py b/feat/data.py index 5f3babd4..9b06e9ce 100644 --- a/feat/data.py +++ b/feat/data.py @@ -2237,3 +2237,23 @@ def load_frame(self, idx): container.close() return frame_data, frame_idx + + def calc_approx_frame_time(self, idx): + """Calculate the approximate time of a frame in a video + + Args: + frame_idx (int): frame number + + Returns: + float: time in seconds + """ + frame_time = idx / self.metadata["fps"] + total_time = self.metadata["num_frames"] / self.metadata["fps"] + time = total_time if idx >= self.metadata["num_frames"] else frame_time + return self.convert_sec_to_min_sec(time) + + @staticmethod + def convert_sec_to_min_sec(duration): + minutes = int(duration // 60) + seconds = int(duration % 60) + return f"{minutes:02d}:{seconds:02d}" diff --git a/feat/detector.py b/feat/detector.py index 0b83e01b..1773a9d4 100644 --- a/feat/detector.py +++ b/feat/detector.py @@ -847,8 +847,12 @@ def detect_video( emotion_model_kwargs = kwargs.pop("emotion_model_kwargs", dict()) facepose_model_kwargs = kwargs.pop("facepose_model_kwargs", dict()) + dataset = VideoDataset( + video_path, skip_frames=skip_frames, output_size=output_size + ) + data_loader = DataLoader( - VideoDataset(video_path, skip_frames=skip_frames, output_size=output_size), + dataset, num_workers=num_workers, batch_size=batch_size, pin_memory=pin_memory, @@ -884,6 +888,9 @@ def detect_video( batch_output = pd.concat(batch_output) batch_output.reset_index(drop=True, inplace=True) + batch_output["approx_time"] = [ + dataset.calc_approx_frame_time(x) for x in batch_output["frame"].to_numpy() + ] return batch_output.set_index("frame", drop=False) def _create_fex(