Skip to content

Commit

Permalink
Merge pull request #19 from sqoshi/mask-imposer-class
Browse files Browse the repository at this point in the history
MaskImposer package import limitations + imposing class
  • Loading branch information
sqoshi authored Oct 12, 2021
2 parents 6249759 + dd4d339 commit 3b98f66
Show file tree
Hide file tree
Showing 11 changed files with 240 additions and 87 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -131,5 +131,5 @@ dmypy.json
.idea/
shape_predictor_68_face_landmarks.bz2
shape_predictor_68_face_landmarks.dat
results
results*
results.*
1 change: 1 addition & 0 deletions mask_imposer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .controller import MaskImposer
68 changes: 68 additions & 0 deletions mask_imposer/controller.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os
from pathlib import Path
from typing import List, Union, Any

import cv2
from cv2 import waitKey
from numpy.typing import NDArray

from mask_imposer.colored_logger import get_configured_logger
from mask_imposer.definitions import Improvements, MaskSet
from mask_imposer.detector.landmark_detector import Detector
from mask_imposer.imposer.mask_imposer import Imposer
from mask_imposer.input_inspector import Inspector


def _get_bundled_mask_set(set_index: int) -> MaskSet:
"""Creates MaskSet object from bundled sets."""
curr_fp = Path(os.path.dirname(os.path.realpath(__file__))).parent
return MaskSet(
os.path.join(curr_fp, f"mask_imposer/bundled/set_0{set_index}/mask_image.png"),
os.path.join(curr_fp, f"mask_imposer/bundled/set_0{set_index}/mask_coords.json")
)


class MaskImposer:
"""Class allow to use project as installable package and impose masks programmatically."""

def __init__(self, bundled_mask_set_idx: int = 1) -> None:
self._logger = get_configured_logger()
mask_set = _get_bundled_mask_set(bundled_mask_set_idx) # possibility to mix
self._inspector = Inspector(self._logger)
self._detector = Detector(
predictor_fp=None,
face_detection=True,
show_samples=False,
auto_download=True,
logger=self._logger
)
self._imposer = Imposer(
output=None,
mask_set=mask_set,
improvements=Improvements(False, False),
logger=self._logger
)

def impose_mask(self, image: Union[str, List[str]], show: bool = False) -> List[NDArray[Any]]:
"""Imposes mask on image.
:param image: List of paths to images or single image path
:param show: if True than displays results of imposing
:return: list of ndarrays images with imposed masks
"""
images = [image] if not isinstance(image, list) else image
self._detector.detect(images)
masked_images = self._imposer.impose(self._detector.get_landmarks())
self._detector.forget_landmarks()

if show:
for mi in masked_images:
cv2.imshow("Sample", mi)
waitKey(0)

return masked_images

@classmethod
def save(cls, img: NDArray[Any], filepath: str) -> None:
"""Saves image in given path using opencv."""
cv2.imwrite(filepath, img)
50 changes: 38 additions & 12 deletions mask_imposer/detector/download.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os.path
import sys
from bz2 import BZ2File
from http.client import HTTPException
from logging import Logger
from pathlib import Path
from tarfile import CompressionError
from typing import Union, Optional
from urllib.error import HTTPError, URLError
from urllib.request import urlretrieve

Expand All @@ -11,16 +14,34 @@
from mask_imposer.beautifiers import TerminalProgressBar


def _unpack_bz2(filepath: str) -> str:
def _unpack_bz2(filepath: Union[Path, str]) -> str:
"""Unpack downloaded bz2 file and returns path to content."""
model_name = filepath.replace(".bz2", ".dat")
with open(model_name, "wb") as fw:
model_name = str(Path(filepath).name).replace(".bz2", ".dat")
model_fp = os.path.join(Path(os.path.abspath(__file__)).parent.parent, model_name)
with open(model_fp, "wb") as fw:
fw.write(BZ2File(filepath).read())
return model_name
return model_fp


