diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index ee25f58..59abc1f 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -1,6 +1,6 @@ from functools import lru_cache +from typing import List -import cv2 import numpy from tqdm import tqdm @@ -8,11 +8,9 @@ from facefusion import inference_manager, state_manager, wording from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore -from facefusion.typing import DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, VisionFrame -from facefusion.vision import detect_video_fps, get_video_frame, read_image +from facefusion.typing import Detection, DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, Score, VisionFrame +from facefusion.vision import detect_video_fps, get_video_frame, read_image, resize_frame_resolution -PROBABILITY_LIMIT = 0.80 -RATE_LIMIT = 10 STREAM_COUNTER = 0 @@ -20,26 +18,25 @@ STREAM_COUNTER = 0 def create_static_model_set(download_scope : DownloadScope) -> ModelSet: return\ { - 'open_nsfw': + 'yolo_nsfw': { 'hashes': { 'content_analyser': { - 'url': resolve_download_url('models-3.0.0', 'open_nsfw.hash'), - 'path': resolve_relative_path('../.assets/models/open_nsfw.hash') + 'url': resolve_download_url('models-3.2.0', 'yolo_11m_nsfw.hash'), + 'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.hash') } }, 'sources': { 'content_analyser': { - 'url': resolve_download_url('models-3.0.0', 'open_nsfw.onnx'), - 'path': resolve_relative_path('../.assets/models/open_nsfw.onnx') + 'url': resolve_download_url('models-3.2.0', 'yolo_11m_nsfw.onnx'), + 'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.onnx') } }, - 'size': (224, 224), - 'mean': [ 104, 117, 123 ] + 'size': (640, 640) } } @@ -54,7 +51,7 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: - return create_static_model_set('full').get('open_nsfw') + return create_static_model_set('full').get('yolo_nsfw') def pre_check() -> bool: @@ -74,31 +71,9 @@ def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool: def analyse_frame(vision_frame : VisionFrame) -> bool: - vision_frame = prepare_frame(vision_frame) - probability = forward(vision_frame) + nsfw_scores = detect_nsfw(vision_frame) - return probability > PROBABILITY_LIMIT - - -def forward(vision_frame : VisionFrame) -> float: - content_analyser = get_inference_pool().get('content_analyser') - - with conditional_thread_semaphore(): - probability = content_analyser.run(None, - { - 'input': vision_frame - })[0][0][1] - - return probability - - -def prepare_frame(vision_frame : VisionFrame) -> VisionFrame: - model_size = get_model_options().get('size') - model_mean = get_model_options().get('mean') - vision_frame = cv2.resize(vision_frame, model_size).astype(numpy.float32) - vision_frame -= numpy.array(model_mean).astype(numpy.float32) - vision_frame = numpy.expand_dims(vision_frame, axis = 0) - return vision_frame + return len(nsfw_scores) > 0 @lru_cache(maxsize = None) @@ -115,12 +90,52 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int counter = 0 with tqdm(total = len(frame_range), desc = wording.get('analysing'), unit = 'frame', ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: + for frame_number in frame_range: if frame_number % int(video_fps) == 0: vision_frame = get_video_frame(video_path, frame_number) if analyse_frame(vision_frame): counter += 1 rate = counter * int(video_fps) / len(frame_range) * 100 - progress.update() progress.set_postfix(rate = rate) - return rate > RATE_LIMIT + progress.update() + + return rate > 10.0 + + +def detect_nsfw(vision_frame : VisionFrame) -> List[Score]: + nsfw_scores = [] + model_size = get_model_options().get('size') + temp_vision_frame = resize_frame_resolution(vision_frame, model_size) + detect_vision_frame = prepare_detect_frame(temp_vision_frame) + detection = forward(detect_vision_frame) + detection = numpy.squeeze(detection).T + nsfw_scores_raw = numpy.amax(detection[:, 4:], axis = 1) + keep_indices = numpy.where(nsfw_scores_raw > 0.2)[0] + + if numpy.any(keep_indices): + nsfw_scores_raw = nsfw_scores_raw[keep_indices] + nsfw_scores = nsfw_scores_raw.ravel().tolist() + + return nsfw_scores + + +def forward(vision_frame : VisionFrame) -> Detection: + content_analyser = get_inference_pool().get('content_analyser') + + with conditional_thread_semaphore(): + detection = content_analyser.run(None, + { + 'input': vision_frame + }) + + return detection + + +def prepare_detect_frame(temp_vision_frame : VisionFrame) -> VisionFrame: + model_size = get_model_options().get('size') + detect_vision_frame = numpy.zeros((model_size[0], model_size[1], 3)) + detect_vision_frame[:temp_vision_frame.shape[0], :temp_vision_frame.shape[1], :] = temp_vision_frame + detect_vision_frame = detect_vision_frame / 255.0 + detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) + return detect_vision_frame diff --git a/facefusion/face_helper.py b/facefusion/face_helper.py index 9218fdb..970d328 100644 --- a/facefusion/face_helper.py +++ b/facefusion/face_helper.py @@ -208,9 +208,9 @@ def estimate_face_angle(face_landmark_68 : FaceLandmark68) -> Angle: return face_angle -def apply_nms(bounding_boxes : List[BoundingBox], face_scores : List[Score], score_threshold : float, nms_threshold : float) -> Sequence[int]: +def apply_nms(bounding_boxes : List[BoundingBox], scores : List[Score], score_threshold : float, nms_threshold : float) -> Sequence[int]: normed_bounding_boxes = [ (x1, y1, x2 - x1, y2 - y1) for (x1, y1, x2, y2) in bounding_boxes ] - keep_indices = cv2.dnn.NMSBoxes(normed_bounding_boxes, face_scores, score_threshold = score_threshold, nms_threshold = nms_threshold) + keep_indices = cv2.dnn.NMSBoxes(normed_bounding_boxes, scores, score_threshold = score_threshold, nms_threshold = nms_threshold) return keep_indices