Final preparations

This commit is contained in:
henryruhs
2025-06-21 12:55:10 +02:00
parent 43e1e4bf44
commit 8f2687801b

View File

@@ -38,7 +38,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
} }
}, },
'size': (640, 640), 'size': (640, 640),
'threshold': 0.2,
'mean': (0.0, 0.0, 0.0), 'mean': (0.0, 0.0, 0.0),
'standard_deviation': (1.0, 1.0, 1.0) 'standard_deviation': (1.0, 1.0, 1.0)
}, },
@@ -61,7 +60,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
} }
}, },
'size': (384, 384), 'size': (384, 384),
'threshold': 0.25,
'mean': (0.5, 0.5, 0.5), 'mean': (0.5, 0.5, 0.5),
'standard_deviation': (0.5, 0.5, 0.5) 'standard_deviation': (0.5, 0.5, 0.5)
}, },
@@ -84,7 +82,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
} }
}, },
'size': (448, 448), 'size': (448, 448),
'threshold': 10.5,
'mean': (0.48145466, 0.4578275, 0.40821073), 'mean': (0.48145466, 0.4578275, 0.40821073),
'standard_deviation': (0.26862954, 0.26130258, 0.27577711) 'standard_deviation': (0.26862954, 0.26130258, 0.27577711)
} }
@@ -167,7 +164,7 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
progress.set_postfix(rate = rate) progress.set_postfix(rate = rate)
progress.update() progress.update()
return rate > 10.0 return bool(rate > 10.0)
def detect_nsfw(vision_frame : VisionFrame) -> bool: def detect_nsfw(vision_frame : VisionFrame) -> bool:
@@ -179,30 +176,24 @@ def detect_nsfw(vision_frame : VisionFrame) -> bool:
def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool:
model_set = create_static_model_set('full').get('nsfw_1')
model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1') detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1')
detection = forward_nsfw(detect_vision_frame, 'nsfw_1') detection = forward_nsfw(detect_vision_frame, 'nsfw_1')
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1)) detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
return detection_score > model_threshold return bool(detection_score > 0.2)
def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool:
model_set = create_static_model_set('full').get('nsfw_2')
model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2') detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2')
detection = forward_nsfw(detect_vision_frame, 'nsfw_2') detection = forward_nsfw(detect_vision_frame, 'nsfw_2')
detection_score = detection[0] - detection[1] detection_score = detection[0] - detection[1]
return detection_score > model_threshold return bool(detection_score > 0.25)
def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool: def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool:
model_set = create_static_model_set('full').get('nsfw_3')
model_threshold = model_set.get('threshold')
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3') detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3')
detection = forward_nsfw(detect_vision_frame, 'nsfw_3') detection = forward_nsfw(detect_vision_frame, 'nsfw_3')
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1]) detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
return detection_score > model_threshold return bool(detection_score > 10.5)
def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection: def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection: