diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index bc2d86b..8c4cfa1 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -38,7 +38,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: } }, 'size': (640, 640), - 'threshold': 0.2, 'mean': (0.0, 0.0, 0.0), 'standard_deviation': (1.0, 1.0, 1.0) }, @@ -61,7 +60,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: } }, 'size': (384, 384), - 'threshold': 0.25, 'mean': (0.5, 0.5, 0.5), 'standard_deviation': (0.5, 0.5, 0.5) }, @@ -84,7 +82,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: } }, 'size': (448, 448), - 'threshold': 10.5, 'mean': (0.48145466, 0.4578275, 0.40821073), 'standard_deviation': (0.26862954, 0.26130258, 0.27577711) } @@ -167,7 +164,7 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int progress.set_postfix(rate = rate) progress.update() - return rate > 10.0 + return bool(rate > 10.0) def detect_nsfw(vision_frame : VisionFrame) -> bool: @@ -179,30 +176,24 @@ def detect_nsfw(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool: - model_set = create_static_model_set('full').get('nsfw_1') - model_threshold = model_set.get('threshold') detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1') detection = forward_nsfw(detect_vision_frame, 'nsfw_1') detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1)) - return detection_score > model_threshold + return bool(detection_score > 0.2) def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool: - model_set = create_static_model_set('full').get('nsfw_2') - model_threshold = model_set.get('threshold') detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2') detection = forward_nsfw(detect_vision_frame, 'nsfw_2') detection_score = detection[0] - detection[1] - return detection_score > model_threshold + return bool(detection_score > 0.25) def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool: - model_set = create_static_model_set('full').get('nsfw_3') - model_threshold = model_set.get('threshold') detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3') detection = forward_nsfw(detect_vision_frame, 'nsfw_3') detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1]) - return detection_score > model_threshold + return bool(detection_score > 10.5) def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection: