Final preparations
This commit is contained in:
@@ -38,7 +38,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'size': (640, 640),
|
'size': (640, 640),
|
||||||
'threshold': 0.2,
|
|
||||||
'mean': (0.0, 0.0, 0.0),
|
'mean': (0.0, 0.0, 0.0),
|
||||||
'standard_deviation': (1.0, 1.0, 1.0)
|
'standard_deviation': (1.0, 1.0, 1.0)
|
||||||
},
|
},
|
||||||
@@ -61,7 +60,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'size': (384, 384),
|
'size': (384, 384),
|
||||||
'threshold': 0.25,
|
|
||||||
'mean': (0.5, 0.5, 0.5),
|
'mean': (0.5, 0.5, 0.5),
|
||||||
'standard_deviation': (0.5, 0.5, 0.5)
|
'standard_deviation': (0.5, 0.5, 0.5)
|
||||||
},
|
},
|
||||||
@@ -84,7 +82,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
'size': (448, 448),
|
'size': (448, 448),
|
||||||
'threshold': 10.5,
|
|
||||||
'mean': (0.48145466, 0.4578275, 0.40821073),
|
'mean': (0.48145466, 0.4578275, 0.40821073),
|
||||||
'standard_deviation': (0.26862954, 0.26130258, 0.27577711)
|
'standard_deviation': (0.26862954, 0.26130258, 0.27577711)
|
||||||
}
|
}
|
||||||
@@ -167,7 +164,7 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
|
|||||||
progress.set_postfix(rate = rate)
|
progress.set_postfix(rate = rate)
|
||||||
progress.update()
|
progress.update()
|
||||||
|
|
||||||
return rate > 10.0
|
return bool(rate > 10.0)
|
||||||
|
|
||||||
|
|
||||||
def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
||||||
@@ -179,30 +176,24 @@ def detect_nsfw(vision_frame : VisionFrame) -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool:
|
def detect_with_nsfw_1(vision_frame : VisionFrame) -> bool:
|
||||||
model_set = create_static_model_set('full').get('nsfw_1')
|
|
||||||
model_threshold = model_set.get('threshold')
|
|
||||||
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1')
|
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_1')
|
||||||
detection = forward_nsfw(detect_vision_frame, 'nsfw_1')
|
detection = forward_nsfw(detect_vision_frame, 'nsfw_1')
|
||||||
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
|
detection_score = numpy.max(numpy.amax(detection[:, 4:], axis = 1))
|
||||||
return detection_score > model_threshold
|
return bool(detection_score > 0.2)
|
||||||
|
|
||||||
|
|
||||||
def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool:
|
def detect_with_nsfw_2(vision_frame : VisionFrame) -> bool:
|
||||||
model_set = create_static_model_set('full').get('nsfw_2')
|
|
||||||
model_threshold = model_set.get('threshold')
|
|
||||||
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2')
|
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_2')
|
||||||
detection = forward_nsfw(detect_vision_frame, 'nsfw_2')
|
detection = forward_nsfw(detect_vision_frame, 'nsfw_2')
|
||||||
detection_score = detection[0] - detection[1]
|
detection_score = detection[0] - detection[1]
|
||||||
return detection_score > model_threshold
|
return bool(detection_score > 0.25)
|
||||||
|
|
||||||
|
|
||||||
def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool:
|
def detect_with_nsfw_3(vision_frame : VisionFrame) -> bool:
|
||||||
model_set = create_static_model_set('full').get('nsfw_3')
|
|
||||||
model_threshold = model_set.get('threshold')
|
|
||||||
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3')
|
detect_vision_frame = prepare_detect_frame(vision_frame, 'nsfw_3')
|
||||||
detection = forward_nsfw(detect_vision_frame, 'nsfw_3')
|
detection = forward_nsfw(detect_vision_frame, 'nsfw_3')
|
||||||
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
|
detection_score = (detection[2] + detection[3]) - (detection[0] + detection[1])
|
||||||
return detection_score > model_threshold
|
return bool(detection_score > 10.5)
|
||||||
|
|
||||||
|
|
||||||
def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection:
|
def forward_nsfw(vision_frame : VisionFrame, nsfw_model : str) -> Detection:
|
||||||
|
|||||||
Reference in New Issue
Block a user