Polish content analyser
This commit is contained in:
@@ -18,7 +18,7 @@ STREAM_COUNTER = 0
|
|||||||
def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||||
return\
|
return\
|
||||||
{
|
{
|
||||||
'yolo_11m':
|
'nsfw_1':
|
||||||
{
|
{
|
||||||
'hashes':
|
'hashes':
|
||||||
{
|
{
|
||||||
@@ -41,7 +41,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
'mean': (0.0, 0.0, 0.0),
|
'mean': (0.0, 0.0, 0.0),
|
||||||
'standard_deviation': (1.0, 1.0, 1.0)
|
'standard_deviation': (1.0, 1.0, 1.0)
|
||||||
},
|
},
|
||||||
'marqo':
|
'nsfw_2':
|
||||||
{
|
{
|
||||||
'hashes':
|
'hashes':
|
||||||
{
|
{
|
||||||
@@ -64,7 +64,7 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
'mean': (0.5, 0.5, 0.5),
|
'mean': (0.5, 0.5, 0.5),
|
||||||
'standard_deviation': (0.5, 0.5, 0.5)
|
'standard_deviation': (0.5, 0.5, 0.5)
|
||||||
},
|
},
|
||||||
'freepik':
|
'nsfw_3':
|
||||||
{
|
{
|
||||||
'hashes':
|
'hashes':
|
||||||
{
|
{
|
||||||
@@ -91,14 +91,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
|
model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]
|
||||||
_, model_source_set = collect_model_downloads()
|
_, model_source_set = collect_model_downloads()
|
||||||
|
|
||||||
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
|
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
|
model_names = [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_names)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
@@ -107,9 +107,9 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
|||||||
model_hash_set = {}
|
model_hash_set = {}
|
||||||
model_source_set = {}
|
model_source_set = {}
|
||||||
|
|
||||||
for nsfw_model in [ 'yolo_11m', 'marqo', 'freepik' ]:
|
for content_analyser_model in [ 'nsfw_1', 'nsfw_2', 'nsfw_3' ]:
|
||||||
model_hash_set[nsfw_model] = model_set.get(nsfw_model).get('hashes').get('content_analyser')
|
model_hash_set[content_analyser_model] = model_set.get(content_analyser_model).get('hashes').get('content_analyser')
|
||||||
model_source_set[nsfw_model] = model_set.get(nsfw_model).get('sources').get('content_analyser')
|
model_source_set[content_analyser_model] = model_set.get(content_analyser_model).get('sources').get('content_analyser')
|
||||||
|
|
||||||
return model_hash_set, model_source_set
|
return model_hash_set, model_source_set
|
||||||
|
|
||||||
@@ -164,51 +164,42 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
|
|||||||
|
|
||||||
|
|
||||||
def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
||||||
|
is_nsfw_1 = detect_with_nsfw_1(vision_frame)
|
||||||
|
is_nsfw_2 = detect_with_nsfw_2(vision_frame)
|
||||||
|
is_nsfw_3 = detect_with_nsfw_3(vision_frame)
|
||||||
|
|
||||||
if detect_with_yolo_11m(vision_frame):
|
return is_nsfw_1 and is_nsfw_2 or is_nsfw_1 and is_nsfw_3 or is_nsfw_2 and is_nsfw_3
|
||||||
|
|
||||||
if detect_with_marqo(vision_frame):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return detect_with_freepik(vision_frame)
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def detect_with_yolo_11m(vision_frame : VisionFrame) -> bool:
|
def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool:
|
||||||
model_name = 'yolo_11m'
|
model_set = create_static_model_set('full').get('nsfw_1')
|
||||||
model_set = create_static_model_set('full').get(model_name)
|
|
||||||
model_threshold = model_set.get('threshold')
|
model_threshold = model_set.get('threshold')
|
||||||
|
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1')
|
||||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
detection = forward_nsfw(detect_vision_frame, 'nsfw_1')
|
||||||
detection = forward_yolo_11m(detect_vision_frame)
|
|
||||||
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
|
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
|
||||||
return detection_score > model_threshold
|
return detection_score > model_threshold
|
||||||
|
|
||||||
|
|
||||||
def detect_with_marqo(vision_frame : VisionFrame) -> bool:
|
def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool:
|
||||||
model_name = 'marqo'
|
model_set = create_static_model_set('full').get('nsfw_2')
|
||||||
model_set = create_static_model_set('full').get(model_name)
|
|
||||||
model_threshold = model_set.get('threshold')
|
model_threshold = model_set.get('threshold')
|
||||||
|
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2')
|
||||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
detection = forward_nsfw(detect_vision_frame, 'nsfw_2')
|
||||||
detection = forward_marqo(detect_vision_frame)[0]
|
|
||||||
detection_score = detection[0] - detection[1]
|
detection_score = detection[0] - detection[1]
|
||||||
return detection_score > model_threshold
|
return detection_score > model_threshold
|
||||||
|
|
||||||
|
|
||||||
def detect_with_freepik(vision_frame : VisionFrame) -> bool:
|
def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool:
|
||||||
model_name = 'freepik'
|
model_set = create_static_model_set('full').get('nsfw_3')
|
||||||
model_set = create_static_model_set('full').get(model_name)
|
|
||||||
model_threshold = model_set.get('threshold')
|
model_threshold = model_set.get('threshold')
|
||||||
|
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3')
|
||||||
detect_vision_frame = prepare_detect_frame(vision_frame, model_name)
|
detection = forward_nsfw(detect_vision_frame, 'nsfw_3')
|
||||||
detection = forward_freepik(detect_vision_frame)[0]
|
|
||||||
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
|
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
|
||||||
return detection_score > model_threshold
|
return detection_score > model_threshold
|
||||||
|
|
||||||
|
|
||||||
def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
|
def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection:
|
||||||
content_analyser = get_inference_pool().get('yolo_11m')
|
content_analyser = get_inference_pool().get(nsfw_model)
|
||||||
|
|
||||||
with conditional_thread_semaphore():
|
with conditional_thread_semaphore():
|
||||||
detection = content_analyser.run(None,
|
detection = content_analyser.run(None,
|
||||||
@@ -216,29 +207,8 @@ def forward_yolo_11m(vision_frame : VisionFrame) -> Detection:
|
|||||||
'input': vision_frame
|
'input': vision_frame
|
||||||
})[0]
|
})[0]
|
||||||
|
|
||||||
return detection
|
if nsfw_model in [ 'nsfw_2', 'nsfw_3' ]:
|
||||||
|
return detection[0]
|
||||||
|
|
||||||
def forward_marqo(vision_frame : VisionFrame) -> Detection:
|
|
||||||
content_analyser = get_inference_pool().get('marqo')
|
|
||||||
|
|
||||||
with conditional_thread_semaphore():
|
|
||||||
detection = content_analyser.run(None,
|
|
||||||
{
|
|
||||||
'input': vision_frame
|
|
||||||
})[0]
|
|
||||||
|
|
||||||
return detection
|
|
||||||
|
|
||||||
|
|
||||||
def forward_freepik(vision_frame : VisionFrame) -> Detection:
|
|
||||||
content_analyser = get_inference_pool().get('freepik')
|
|
||||||
|
|
||||||
with conditional_thread_semaphore():
|
|
||||||
detection = content_analyser.run(None,
|
|
||||||
{
|
|
||||||
'input': vision_frame
|
|
||||||
})[0]
|
|
||||||
|
|
||||||
return detection
|
return detection
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user