Polish content analyser

This commit is contained in:
henryruhs
2025-06-10 19:30:58 +02:00
parent f65aabfd72
commit 8e80ab0d21

View File

@@ -18,7 +18,7 @@ STREAM_COUNTER = 0
def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
return\ return\
{ {
'yolo_11m': 'nsfw_1':
{ {
'hashes': 'hashes':
{ {
@@ -41,7 +41,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'mean': (0.0, 0.0, 0.0), 'mean': (0.0, 0.0, 0.0),
'standard_deviation': (1.0, 1.0, 1.0) 'standard_deviation': (1.0, 1.0, 1.0)
}, },
'marqo': 'nsfw_2':
{ {
'hashes': 'hashes':
{ {
@@ -64,7 +64,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
'mean': (0.5, 0.5, 0.5), 'mean': (0.5, 0.5, 0.5),
'standard_deviation': (0.5, 0.5, 0.5) 'standard_deviation': (0.5, 0.5, 0.5)
}, },
'freepik': 'nsfw_3':
{ {
'hashes': 'hashes':
{ {
@@ -91,14 +91,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ 'yolo_11m', 'marqo', 'freepik' ] model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]
_, model_source_set = collect_model_downloads() _, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_source_set) return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
model_names = [ 'yolo_11m', 'marqo', 'freepik' ] model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]
inference_manager.clear_inference_pool(__name__, model_names) inference_manager.clear_inference_pool(__name__, model_names)
@@ -107,9 +107,9 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_hash_set = {} model_hash_set = {}
model_source_set = {} model_source_set = {}
for nsfw_model in [ 'yolo_11m', 'marqo', 'freepik' ]: for content_analyser_model in [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]:
model_hash_set[nsfw_model] = model_set.get(nsfw_model).get('hashes').get('content_analyser') model_hash_set[content_analyser_model] = model_set.get(content_analyser_model).get('hashes').get('content_analyser')
model_source_set[nsfw_model] = model_set.get(nsfw_model).get('sources').get('content_analyser') model_source_set[content_analyser_model] = model_set.get(content_analyser_model).get('sources').get('content_analyser')
return model_hash_set, model_source_set return model_hash_set, model_source_set
@@ -164,51 +164,42 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
def detect_nsfw(vision_frame : VisionFrame) -> bool: def detect_nsfw(vision_frame : VisionFrame) -> bool:
is_nsfw_1 = detect_with_nsfw_1(vision_frame)
is_nsfw_2 = detect_with_nsfw_2(vision_frame)
is_nsfw_3 = detect_with_nsfw_3(vision_frame)
if detect_with_yolo_11m(vision_frame): return is_nsfw_1 and is_nsfw_2 or is_nsfw_1 and is_nsfw_3 or is_nsfw_2 and is_nsfw_3
if detect_with_marqo(vision_frame):
return True
return detect_with_freepik(vision_frame)
return False
def detect_with_yolo_11m(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool:
model_name = 'yolo_11m' model_set = create_static_model_set('full').get('nsfw_1')
model_set = create_static_model_set('full').get(model_name)
model_threshold = model_set.get('threshold') model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1')
detect_vision_frame = prepare_detect_frame(vision_frame, model_name) detection = forward_nsfw(detect_vision_frame, 'nsfw_1')
detection = forward_yolo_11m(detect_vision_frame)
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1)) detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
return detection_score > model_threshold return detection_score > model_threshold
def detect_with_marqo(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool:
model_name = 'marqo' model_set = create_static_model_set('full').get('nsfw_2')
model_set = create_static_model_set('full').get(model_name)
model_threshold = model_set.get('threshold') model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2')
detect_vision_frame = prepare_detect_frame(vision_frame, model_name) detection = forward_nsfw(detect_vision_frame, 'nsfw_2')
detection = forward_marqo(detect_vision_frame)[0]
detection_score = detection[0] - detection[1] detection_score = detection[0] - detection[1]
return detection_score > model_threshold return detection_score > model_threshold
def detect_with_freepik(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool:
model_name = 'freepik' model_set = create_static_model_set('full').get('nsfw_3')
model_set = create_static_model_set('full').get(model_name)
model_threshold = model_set.get('threshold') model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3')
detect_vision_frame = prepare_detect_frame(vision_frame, model_name) detection = forward_nsfw(detect_vision_frame, 'nsfw_3')
detection = forward_freepik(detect_vision_frame)[0]
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1]) detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
return detection_score > model_threshold return detection_score > model_threshold
def forward_yolo_11m(vision_frame : VisionFrame) -> Detection: def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection:
content_analyser = get_inference_pool().get('yolo_11m') content_analyser = get_inference_pool().get(nsfw_model)
with conditional_thread_semaphore(): with conditional_thread_semaphore():
detection = content_analyser.run(None, detection = content_analyser.run(None,
@@ -216,29 +207,8 @@ def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
'input': vision_frame 'input': vision_frame
})[0] })[0]
return detection if nsfw_model in [ 'nsfw_2', 'nsfw_3' ]:
return detection[0]
def forward_marqo(vision_frame : VisionFrame) -> Detection:
content_analyser = get_inference_pool().get('marqo')
with conditional_thread_semaphore():
detection = content_analyser.run(None,
{
'input': vision_frame
})[0]
return detection
def forward_freepik(vision_frame : VisionFrame) -> Detection:
content_analyser = get_inference_pool().get('freepik')
with conditional_thread_semaphore():
detection = content_analyser.run(None,
{
'input': vision_frame
})[0]
return detection return detection