Skip to content

Commit

Permalink
Merge branch 'more_efficient_video' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ejolly committed Jul 7, 2023
2 parents b1021db + f5fb91f commit 3973bf1
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,6 @@ dev/
gpu_profiling.py

environment.yml

# Longer video
FNL_01.mp4
103 changes: 74 additions & 29 deletions feat/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import torch
from PIL import Image
import logging
import av
from itertools import islice

__all__ = [
"FexSeries",
Expand Down Expand Up @@ -664,7 +666,6 @@ def info(self):
print(f"{self.__class__}\n" + "".join(attr_list))

def _update_extracted_colnames(self, prefix=None, mode="replace"):

cols2update = [
"au_columns",
"emotion_columns",
Expand Down Expand Up @@ -705,7 +706,6 @@ def _update_extracted_colnames(self, prefix=None, mode="replace"):
_ = [setattr(self, col, val) for col, val in zip(cols2update, update)]

def _parse_features_labels(self, X, y):

feature_groups = [
"sessions",
"emotions",
Expand Down Expand Up @@ -1612,7 +1612,6 @@ def extract_boft(self, min_freq=0.06, max_freq=0.66, bank=8, *args, **kwargs):
)

def _prepare_plot_aus(self, row, muscles, gaze):

"""
Plot one or more faces based on their AU representation. This method is just a
convenient wrapper for feat.plotting.plot_face. See that function for additional
Expand Down Expand Up @@ -1723,10 +1722,8 @@ def plot_detections(
col_count += 1

for _, row in plot_data.iterrows():

# DRAW LANDMARKS ON IMAGE OR AU FACE
if face_ax is not None:

facebox = row[self.facebox_columns].values

if not faces == "aus" and plot_original_image:
Expand Down Expand Up @@ -1910,7 +1907,6 @@ def __len__(self):
return len(self.images)

def __getitem__(self, idx):

# Dimensions are [channels, height, width]
try:
img = read_image(self.images[idx])
Expand Down Expand Up @@ -2088,7 +2084,6 @@ def __len__(self):
return self.main_file.shape[0]

def __getitem__(self, idx):

# Dimensions are [channels, height, width]
img = read_image(self.main_file["image_path"].iloc[idx])
label = self.main_file.loc[idx, self.avail_AUs].to_numpy().astype(np.int16)
Expand Down Expand Up @@ -2164,51 +2159,101 @@ class VideoDataset(Dataset):
"""

def __init__(self, video_file, skip_frames=None, output_size=None):

# Ignore UserWarning: The pts_unit 'pts' gives wrong results. Please use
# pts_unit 'sec'. See why it's ok in this issue:
# https://github.com/pytorch/vision/issues/1931
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
# Video dimensions are: [time, height, width, channels]
self.video, self.audio, self.info = read_video(video_file)
# Swap them to match output of read_image: [time, channels, height, width]
# Otherwise detectors face on tensor dimension mismatch
self.video = swapaxes(swapaxes(self.video, 1, 3), -1, -2)
self.file_name = video_file
self.skip_frames = skip_frames
self.output_size = output_size
self.get_video_metadata(video_file)
# This is the list of frame ids used to slice the video not video_frames
self.video_frames = np.arange(
0, self.video.shape[0], 1 if skip_frames is None else skip_frames
0, self.metadata["num_frames"], 1 if skip_frames is None else skip_frames
)
self.video = self.video[self.video_frames, :, :]

def __len__(self):
return self.video.shape[0]
# Number of frames respective skip_frames
return len(self.video_frames)

def __getitem__(self, idx):
# Get the frame data and frame number respective skip_frames
frame_data, frame_idx = self.load_frame(idx)

# Swap frame dims to match output of read_image: [time, channels, height, width]
# Otherwise detectors face on tensor dimension mismatch
frame_data = swapaxes(swapaxes(frame_data, 0, -1), 1, 2)

# Rescale if needed like in ImageDataset
if self.output_size is not None:
logging.info(
f"VideoDataset: RESCALING WARNING: from {self.video[idx].shape} to output_size={self.output_size}"
f"VideoDataset: RESCALING WARNING: from {self.metadata['shape']} to output_size={self.output_size}"
)
transform = Compose(
[Rescale(self.output_size, preserve_aspect_ratio=True, padding=False)]
)
transformed_img = transform(self.video[idx])
transformed_frame_data = transform(frame_data)

return {
"Image": transformed_img["Image"],
"Frame": self.video_frames[idx],
"Scale": transformed_img["Scale"],
"Padding": transformed_img["Padding"],
"Image": transformed_frame_data["Image"],
"Frame": frame_idx,
"FileName": self.file_name,
"Scale": transformed_frame_data["Scale"],
"Padding": transformed_frame_data["Padding"],
}
else:
return {
"Image": self.video[idx],
"Frame": self.video_frames[idx],
"Image": frame_data,
"Frame": frame_idx,
"FileName": self.file_name,
"Scale": 1.0,
"Padding": {"Left": 0, "Top": 0, "Right": 0, "Bottom": 0},
}

def get_video_metadata(self, video_file):
container = av.open(video_file)
stream = container.streams.video[0]
fps = stream.average_rate
height = stream.height
width = stream.width
num_frames = stream.frames
container.close()
self.metadata = {
"fps": float(fps),
"fps_frac": fps,
"height": height,
"width": width,
"num_frames": num_frames,
"shape": (height, width),
}

def load_frame(self, idx):
"""Load in a single frame from the video using a lazy generator"""

# Get frame number respecting skip_frames
frame_idx = int(self.video_frames[idx])

# Use a py-av generator to load in just this frame
container = av.open(self.file_name)
stream = container.streams.video[0]
frame = next(islice(container.decode(stream), frame_idx, None))
frame_data = torch.from_numpy(frame.to_ndarray(format="rgb24"))
container.close()

return frame_data, frame_idx

def calc_approx_frame_time(self, idx):
"""Calculate the approximate time of a frame in a video
Args:
frame_idx (int): frame number
Returns:
float: time in seconds
"""
frame_time = idx / self.metadata["fps"]
total_time = self.metadata["num_frames"] / self.metadata["fps"]
time = total_time if idx >= self.metadata["num_frames"] else frame_time
return self.convert_sec_to_min_sec(time)

@staticmethod
def convert_sec_to_min_sec(duration):
minutes = int(duration // 60)
seconds = int(duration % 60)
return f"{minutes:02d}:{seconds:02d}"
24 changes: 18 additions & 6 deletions feat/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,6 @@ def detect_landmarks(self, frame, detected_faces, **landmark_model_kwargs):
else:
if self.info["landmark_model"]:
if self.info["landmark_model"].lower() == "mobilenet":

out_size = 224
else:
out_size = 112
Expand Down Expand Up @@ -459,7 +458,6 @@ def detect_landmarks(self, frame, detected_faces, **landmark_model_kwargs):

landmark_results = []
for ik in range(landmark.shape[0]):

landmark_results.append(
new_bbox[ik].inverse_transform_landmark(landmark[ik, :, :])
)
Expand Down Expand Up @@ -657,6 +655,7 @@ def _run_detection_waterfall(
facepose_model_kwargs,
emotion_model_kwargs,
au_model_kwargs,
suppress_torchvision_warnings=True,
):
"""
Main detection "waterfall." Calls each individual detector in the sequence
Expand All @@ -675,6 +674,15 @@ def _run_detection_waterfall(
Returns:
tuple: faces, landmarks, poses, aus, emotions
"""

# Reset warnings
warnings.filterwarnings("default", category=UserWarning, module="torchvision")

if suppress_torchvision_warnings:
warnings.filterwarnings(
"ignore", category=UserWarning, module="torchvision"
)

faces = self.detect_faces(
batch_data["Image"],
threshold=face_detection_threshold,
Expand Down Expand Up @@ -771,7 +779,6 @@ def detect_image(
batch_output = []

for batch_id, batch_data in enumerate(tqdm(data_loader)):

faces, landmarks, poses, aus, emotions = self._run_detection_waterfall(
batch_data,
face_detection_threshold,
Expand Down Expand Up @@ -840,8 +847,12 @@ def detect_video(
emotion_model_kwargs = kwargs.pop("emotion_model_kwargs", dict())
facepose_model_kwargs = kwargs.pop("facepose_model_kwargs", dict())

dataset = VideoDataset(
video_path, skip_frames=skip_frames, output_size=output_size
)

data_loader = DataLoader(
VideoDataset(video_path, skip_frames=skip_frames, output_size=output_size),
dataset,
num_workers=num_workers,
batch_size=batch_size,
pin_memory=pin_memory,
Expand All @@ -851,7 +862,6 @@ def detect_video(
batch_output = []

for batch_data in tqdm(data_loader):

faces, landmarks, poses, aus, emotions = self._run_detection_waterfall(
batch_data,
face_detection_threshold,
Expand All @@ -878,6 +888,9 @@ def detect_video(

batch_output = pd.concat(batch_output)
batch_output.reset_index(drop=True, inplace=True)
batch_output["approx_time"] = [
dataset.calc_approx_frame_time(x) for x in batch_output["frame"].to_numpy()
]
return batch_output.set_index("frame", drop=False)

def _create_fex(
Expand Down Expand Up @@ -1084,7 +1097,6 @@ def _match_faces_to_poses(faces, faces_pose, poses):
return (faces, poses)

else:

overlap_faces = []
overlap_poses = []
for frame_face, frame_face_pose, frame_pose in zip(
Expand Down
4 changes: 2 additions & 2 deletions feat/tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def test_fex_new(data_path):


def test_fex_old(imotions_data):

# Dropped support in >= 0.4.0
with pytest.raises(Exception):
Fex().read_facet()
Expand Down Expand Up @@ -125,7 +124,8 @@ def test_fex_old(imotions_data):
assert len(dat.downsample(target=10)) == 52

# Test upsample
assert len(dat.upsample(target=60, target_type="hz")) == (len(dat) - 1) * 2
# Commenting out because of a bug in nltools: https://github.com/cosanlab/nltools/issues/418
# assert len(dat.upsample(target=60, target_type="hz")) == (len(dat) - 1) * 2

# Test interpolation
assert (
Expand Down

0 comments on commit 3973bf1

Please sign in to comment.