multi model content analyser
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
|
||||
import numpy
|
||||
from tqdm import tqdm
|
||||
@@ -8,7 +8,7 @@ from facefusion import inference_manager, state_manager, wording
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.filesystem import resolve_relative_path
|
||||
from facefusion.thread_helper import conditional_thread_semaphore
|
||||
from facefusion.types import Detection, DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, Score, VisionFrame
|
||||
from facefusion.types import Detection, DownloadScope, DownloadSet, Fps, InferencePool, ModelSet, VisionFrame
|
||||
from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame
|
||||
|
||||
STREAM_COUNTER = 0
|
||||
@@ -18,7 +18,7 @@ STREAM_COUNTER = 0
|
||||
def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
return\
|
||||
{
|
||||
'yolo_nsfw':
|
||||
'yolo_11m':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
@@ -36,30 +36,86 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'path': resolve_relative_path('../.assets/models/yolo_11m_nsfw.onnx')
|
||||
}
|
||||
},
|
||||
'size': (640, 640)
|
||||
'threshold': 0.2,
|
||||
'size': (640, 640),
|
||||
'mean': (0.0, 0.0, 0.0),
|
||||
'standard_deviation': (1.0, 1.0, 1.0)
|
||||
},
|
||||
'marqo':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
'content_analyser':
|
||||
{
|
||||
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/marqo_nsfw.hash',
|
||||
'path': resolve_relative_path('../.assets/models/marqo_nsfw.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'content_analyser':
|
||||
{
|
||||
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/marqo_nsfw.onnx',
|
||||
'path': resolve_relative_path('../.assets/models/marqo_nsfw.onnx')
|
||||
}
|
||||
},
|
||||
'threshold': 0.24,
|
||||
'size': (384, 384),
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'standard_deviation': (0.5, 0.5, 0.5)
|
||||
},
|
||||
'freepik':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
'content_analyser':
|
||||
{
|
||||
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/freepik_nsfw.hash',
|
||||
'path': resolve_relative_path('../.assets/models/freepik_nsfw.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'content_analyser':
|
||||
{
|
||||
'url': 'https://huggingface.co/bluefoxcreation/Models/resolve/main/nsfw_detectors/freepik_nsfw.onnx',
|
||||
'path': resolve_relative_path('../.assets/models/freepik_nsfw.onnx')
|
||||
}
|
||||
},
|
||||
'threshold': 10.5,
|
||||
'size': (448, 448),
|
||||
'mean': (0.48145466, 0.4578275, 0.40821073),
|
||||
'standard_deviation': (0.26862954, 0.26130258, 0.27577711)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ 'yolo_nsfw' ]
|
||||
model_source_set = get_model_options().get('sources')
|
||||
model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
|
||||
_, model_source_set = collect_model_downloads()
|
||||
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_names = [ 'yolo_nsfw' ]
|
||||
model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
return create_static_model_set('full').get('yolo_nsfw')
|
||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||
model_set = create_static_model_set('full')
|
||||
model_hash_set = {}
|
||||
model_source_set = {}
|
||||
|
||||
for nsfw_model in [ 'yolo_11m', 'marqo', 'freepik' ]:
|
||||
model_hash_set[nsfw_model] = model_set.get(nsfw_model).get('hashes').get('content_analyser')
|
||||
model_source_set[nsfw_model] = model_set.get(nsfw_model).get('sources').get('content_analyser')
|
||||
|
||||
return model_hash_set, model_source_set
|
||||
|
||||
|
||||
def pre_check() -> bool:
|
||||
model_hash_set = get_model_options().get('hashes')
|
||||
model_source_set = get_model_options().get('sources')
|
||||
model_hash_set, model_source_set = collect_model_downloads()
|
||||
|
||||
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
|
||||
|
||||
@@ -74,9 +130,7 @@ def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:
|
||||
|
||||
|
||||
def analyse_frame(vision_frame : VisionFrame) -> bool:
|
||||
nsfw_scores = detect_nsfw(vision_frame)
|
||||
|
||||
return len(nsfw_scores) > 0
|
||||
return detect_nsfw(vision_frame)
|
||||
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
@@ -109,36 +163,95 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
|
||||
return rate > 10.0
|
||||
|
||||
|
||||
def detect_nsfw(vision_frame : VisionFrame) -> List[Score]:
|
||||
nsfw_scores = []
|
||||
model_size = get_model_options().get('size')
|
||||
temp_vision_frame = fit_frame(vision_frame, model_size)
|
||||
detect_vision_frame = prepare_detect_frame(temp_vision_frame)
|
||||
detection = forward(detect_vision_frame)
|
||||
detection = numpy.squeeze(detection).T
|
||||
nsfw_scores_raw = numpy.amax(detection[:, 4:], axis = 1)
|
||||
keep_indices = numpy.where(nsfw_scores_raw > 0.2)[0]
|
||||
def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
||||
|
||||
if numpy.any(keep_indices):
|
||||
nsfw_scores_raw = nsfw_scores_raw[keep_indices]
|
||||
nsfw_scores = nsfw_scores_raw.ravel().tolist()
|
||||
if detect_with_yolo_11m(vision_frame):
|
||||
|
||||
return nsfw_scores
|
||||
if detect_with_marqo(vision_frame):
|
||||
return True
|
||||
|
||||
return detect_with_freepik(vision_frame)
|
||||
return False
|
||||
|
||||
|
||||
def forward(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('content_analyser')
|
||||
def detect_with_yolo_11m(vision_frame : VisionFrame) -> bool:
|
||||
model_name = 'yolo_11m'
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
model_threshold = model_set.get('threshold')
|
||||
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
||||
detection = forward_yolo_11m(detect_vision_frame)
|
||||
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
|
||||
return detection_score > model_threshold
|
||||
|
||||
|
||||
def detect_with_marqo(vision_frame : VisionFrame) -> bool:
|
||||
model_name = 'marqo'
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
model_threshold = model_set.get('threshold')
|
||||
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
||||
detection = forward_marqo(detect_vision_frame)[0]
|
||||
detection_score = detection[0] - detection[1]
|
||||
return detection_score > model_threshold
|
||||
|
||||
|
||||
def detect_with_freepik(vision_frame : VisionFrame) -> bool:
|
||||
model_name = 'freepik'
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
model_threshold = model_set.get('threshold')
|
||||
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
||||
detection = forward_freepik(detect_vision_frame)[0]
|
||||
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
|
||||
return detection_score > model_threshold
|
||||
|
||||
|
||||
def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('yolo_11m')
|
||||
|
||||
with conditional_thread_semaphore():
|
||||
detection = content_analyser.run(None,
|
||||
{
|
||||
'input': vision_frame
|
||||
})
|
||||
})[0]
|
||||
|
||||
return detection
|
||||
|
||||
|
||||
def prepare_detect_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
|
||||
detect_vision_frame = temp_vision_frame / 255.0
|
||||
def forward_marqo(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('marqo')
|
||||
|
||||
with conditional_thread_semaphore():
|
||||
detection = content_analyser.run(None,
|
||||
{
|
||||
'input': vision_frame
|
||||
})[0]
|
||||
|
||||
return detection
|
||||
|
||||
|
||||
def forward_freepik(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('freepik')
|
||||
|
||||
with conditional_thread_semaphore():
|
||||
detection = content_analyser.run(None,
|
||||
{
|
||||
'input': vision_frame
|
||||
})[0]
|
||||
|
||||
return detection
|
||||
|
||||
|
||||
def prepare_detect_frame(temp_vision_frame : VisionFrame, model_name : str) -> VisionFrame:
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
model_size = model_set.get('size')
|
||||
model_mean = model_set.get('mean')
|
||||
model_standard_deviation = model_set.get('standard_deviation')
|
||||
|
||||
detect_vision_frame = fit_frame(temp_vision_frame, model_size)
|
||||
detect_vision_frame = detect_vision_frame[:, :, ::-1] / 255.0
|
||||
detect_vision_frame -= model_mean
|
||||
detect_vision_frame /= model_standard_deviation
|
||||
detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32)
|
||||
return detect_vision_frame
|
||||
|
||||
Reference in New Issue
Block a user