Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More efficient video processing #170

Merged
merged 3 commits into from
Jul 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,6 @@ dev/
gpu_profiling.py

environment.yml

# Longer video
FNL_01.mp4
103 changes: 74 additions & 29 deletions feat/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import torch
from PIL import Image
import logging
import av
from itertools import islice

__all__ = [
"FexSeries",
Expand Down Expand Up @@ -664,7 +666,6 @@ def info(self):
print(f"{self.__class__}\n" + "".join(attr_list))

def _update_extracted_colnames(self, prefix=None, mode="replace"):

cols2update = [
"au_columns",
"emotion_columns",
Expand Down Expand Up @@ -705,7 +706,6 @@ def _update_extracted_colnames(self, prefix=None, mode="replace"):
_ = [setattr(self, col, val) for col, val in zip(cols2update, update)]

def _parse_features_labels(self, X, y):

feature_groups = [
"sessions",
"emotions",
Expand Down Expand Up @@ -1612,7 +1612,6 @@ def extract_boft(self, min_freq=0.06, max_freq=0.66, bank=8, *args, **kwargs):
)

def _prepare_plot_aus(self, row, muscles, gaze):

"""
Plot one or more faces based on their AU representation. This method is just a
convenient wrapper for feat.plotting.plot_face. See that function for additional
Expand Down Expand Up @@ -1723,10 +1722,8 @@ def plot_detections(
col_count += 1

for _, row in plot_data.iterrows():

# DRAW LANDMARKS ON IMAGE OR AU FACE
if face_ax is not None:

facebox = row[self.facebox_columns].values

if not faces == "aus" and plot_original_image:
Expand Down Expand Up @@ -1910,7 +1907,6 @@ def __len__(self):
return len(self.images)

def __getitem__(self, idx):

# Dimensions are [channels, height, width]
try:
img = read_image(self.images[idx])
Expand Down Expand Up @@ -2088,7 +2084,6 @@ def __len__(self):
return self.main_file.shape[0]

def __getitem__(self, idx):

# Dimensions are [channels, height, width]
img = read_image(self.main_file["image_path"].iloc[idx])
label = self.main_file.loc[idx, self.avail_AUs].to_numpy().astype(np.int16)
Expand Down Expand Up @@ -2164,51 +2159,101 @@ class VideoDataset(Dataset):
"""

def __init__(self, video_file, skip_frames=None, output_size=None):

# Ignore UserWarning: The pts_unit 'pts' gives wrong results. Please use
# pts_unit 'sec'. See why it's ok in this issue:
# https://github.com/pytorch/vision/issues/1931
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
# Video dimensions are: [time, height, width, channels]
self.video, self.audio, self.info = read_video(video_file)
# Swap them to match output of read_image: [time, channels, height, width]
# Otherwise detectors face on tensor dimension mismatch
self.video = swapaxes(swapaxes(self.video, 1, 3), -1, -2)
self.file_name = video_file
self.skip_frames = skip_frames
self.output_size = output_size
self.get_video_metadata(video_file)
# This is the list of frame ids used to slice the video not video_frames
self.video_frames = np.arange(
0, self.video.shape[0], 1 if skip_frames is None else skip_frames
0, self.metadata["num_frames"], 1 if skip_frames is None else skip_frames
)
self.video = self.video[self.video_frames, :, :]

def __len__(self):
return self.video.shape[0]
# Number of frames respective skip_frames
return len(self.video_frames)

def __getitem__(self, idx):
# Get the frame data and frame number respective skip_frames
frame_data, frame_idx = self.load_frame(idx)

# Swap frame dims to match output of read_image: [time, channels, height, width]
# Otherwise detectors face on tensor dimension mismatch
frame_data = swapaxes(swapaxes(frame_data, 0, -1), 1, 2)

# Rescale if needed like in ImageDataset
if self.output_size is not None:
logging.info(
f"VideoDataset: RESCALING WARNING: from {self.video[idx].shape} to output_size={self.output_size}"
f"VideoDataset: RESCALING WARNING: from {self.metadata['shape']} to output_size={self.output_size}"
)
transform = Compose(
[Rescale(self.output_size, preserve_aspect_ratio=True, padding=False)]
)
transformed_img = transform(self.video[idx])
transformed_frame_data = transform(frame_data)

return {
"Image": transformed_img["Image"],
"Frame": self.video_frames[idx],
"Scale": transformed_img["Scale"],
"Padding": transformed_img["Padding"],
"Image": transformed_frame_data["Image"],
"Frame": frame_idx,
"FileName": self.file_name,
"Scale": transformed_frame_data["Scale"],
"Padding": transformed_frame_data["Padding"],
}
else:
return {
"Image": self.video[idx],
"Frame": self.video_frames[idx],
"Image": frame_data,
"Frame": frame_idx,
"FileName": self.file_name,
"Scale": 1.0,
"Padding": {"Left": 0, "Top": 0, "Right": 0, "Bottom": 0},
}

def get_video_metadata(self, video_file):
container = av.open(video_file)
stream = container.streams.video[0]
fps = stream.average_rate
height = stream.height
width = stream.width
num_frames = stream.frames
container.close()
self.metadata = {
"fps": float(fps),
"fps_frac": fps,
"height": height,
"width": width,
"num_frames": num_frames,
"shape": (height, width),
}

def load_frame(self, idx):
"""Load in a single frame from the video using a lazy generator"""

# Get frame number respecting skip_frames
frame_idx = int(self.video_frames[idx])

# Use a py-av generator to load in just this frame
container = av.open(self.file_name)
stream = container.streams.video[0]
frame = next(islice(container.decode(stream), frame_idx, None))
frame_data = torch.from_numpy(frame.to_ndarray(format="rgb24"))
container.close()

return frame_data, frame_idx

def calc_approx_frame_time(self, idx):
"""Calculate the approximate time of a frame in a video

Args:
frame_idx (int): frame number

Returns:
float: time in seconds
"""
frame_time = idx / self.metadata["fps"]
total_time = self.metadata["num_frames"] / self.metadata["fps"]
time = total_time if idx >= self.metadata["num_frames"] else frame_time
return self.convert_sec_to_min_sec(time)

@staticmethod
def convert_sec_to_min_sec(duration):
minutes = int(duration // 60)
seconds = int(duration % 60)
return f"{minutes:02d}:{seconds:02d}"
24 changes: 18 additions & 6 deletions feat/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,6 @@ def detect_landmarks(self, frame, detected_faces, **landmark_model_kwargs):
else:
if self.info["landmark_model"]:
if self.info["landmark_model"].lower() == "mobilenet":

out_size = 224
else:
out_size = 112
Expand Down Expand Up @@ -459,7 +458,6 @@ def detect_landmarks(self, frame, detected_faces, **landmark_model_kwargs):

landmark_results = []
for ik in range(landmark.shape[0]):

landmark_results.append(
new_bbox[ik].inverse_transform_landmark(landmark[ik, :, :])
)
Expand Down Expand Up @@ -657,6 +655,7 @@ def _run_detection_waterfall(
facepose_model_kwargs,
emotion_model_kwargs,
au_model_kwargs,
suppress_torchvision_warnings=True,
):
"""
Main detection "waterfall." Calls each individual detector in the sequence
Expand All @@ -675,6 +674,15 @@ def _run_detection_waterfall(
Returns:
tuple: faces, landmarks, poses, aus, emotions
"""

# Reset warnings
warnings.filterwarnings("default", category=UserWarning, module="torchvision")

if suppress_torchvision_warnings:
warnings.filterwarnings(
"ignore", category=UserWarning, module="torchvision"
)

faces = self.detect_faces(
batch_data["Image"],
threshold=face_detection_threshold,
Expand Down Expand Up @@ -771,7 +779,6 @@ def detect_image(
batch_output = []

for batch_id, batch_data in enumerate(tqdm(data_loader)):

faces, landmarks, poses, aus, emotions = self._run_detection_waterfall(
batch_data,
face_detection_threshold,
Expand Down Expand Up @@ -840,8 +847,12 @@ def detect_video(
emotion_model_kwargs = kwargs.pop("emotion_model_kwargs", dict())
facepose_model_kwargs = kwargs.pop("facepose_model_kwargs", dict())

dataset = VideoDataset(
video_path, skip_frames=skip_frames, output_size=output_size
)

data_loader = DataLoader(
VideoDataset(video_path, skip_frames=skip_frames, output_size=output_size),
dataset,
num_workers=num_workers,
batch_size=batch_size,
pin_memory=pin_memory,
Expand All @@ -851,7 +862,6 @@ def detect_video(
batch_output = []

for batch_data in tqdm(data_loader):

faces, landmarks, poses, aus, emotions = self._run_detection_waterfall(
batch_data,
face_detection_threshold,
Expand All @@ -878,6 +888,9 @@ def detect_video(

batch_output = pd.concat(batch_output)
batch_output.reset_index(drop=True, inplace=True)
batch_output["approx_time"] = [
dataset.calc_approx_frame_time(x) for x in batch_output["frame"].to_numpy()
]
return batch_output.set_index("frame", drop=False)

def _create_fex(
Expand Down Expand Up @@ -1084,7 +1097,6 @@ def _match_faces_to_poses(faces, faces_pose, poses):
return (faces, poses)

else:

overlap_faces = []
overlap_poses = []
for frame_face, frame_face_pose, frame_pose in zip(
Expand Down
4 changes: 2 additions & 2 deletions feat/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_fex_new(data_path):


def test_fex_old(imotions_data):

# Dropped support in >= 0.4.0
with pytest.raises(Exception):
Fex().read_facet()
Expand Down Expand Up @@ -125,7 +124,8 @@ def test_fex_old(imotions_data):
assert len(dat.downsample(target=10)) == 52

# Test upsample
assert len(dat.upsample(target=60, target_type="hz")) == (len(dat) - 1) * 2
# Commenting out because of a bug in nltools: https://github.com/cosanlab/nltools/issues/418
# assert len(dat.upsample(target=60, target_type="hz")) == (len(dat) - 1) * 2

# Test interpolation
assert (
Expand Down
Loading