From 8e80ab0d21013d61f78dc5fd7763c1e89368c1a6 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Tue, 10 Jun 2025 19:30:58 +0200 Subject: [PATCH] Polish content analyser --- facefusion/content_analyser.py | 86 +++++++++++----------------------- 1 file changed, 28 insertions(+), 58 deletions(-) diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index 15face6..b9605f5 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -18,7 +18,7 @@ STREAM_COUNTER = 0 def create_static_model_set(download_scope : DownloadScope) -> ModelSet: return\ { - 'yolo_11m': + 'nsfw_1': { 'hashes': { @@ -41,7 +41,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'mean': (0.0, 0.0, 0.0), 'standard_deviation': (1.0, 1.0, 1.0) }, - 'marqo': + 'nsfw_2': { 'hashes': { @@ -64,7 +64,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: 'mean': (0.5, 0.5, 0.5), 'standard_deviation': (0.5, 0.5, 0.5) }, - 'freepik': + 'nsfw_3': { 'hashes': { @@ -91,14 +91,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: - model_names = [ 'yolo_11m', 'marqo', 'freepik' ] + model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ] _, model_source_set = collect_model_downloads() return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: - model_names = [ 'yolo_11m', 'marqo', 'freepik' ] + model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ] inference_manager.clear_inference_pool(__name__, model_names) @@ -107,9 +107,9 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: model_hash_set = {} model_source_set = {} - for nsfw_model in [ 'yolo_11m', 'marqo', 'freepik' ]: - model_hash_set[nsfw_model] = model_set.get(nsfw_model).get('hashes').get('content_analyser') - model_source_set[nsfw_model] = model_set.get(nsfw_model).get('sources').get('content_analyser') + for content_analyser_model in [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]: + model_hash_set[content_analyser_model] = model_set.get(content_analyser_model).get('hashes').get('content_analyser') + model_source_set[content_analyser_model] = model_set.get(content_analyser_model).get('sources').get('content_analyser') return model_hash_set, model_source_set @@ -164,51 +164,42 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int def detect_nsfw(vision_frame : VisionFrame) -> bool: + is_nsfw_1 = detect_with_nsfw_1(vision_frame) + is_nsfw_2 = detect_with_nsfw_2(vision_frame) + is_nsfw_3 = detect_with_nsfw_3(vision_frame) - if detect_with_yolo_11m(vision_frame): - - if detect_with_marqo(vision_frame): - return True - - return detect_with_freepik(vision_frame) - return False + return is_nsfw_1 and is_nsfw_2 or is_nsfw_1 and is_nsfw_3 or is_nsfw_2 and is_nsfw_3 -def detect_with_yolo_11m(vision_frame : VisionFrame) -> bool: - model_name = 'yolo_11m' - model_set = create_static_model_set('full').get(model_name) +def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool: + model_set = create_static_model_set('full').get('nsfw_1') model_threshold = model_set.get('threshold') - - detect_vision_frame = prepare_detect_frame(vision_frame, model_name) - detection = forward_yolo_11m(detect_vision_frame) + detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1') + detection = forward_nsfw(detect_vision_frame, 'nsfw_1') detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1)) return detection_score > model_threshold -def detect_with_marqo(vision_frame : VisionFrame) -> bool: - model_name = 'marqo' - model_set = create_static_model_set('full').get(model_name) +def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool: + model_set = create_static_model_set('full').get('nsfw_2') model_threshold = model_set.get('threshold') - - detect_vision_frame = prepare_detect_frame(vision_frame, model_name) - detection = forward_marqo(detect_vision_frame)[0] + detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2') + detection = forward_nsfw(detect_vision_frame, 'nsfw_2') detection_score = detection[0] - detection[1] return detection_score > model_threshold -def detect_with_freepik(vision_frame : VisionFrame) -> bool: - model_name = 'freepik' - model_set = create_static_model_set('full').get(model_name) +def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool: + model_set = create_static_model_set('full').get('nsfw_3') model_threshold = model_set.get('threshold') - - detect_vision_frame = prepare_detect_frame(vision_frame, model_name) - detection = forward_freepik(detect_vision_frame)[0] + detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3') + detection = forward_nsfw(detect_vision_frame, 'nsfw_3') detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1]) return detection_score > model_threshold -def forward_yolo_11m(vision_frame : VisionFrame) -> Detection: - content_analyser = get_inference_pool().get('yolo_11m') +def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection: + content_analyser = get_inference_pool().get(nsfw_model) with conditional_thread_semaphore(): detection = content_analyser.run(None, @@ -216,29 +207,8 @@ def forward_yolo_11m(vision_frame : VisionFrame) -> Detection: 'input': vision_frame })[0] - return detection - - -def forward_marqo(vision_frame : VisionFrame) -> Detection: - content_analyser = get_inference_pool().get('marqo') - - with conditional_thread_semaphore(): - detection = content_analyser.run(None, - { - 'input': vision_frame - })[0] - - return detection - - -def forward_freepik(vision_frame : VisionFrame) -> Detection: - content_analyser = get_inference_pool().get('freepik') - - with conditional_thread_semaphore(): - detection = content_analyser.run(None, - { - 'input': vision_frame - })[0] + if nsfw_model in [ 'nsfw_2', 'nsfw_3' ]: + return detection[0] return detection