Feat/improve content analyser (#861)

* Introduce fit_frame to improve content analyser, rename resize_frame_resolution to restrict_frame

* Fix CI, Add some spaces

* Normalize according to face detector
This commit is contained in:
Henry Ruhs
2025-01-29 12:50:29 +01:00
committed by henryruhs
parent e79a99fac4
commit c70b45bd39
4 changed files with 42 additions and 19 deletions

View File

@@ -9,7 +9,7 @@ from facefusion.download import conditional_download_hashes, conditional_downloa
from facefusion.filesystem import resolve_relative_path
from facefusion.thread_helper import conditional_thread_semaphore
from facefusion.typing import Detection, DownloadScope, Fps, InferencePool, ModelOptions, ModelSet, Score, VisionFrame
from facefusion.vision import detect_video_fps, read_image, read_video_frame, resize_frame_resolution
from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame
STREAM_COUNTER = 0
@@ -106,7 +106,7 @@ def analyse_video(video_path : str, trim_frame_start : int, trim_frame_end : int
def detect_nsfw(vision_frame : VisionFrame) -> List[Score]:
nsfw_scores = []
model_size = get_model_options().get('size')
temp_vision_frame = resize_frame_resolution(vision_frame, model_size)
temp_vision_frame = fit_frame(vision_frame, model_size)
detect_vision_frame = prepare_detect_frame(temp_vision_frame)
detection = forward(detect_vision_frame)
detection = numpy.squeeze(detection).T
@@ -133,9 +133,6 @@ def forward(vision_frame : VisionFrame) -> Detection:
def prepare_detect_frame(temp_vision_frame : VisionFrame) -> VisionFrame:
model_size = get_model_options().get('size')
detect_vision_frame = numpy.zeros((model_size[0], model_size[1], 3))
detect_vision_frame[:temp_vision_frame.shape[0], :temp_vision_frame.shape[1], :] = temp_vision_frame
detect_vision_frame = detect_vision_frame / 255.0
detect_vision_frame = temp_vision_frame / 255.0
detect_vision_frame = numpy.expand_dims(detect_vision_frame.transpose(2, 0, 1), axis = 0).astype(numpy.float32)
return detect_vision_frame