Merge pull request #895 from facefusion/multimodel-content-analyser

Reduce content analyser false positives
This commit is contained in:
Henry Ruhs
2025-06-10 18:21:10 +02:00
committed by GitHub
2 changed files with 150 additions and 37 deletions

View File

@@ -1,5 +1,5 @@
from functools import lru_cache from functools import lru_cache
from typing import List from typing import Tuple
import numpy import numpy
from tqdm import tqdm from tqdm import tqdm
@@ -8,7 +8,7 @@ from facefusion import inference_manager, state_manager, wording
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.filesystem import resolve_relative_path from facefusion.filesystem import resolve_relative_path
from facefusion.thread_helper import conditional_thread_semaphore from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.types import Detection, DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, Score, VisionFrame from facefusion.types import Detection, DownloadScope, DownloadSet, Fps, InferencePool, ModelSet, VisionFrame
from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame
STREAM_COUNTER = 0 STREAM_COUNTER = 0
@@ -18,7 +18,7 @@ STREAM_COUNTER = 0
def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
return\ return\
{ {
'yolo_nsfw': 'yolo_11m':
{ {
'hashes': 'hashes':
{ {
@@ -36,30 +36,86 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.onnx') 'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.onnx')
} }
}, },
'size': (640, 640) 'threshold': 0.2,
'size': (640, 640),
'mean': (0.0, 0.0, 0.0),
'standard_deviation': (1.0, 1.0, 1.0)
},
'marqo':
{
'hashes':
{
'content_analyser':
{
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/marqo_nsfw.hash',
'path': resolve_relative_path('../.assets/models/marqo_nsfw.hash')
}
},
'sources':
{
'content_analyser':
{
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/marqo_nsfw.onnx',
'path': resolve_relative_path('../.assets/models/marqo_nsfw.onnx')
}
},
'threshold': 0.24,
'size': (384, 384),
'mean': (0.5, 0.5, 0.5),
'standard_deviation': (0.5, 0.5, 0.5)
},
'freepik':
{
'hashes':
{
'content_analyser':
{
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/freepik_nsfw.hash',
'path': resolve_relative_path('../.assets/models/freepik_nsfw.hash')
}
},
'sources':
{
'content_analyser':
{
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/freepik_nsfw.onnx',
'path': resolve_relative_path('../.assets/models/freepik_nsfw.onnx')
}
},
'threshold': 10.5,
'size': (448, 448),
'mean': (0.48145466, 0.4578275, 0.40821073),
'standard_deviation': (0.26862954, 0.26130258, 0.27577711)
} }
} }
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ 'yolo_nsfw' ] model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
model_source_set = get_model_options().get('sources') _, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_source_set) return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
model_names = [ 'yolo_nsfw' ] model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
inference_manager.clear_inference_pool(__name__, model_names) inference_manager.clear_inference_pool(__name__, model_names)
def get_model_options() -> ModelOptions: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
return create_static_model_set('full').get('yolo_nsfw') model_set = create_static_model_set('full')
model_hash_set = {}
model_source_set = {}
for nsfw_model in [ 'yolo_11m', 'marqo', 'freepik' ]:
model_hash_set[nsfw_model] = model_set.get(nsfw_model).get('hashes').get('content_analyser')
model_source_set[nsfw_model] = model_set.get(nsfw_model).get('sources').get('content_analyser')
return model_hash_set, model_source_set
def pre_check() -> bool: def pre_check() -> bool:
model_hash_set = get_model_options().get('hashes') model_hash_set, model_source_set = collect_model_downloads()
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
@@ -74,9 +130,7 @@ def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:
def analyse_frame(vision_frame : VisionFrame) -> bool: def analyse_frame(vision_frame : VisionFrame) -> bool:
nsfw_scores = detect_nsfw(vision_frame) return detect_nsfw(vision_frame)
return len(nsfw_scores) > 0
@lru_cache(maxsize = None) @lru_cache(maxsize = None)
@@ -109,36 +163,95 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
return rate > 10.0 return rate > 10.0
def detect_nsfw(vision_frame : VisionFrame) -> List[Score]: def detect_nsfw(vision_frame : VisionFrame) -> bool:
nsfw_scores = []
model_size = get_model_options().get('size')
temp_vision_frame = fit_frame(vision_frame, model_size)
detect_vision_frame = prepare_detect_frame(temp_vision_frame)
detection = forward(detect_vision_frame)
detection = numpy.squeeze(detection).T
nsfw_scores_raw = numpy.amax(detection[:, 4:], axis = 1)
keep_indices = numpy.where(nsfw_scores_raw > 0.2)[0]
if numpy.any(keep_indices): if detect_with_yolo_11m(vision_frame):
nsfw_scores_raw = nsfw_scores_raw[keep_indices]
nsfw_scores = nsfw_scores_raw.ravel().tolist()
return nsfw_scores if detect_with_marqo(vision_frame):
return True
return detect_with_freepik(vision_frame)
return False
def forward(vision_frame : VisionFrame) -> Detection: def detect_with_yolo_11m(vision_frame : VisionFrame) -> bool:
content_analyser = get_inference_pool().get('content_analyser') model_name = 'yolo_11m'
model_set = create_static_model_set('full').get(model_name)
model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
detection = forward_yolo_11m(detect_vision_frame)
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
return detection_score > model_threshold
def detect_with_marqo(vision_frame : VisionFrame) -> bool:
model_name = 'marqo'
model_set = create_static_model_set('full').get(model_name)
model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
detection = forward_marqo(detect_vision_frame)[0]
detection_score = detection[0] - detection[1]
return detection_score > model_threshold
def detect_with_freepik(vision_frame : VisionFrame) -> bool:
model_name = 'freepik'
model_set = create_static_model_set('full').get(model_name)
model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
detection = forward_freepik(detect_vision_frame)[0]
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
return detection_score > model_threshold
def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
content_analyser = get_inference_pool().get('yolo_11m')
with conditional_thread_semaphore(): with conditional_thread_semaphore():
detection = content_analyser.run(None, detection = content_analyser.run(None,
{ {
'input': vision_frame 'input': vision_frame
}) })[0]
return detection return detection
def prepare_detect_frame(temp_vision_frame : VisionFrame) -> VisionFrame: def forward_marqo(vision_frame : VisionFrame) -> Detection:
detect_vision_frame = temp_vision_frame / 255.0 content_analyser = get_inference_pool().get('marqo')
with conditional_thread_semaphore():
detection = content_analyser.run(None,
{
'input': vision_frame
})[0]
return detection
def forward_freepik(vision_frame : VisionFrame) -> Detection:
content_analyser = get_inference_pool().get('freepik')
with conditional_thread_semaphore():
detection = content_analyser.run(None,
{
'input': vision_frame
})[0]
return detection
def prepare_detect_frame(temp_vision_frame : VisionFrame, model_name : str) -> VisionFrame:
model_set = create_static_model_set('full').get(model_name)
model_size = model_set.get('size')
model_mean = model_set.get('mean')
model_standard_deviation = model_set.get('standard_deviation')
detect_vision_frame = fit_frame(temp_vision_frame, model_size)
detect_vision_frame = detect_vision_frame[:, :, ::-1] / 255.0
detect_vision_frame -= model_mean
detect_vision_frame /= model_standard_deviation
detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32) detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32)
return detect_vision_frame return detect_vision_frame

View File

@@ -16,17 +16,17 @@ def before_all() -> None:
def test_get_inference_pool() -> None: def test_get_inference_pool() -> None:
model_names = [ 'yolo_nsfw' ] model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
model_source_set = content_analyser.get_model_options().get('sources') _, model_source_set = content_analyser.collect_model_downloads()
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
get_inference_pool('facefusion.content_analyser', model_names, model_source_set) get_inference_pool('facefusion.content_analyser', model_names, model_source_set)
assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m'), InferenceSession)
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
get_inference_pool('facefusion.content_analyser', model_names, model_source_set) get_inference_pool('facefusion.content_analyser', model_names, model_source_set)
assert isinstance(INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m'), InferenceSession)
assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m')