def _accepted_download() -> bool:
def find_predictor(default_name: str, logger: Logger) -> Optional[str]:
"""Looking for a predictor in place of installed package and in current working directory. """
file_dir = Path(os.path.dirname(os.path.realpath(__file__))).parent.parent
startup_dir = os.getcwd()
for hc_dir in (file_dir, startup_dir):
logger.info("Looking for shape predictor in '%s'" % hc_dir)
for dire, _, filenames in os.walk(hc_dir): # type: ignore
for fn in filenames:
if default_name in str(fn):
predictor_fp = os.path.join(str(dire), str(fn))
logger.info("Predictor found in '%s'" % predictor_fp)
return predictor_fp
return None


def _accepted_download(auto: bool) -> bool:
"""Ask for permission to download bundled model."""
if auto:
return True
response = input(
colored("Would you like to download ", "green")
+ colored("64 [MB]", "red")
Expand All @@ -30,17 +51,19 @@ def _accepted_download() -> bool:


def download_predictor(
logger: Logger,
url: str = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2",
predictor_fp: str = "shape_predictor_68_face_landmarks.bz2"
logger: Logger,
url: str = "http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2",
predictor_name: str = "shape_predictor_68_face_landmarks.bz2",
auto: bool = False
) -> str:
"""Downloads default dlib shape predictor (68-landmark)"""

logger.warning("Shape predictor not found.")
if _accepted_download():
logger.warning("Shape predictor not passed directly.")
if _accepted_download(auto):
download_fp = Path(os.path.join("/tmp", predictor_name))
try:
urlretrieve(url, predictor_fp, TerminalProgressBar())
return _unpack_bz2(predictor_fp)
urlretrieve(url, download_fp, TerminalProgressBar())
return _unpack_bz2(download_fp)
except URLError or HTTPError or HTTPException:
logger.critical(
"Error occurred during model download. "
Expand All @@ -53,6 +76,9 @@ def download_predictor(
"Please input filepath to model via terminal arguments."
)
sys.exit()
finally:
if download_fp.exists() and download_fp.is_file():
os.remove(download_fp)
else:
logger.critical("Shape predictor not provided. Detection interrupted.")
sys.exit()
38 changes: 26 additions & 12 deletions mask_imposer/detector/landmark_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
shape_predictor)
from numpy.typing import NDArray

from .download import download_predictor
from .download import download_predictor, find_predictor
from .image import Image


Expand Down Expand Up @@ -39,23 +39,37 @@ def _shape_to_dict(shape: full_object_detection) -> Dict[int, Tuple[int, int]]:
return result


def get_predictor(predictor_fp: Optional[str], auto_download: bool, logger: Logger) -> str:
"""Looking for predictor in cwd or downloads it from dlib official page."""
if predictor_fp is None:
predictor_fp = find_predictor("shape_predictor_68_face_landmarks.dat", logger)
if predictor_fp is None:
predictor_fp = download_predictor(
logger,
auto=auto_download,
predictor_name="shape_predictor_68_face_landmarks.bz2"
)
return predictor_fp


class Detector:
def __init__(self, images: List[str], # pylint:disable=R0913
def __init__(self, # pylint:disable=R0913
predictor_fp: Optional[str],
face_detection: bool,
show_samples: bool,
auto_download: bool,
logger: Logger) -> None:
self._logger = logger
self._images = images
self._detector = get_frontal_face_detector()
if not predictor_fp:
predictor_fp = download_predictor(logger)
self._predictor = shape_predictor(predictor_fp)
self._predictor = shape_predictor(get_predictor(predictor_fp, auto_download, logger))
self._landmarks_collection: Dict[str, Dict[int, Tuple[int, int]]] = {}

self._should_detect_face_rect = face_detection
self._should_display_samples = show_samples

def forget_landmarks(self) -> None:
self._landmarks_collection = {}

@classmethod
def _display_sample(cls, image: Image, rect: dlib.rectangle,
shape: full_object_detection) -> None:
Expand Down Expand Up @@ -95,19 +109,19 @@ def _detect_face_rect(self, image: Image) -> dlib.rectangle:
# whole image as face rectangle (there should be only a center face on image)
return image.get_rectangle()

def _check_fails(self) -> None:
def _check_fails(self, images_list: List[str]) -> None:
"""Check if landmarks were found for every inputted image and warns about fails."""
if len(self._images) != len(self._landmarks_collection):
diff = len(self._images) - len(self._landmarks_collection)
if len(images_list) != len(self._landmarks_collection):
diff = len(images_list) - len(self._landmarks_collection)
self._logger.warning(f"Landmarks not found in {diff} images.")

def detect(self) -> None:
def detect(self, images_list: List[str]) -> None:
"""Creates landmark collection.
During creation may optionally display samples with drawn landmarks.
May detect face boxes, but it is preferred to pass images as stated in readme.
"""
for img_path in self._images:
for img_path in images_list:
image = Image(img_path)
try:
rect = self._detect_face_rect(image) # detect rectangles with faces
Expand All @@ -122,8 +136,8 @@ def detect(self) -> None:
except NotImplementedError: # must be changed
self._logger.warning(f"Landmarks not detected on {image}.")
continue
self._check_fails(images_list)
self._logger.info("Detection finished.")
self._check_fails()

def get_landmarks(self) -> Dict[str, Dict[int, Tuple[int, int]]]:
return self._landmarks_collection
31 changes: 19 additions & 12 deletions mask_imposer/imposer/mask_imposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import time
from logging import Logger
from os.path import join
from typing import Any, Dict, Tuple, Union
from typing import Any, Dict, Tuple, Union, Optional, List

import cv2
from numpy.typing import NDArray
Expand Down Expand Up @@ -34,18 +34,20 @@ class Imposer:

def __init__(
self,
landmarks: detections_dict,
output: Output,
output: Optional[Output],
mask_set: MaskSet,
improvements: Improvements,
logger: Logger
logger: Logger,
) -> None:
self._logger = logger
self._output_dir: str = output.directory
self._output_format: ImageFormat = output.format
if output is not None:
self.live_imposing = False
self._output_dir: str = output.directory
self._output_format: ImageFormat = output.format
else:
self.live_imposing = True
self._should_draw_landmarks = improvements.draw_landmarks
self._show_samples = improvements.show_samples
self._landmarks = landmarks
self._mask = MaskImage(mask_set)

@staticmethod
Expand All @@ -64,7 +66,7 @@ def _fit_left_top_coords(

@staticmethod
def _compute_size_surpluses(
target: NDArray, overlay: NDArray # type:ignore
target: NDArray, overlay: NDArray # type:ignore
) -> Tuple[Any, Any]:
"""Get differences between width and height limit of to-replace box
from original image and mask.
Expand Down Expand Up @@ -211,11 +213,16 @@ def save(self, filename: str, image: Image) -> None:
join(self._output_dir, f"{filename}.{self._output_format}"), image.img
)

def impose(self) -> None:
def impose(self, landmarks_collection: detections_dict) -> List[NDArray[Any]]:
"""Imposes mask image on images stored as a dictionary keys in landmarks detections."""
self._create_output_dir()
for image_fp, landmarks_dict in self._landmarks.items():
if not self.live_imposing:
self._create_output_dir()
masked_images = []
for image_fp, landmarks_dict in landmarks_collection.items():
if "masked" not in image_fp:
img_obj = Image(image_fp)
self._paste_mask(img_obj, landmarks_dict)
self.save(get_name_from(image_fp), img_obj)
masked_images.append(img_obj.img)
if not self.live_imposing:
self.save(get_name_from(image_fp), img_obj)
return masked_images
Loading

0 comments on commit 3b98f66

Please sign in to comment.