Polish content analyser
This commit is contained in:
@@ -18,7 +18,7 @@ STREAM_COUNTER = 0
|
||||
def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
return\
|
||||
{
|
||||
'yolo_11m':
|
||||
'nsfw_1':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
@@ -41,7 +41,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'mean': (0.0, 0.0, 0.0),
|
||||
'standard_deviation': (1.0, 1.0, 1.0)
|
||||
},
|
||||
'marqo':
|
||||
'nsfw_2':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
@@ -64,7 +64,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
'mean': (0.5, 0.5, 0.5),
|
||||
'standard_deviation': (0.5, 0.5, 0.5)
|
||||
},
|
||||
'freepik':
|
||||
'nsfw_3':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
@@ -91,14 +91,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
|
||||
model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]
|
||||
_, model_source_set = collect_model_downloads()
|
||||
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
|
||||
model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
@@ -107,9 +107,9 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||
model_hash_set = {}
|
||||
model_source_set = {}
|
||||
|
||||
for nsfw_model in [ 'yolo_11m', 'marqo', 'freepik' ]:
|
||||
model_hash_set[nsfw_model] = model_set.get(nsfw_model).get('hashes').get('content_analyser')
|
||||
model_source_set[nsfw_model] = model_set.get(nsfw_model).get('sources').get('content_analyser')
|
||||
for content_analyser_model in [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]:
|
||||
model_hash_set[content_analyser_model] = model_set.get(content_analyser_model).get('hashes').get('content_analyser')
|
||||
model_source_set[content_analyser_model] = model_set.get(content_analyser_model).get('sources').get('content_analyser')
|
||||
|
||||
return model_hash_set, model_source_set
|
||||
|
||||
@@ -164,51 +164,42 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
|
||||
|
||||
|
||||
def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
||||
is_nsfw_1 = detect_with_nsfw_1(vision_frame)
|
||||
is_nsfw_2 = detect_with_nsfw_2(vision_frame)
|
||||
is_nsfw_3 = detect_with_nsfw_3(vision_frame)
|
||||
|
||||
if detect_with_yolo_11m(vision_frame):
|
||||
|
||||
if detect_with_marqo(vision_frame):
|
||||
return True
|
||||
|
||||
return detect_with_freepik(vision_frame)
|
||||
return False
|
||||
return is_nsfw_1 and is_nsfw_2 or is_nsfw_1 and is_nsfw_3 or is_nsfw_2 and is_nsfw_3
|
||||
|
||||
|
||||
def detect_with_yolo_11m(vision_frame : VisionFrame) -> bool:
|
||||
model_name = 'yolo_11m'
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool:
|
||||
model_set = create_static_model_set('full').get('nsfw_1')
|
||||
model_threshold = model_set.get('threshold')
|
||||
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
||||
detection = forward_yolo_11m(detect_vision_frame)
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1')
|
||||
detection = forward_nsfw(detect_vision_frame, 'nsfw_1')
|
||||
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
|
||||
return detection_score > model_threshold
|
||||
|
||||
|
||||
def detect_with_marqo(vision_frame : VisionFrame) -> bool:
|
||||
model_name = 'marqo'
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool:
|
||||
model_set = create_static_model_set('full').get('nsfw_2')
|
||||
model_threshold = model_set.get('threshold')
|
||||
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
||||
detection = forward_marqo(detect_vision_frame)[0]
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2')
|
||||
detection = forward_nsfw(detect_vision_frame, 'nsfw_2')
|
||||
detection_score = detection[0] - detection[1]
|
||||
return detection_score > model_threshold
|
||||
|
||||
|
||||
def detect_with_freepik(vision_frame : VisionFrame) -> bool:
|
||||
model_name = 'freepik'
|
||||
model_set = create_static_model_set('full').get(model_name)
|
||||
def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool:
|
||||
model_set = create_static_model_set('full').get('nsfw_3')
|
||||
model_threshold = model_set.get('threshold')
|
||||
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
||||
detection = forward_freepik(detect_vision_frame)[0]
|
||||
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3')
|
||||
detection = forward_nsfw(detect_vision_frame, 'nsfw_3')
|
||||
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
|
||||
return detection_score > model_threshold
|
||||
|
||||
|
||||
def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('yolo_11m')
|
||||
def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection:
|
||||
content_analyser = get_inference_pool().get(nsfw_model)
|
||||
|
||||
with conditional_thread_semaphore():
|
||||
detection = content_analyser.run(None,
|
||||
@@ -216,29 +207,8 @@ def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
|
||||
'input': vision_frame
|
||||
})[0]
|
||||
|
||||
return detection
|
||||
|
||||
|
||||
def forward_marqo(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('marqo')
|
||||
|
||||
with conditional_thread_semaphore():
|
||||
detection = content_analyser.run(None,
|
||||
{
|
||||
'input': vision_frame
|
||||
})[0]
|
||||
|
||||
return detection
|
||||
|
||||
|
||||
def forward_freepik(vision_frame : VisionFrame) -> Detection:
|
||||
content_analyser = get_inference_pool().get('freepik')
|
||||
|
||||
with conditional_thread_semaphore():
|
||||
detection = content_analyser.run(None,
|
||||
{
|
||||
'input': vision_frame
|
||||
})[0]
|
||||
if nsfw_model in [ 'nsfw_2', 'nsfw_3' ]:
|
||||
return detection[0]
|
||||
|
||||
return detection
|
||||
|
||||
|
||||
Reference in New Issue
Block a user