diff --git a/facefusion/benchmarker.py b/facefusion/benchmarker.py index ae90f5e..a9d4bcc 100644 --- a/facefusion/benchmarker.py +++ b/facefusion/benchmarker.py @@ -3,7 +3,7 @@ import os import statistics import tempfile from time import perf_counter -from typing import Generator, List +from typing import Iterator, List import facefusion.choices from facefusion import content_analyser, core, state_manager @@ -31,7 +31,7 @@ def pre_check() -> bool: return True -def run() -> Generator[List[BenchmarkCycleSet], None, None]: +def run() -> Iterator[List[BenchmarkCycleSet]]: benchmark_resolutions = state_manager.get_item('benchmark_resolutions') benchmark_cycle_count = state_manager.get_item('benchmark_cycle_count') diff --git a/facefusion/execution.py b/facefusion/execution.py index d39be91..c5acd50 100644 --- a/facefusion/execution.py +++ b/facefusion/execution.py @@ -28,7 +28,7 @@ def get_available_execution_providers() -> List[ExecutionProvider]: return available_execution_providers -def create_inference_session_providers(execution_device_id : str, execution_providers : List[ExecutionProvider]) -> List[InferenceSessionProvider]: +def create_inference_session_providers(execution_device_id : int, execution_providers : List[ExecutionProvider]) -> List[InferenceSessionProvider]: inference_session_providers : List[InferenceSessionProvider] = [] for execution_provider in execution_providers: @@ -89,10 +89,10 @@ def resolve_cudnn_conv_algo_search() -> str: return 'EXHAUSTIVE' -def resolve_openvino_device_type(execution_device_id : str) -> str: - if execution_device_id == '0': +def resolve_openvino_device_type(execution_device_id : int) -> str: + if execution_device_id == 0: return 'GPU' - return 'GPU.' + execution_device_id + return 'GPU.' + str(execution_device_id) def run_nvidia_smi() -> subprocess.Popen[bytes]: diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index e14e715..1060fbf 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -42,7 +42,7 @@ def get_inference_pool(module_name : str, model_names : List[str], model_source_ return INFERENCE_POOL_SET.get(app_context).get(current_inference_context) -def create_inference_pool(model_source_set : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool: +def create_inference_pool(model_source_set : DownloadSet, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferencePool: inference_pool : InferencePool = {} for model_name in model_source_set.keys(): @@ -67,7 +67,7 @@ def clear_inference_pool(module_name : str, model_names : List[str]) -> None: del INFERENCE_POOL_SET[app_context][inference_context] -def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession: +def create_inference_session(model_path : str, execution_device_id : int, execution_providers : List[ExecutionProvider]) -> InferenceSession: model_file_name = get_file_name(model_path) start_time = time() @@ -82,8 +82,8 @@ def create_inference_session(model_path : str, execution_device_id : str, execut fatal_exit(1) -def get_inference_context(module_name : str, model_names : List[str], execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: - inference_context = '.'.join([ module_name ] + model_names + [ execution_device_id ] + list(execution_providers)) +def get_inference_context(module_name : str, model_names : List[str], execution_device_id : int, execution_providers : List[ExecutionProvider]) -> str: + inference_context = '.'.join([ module_name ] + model_names + [ str(execution_device_id) ] + list(execution_providers)) return inference_context diff --git a/facefusion/metadata.py b/facefusion/metadata.py index eefc4cd..462ebf2 100644 --- a/facefusion/metadata.py +++ b/facefusion/metadata.py @@ -4,7 +4,7 @@ METADATA =\ { 'name': 'FaceFusion', 'description': 'Industry leading face manipulation platform', - 'version': '3.5.0', + 'version': '3.5.1', 'license': 'OpenRAIL-AS', 'author': 'Henry Ruhs', 'url': 'https://facefusion.io' diff --git a/facefusion/program.py b/facefusion/program.py index d3d6087..ab7760b 100755 --- a/facefusion/program.py +++ b/facefusion/program.py @@ -234,7 +234,7 @@ def create_execution_program() -> ArgumentParser: program = ArgumentParser(add_help = False) available_execution_providers = get_available_execution_providers() group_execution = program.add_argument_group('execution') - group_execution.add_argument('--execution-device-ids', help = translator.get('help.execution_device_ids'), type = int, default = config.get_str_list('execution', 'execution_device_ids', '0'), nargs = '+', metavar = 'EXECUTION_DEVICE_IDS') + group_execution.add_argument('--execution-device-ids', help = translator.get('help.execution_device_ids'), type = int, default = config.get_int_list('execution', 'execution_device_ids', '0'), nargs = '+', metavar = 'EXECUTION_DEVICE_IDS') group_execution.add_argument('--execution-providers', help = translator.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution', 'execution_providers', get_first(available_execution_providers)), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS') group_execution.add_argument('--execution-thread-count', help = translator.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution', 'execution_thread_count', '8'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range)) job_store.register_job_keys([ 'execution_device_ids', 'execution_providers', 'execution_thread_count' ]) diff --git a/facefusion/streamer.py b/facefusion/streamer.py index 5cb4702..66e523e 100644 --- a/facefusion/streamer.py +++ b/facefusion/streamer.py @@ -2,7 +2,7 @@ import os import subprocess from collections import deque from concurrent.futures import ThreadPoolExecutor -from typing import Deque, Generator +from typing import Deque, Iterator import cv2 import numpy @@ -18,7 +18,7 @@ from facefusion.types import Fps, StreamMode, VisionFrame from facefusion.vision import extract_vision_mask, read_static_images -def multi_process_capture(camera_capture : cv2.VideoCapture, camera_fps : Fps) -> Generator[VisionFrame, None, None]: +def multi_process_capture(camera_capture : cv2.VideoCapture, camera_fps : Fps) -> Iterator[VisionFrame]: capture_deque : Deque[VisionFrame] = deque() with tqdm(desc = translator.get('streaming'), unit = 'frame', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: diff --git a/facefusion/types.py b/facefusion/types.py index f5b20df..ac58699 100755 --- a/facefusion/types.py +++ b/facefusion/types.py @@ -385,7 +385,7 @@ State = TypedDict('State', 'open_browser' : bool, 'ui_layouts' : List[str], 'ui_workflow' : UiWorkflow, - 'execution_device_ids' : List[str], + 'execution_device_ids' : List[int], 'execution_providers' : List[ExecutionProvider], 'execution_thread_count' : int, 'video_memory_strategy' : VideoMemoryStrategy, diff --git a/facefusion/uis/components/benchmark.py b/facefusion/uis/components/benchmark.py index c0c7695..b62d81a 100644 --- a/facefusion/uis/components/benchmark.py +++ b/facefusion/uis/components/benchmark.py @@ -1,4 +1,4 @@ -from typing import Any, Generator, List, Optional +from typing import Any, Iterator, List, Optional import gradio @@ -44,7 +44,7 @@ def listen() -> None: BENCHMARK_START_BUTTON.click(start, outputs = BENCHMARK_BENCHMARKS_DATAFRAME) -def start() -> Generator[List[Any], None, None]: +def start() -> Iterator[List[Any]]: state_manager.sync_state() for benchmark in benchmarker.run(): diff --git a/facefusion/uis/components/preview.py b/facefusion/uis/components/preview.py index a7b8ded..1482299 100755 --- a/facefusion/uis/components/preview.py +++ b/facefusion/uis/components/preview.py @@ -18,7 +18,7 @@ from facefusion.types import AudioFrame, Face, Mask, VisionFrame from facefusion.uis import choices as uis_choices from facefusion.uis.core import get_ui_component, get_ui_components, register_ui_component from facefusion.uis.types import ComponentOptions, PreviewMode -from facefusion.vision import conditional_merge_vision_mask, detect_frame_orientation, extract_vision_mask, fit_cover_frame, obscure_frame, read_static_image, read_static_images, read_video_frame, restrict_frame, unpack_resolution +from facefusion.vision import detect_frame_orientation, extract_vision_mask, fit_cover_frame, merge_vision_mask, obscure_frame, read_static_image, read_static_images, read_video_frame, restrict_frame, unpack_resolution PREVIEW_IMAGE : Optional[gradio.Image] = None @@ -197,7 +197,7 @@ def update_preview_image(preview_mode : PreviewMode, preview_resolution : str, f reference_vision_frame = read_static_image(state_manager.get_item('target_path')) target_vision_frame = read_static_image(state_manager.get_item('target_path'), 'rgba') target_vision_mask = extract_vision_mask(target_vision_frame) - target_vision_frame = conditional_merge_vision_mask(target_vision_frame, target_vision_mask) + target_vision_frame = merge_vision_mask(target_vision_frame, target_vision_mask) preview_vision_frame = process_preview_frame(reference_vision_frame, source_vision_frames, source_audio_frame, source_voice_frame, target_vision_frame, preview_mode, preview_resolution) preview_vision_frame = cv2.cvtColor(preview_vision_frame, cv2.COLOR_BGRA2RGBA) return gradio.Image(value = preview_vision_frame, elem_classes = [ 'image-preview', 'is-' + detect_frame_orientation(preview_vision_frame) ]) @@ -206,7 +206,7 @@ def update_preview_image(preview_mode : PreviewMode, preview_resolution : str, f reference_vision_frame = read_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) temp_vision_frame = read_video_frame(state_manager.get_item('target_path'), frame_number) temp_vision_mask = extract_vision_mask(temp_vision_frame) - temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) + temp_vision_frame = merge_vision_mask(temp_vision_frame, temp_vision_mask) preview_vision_frame = process_preview_frame(reference_vision_frame, source_vision_frames, source_audio_frame, source_voice_frame, temp_vision_frame, preview_mode, preview_resolution) preview_vision_frame = cv2.cvtColor(preview_vision_frame, cv2.COLOR_BGRA2RGBA) return gradio.Image(value = preview_vision_frame, elem_classes = [ 'image-preview', 'is-' + detect_frame_orientation(preview_vision_frame) ]) @@ -297,6 +297,6 @@ def extract_crop_frame(vision_frame : VisionFrame, face : Face) -> Optional[Visi def prepare_output_frame(target_vision_frame : VisionFrame, temp_vision_frame : VisionFrame, temp_vision_mask : Mask) -> VisionFrame: temp_vision_mask = temp_vision_mask.clip(state_manager.get_item('background_remover_color')[-1], 255) - temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) + temp_vision_frame = merge_vision_mask(temp_vision_frame, temp_vision_mask) temp_vision_frame = cv2.resize(temp_vision_frame, target_vision_frame.shape[1::-1]) return temp_vision_frame diff --git a/facefusion/uis/components/webcam.py b/facefusion/uis/components/webcam.py index 1988c68..effc9c3 100644 --- a/facefusion/uis/components/webcam.py +++ b/facefusion/uis/components/webcam.py @@ -1,4 +1,4 @@ -from typing import Generator, List, Optional, Tuple +from typing import Iterator, List, Optional, Tuple import cv2 import gradio @@ -82,7 +82,7 @@ def pre_stop() -> Tuple[gradio.File, gradio.Image, gradio.Button, gradio.Button] return gradio.File(visible = True), gradio.Image(visible = False), gradio.Button(visible = True), gradio.Button(visible = False) -def start(webcam_device_id : int, webcam_mode : WebcamMode, webcam_resolution : str, webcam_fps : Fps) -> Generator[VisionFrame, None, None]: +def start(webcam_device_id : int, webcam_mode : WebcamMode, webcam_resolution : str, webcam_fps : Fps) -> Iterator[VisionFrame]: state_manager.init_item('face_selector_mode', 'one') state_manager.sync_state() diff --git a/facefusion/vision.py b/facefusion/vision.py index f50256b..561e374 100644 --- a/facefusion/vision.py +++ b/facefusion/vision.py @@ -355,7 +355,11 @@ def extract_vision_mask(vision_frame : VisionFrame) -> Mask: return numpy.full(vision_frame.shape[:2], 255, dtype = numpy.uint8) +def merge_vision_mask(vision_frame : VisionFrame, vision_mask : Mask) -> VisionFrame: + return numpy.dstack((vision_frame[:, :, :3], vision_mask)) + + def conditional_merge_vision_mask(vision_frame : VisionFrame, vision_mask : Mask) -> VisionFrame: if numpy.any(vision_mask < 255): - return numpy.dstack((vision_frame[:, :, :3], vision_mask)) + return merge_vision_mask(vision_frame, vision_mask) return vision_frame diff --git a/facefusion/workflows/image_to_image.py b/facefusion/workflows/image_to_image.py index f1df279..ab16a93 100644 --- a/facefusion/workflows/image_to_image.py +++ b/facefusion/workflows/image_to_image.py @@ -19,7 +19,7 @@ def process(start_time : float) -> ErrorCode: setup, prepare_image, process_image, - partial(finalize_image, start_time), + partial(finalize_image, start_time) ] process_manager.start() diff --git a/facefusion/workflows/image_to_video.py b/facefusion/workflows/image_to_video.py index f247854..00cf326 100644 --- a/facefusion/workflows/image_to_video.py +++ b/facefusion/workflows/image_to_video.py @@ -152,18 +152,6 @@ def restore_audio() -> ErrorCode: return 0 -def finalize_video(start_time : float) -> ErrorCode: - logger.debug(translator.get('clearing_temp'), __name__) - clear_temp_directory(state_manager.get_item('target_path')) - - if is_video(state_manager.get_item('output_path')): - logger.info(translator.get('processing_video_succeeded').format(seconds = calculate_end_time(start_time)), __name__) - else: - logger.error(translator.get('processing_video_failed'), __name__) - return 1 - return 0 - - def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool: reference_vision_frame = read_static_video_frame(state_manager.get_item('target_path'), state_manager.get_item('reference_frame_number')) source_vision_frames = read_static_images(state_manager.get_item('source_paths')) @@ -195,3 +183,15 @@ def process_temp_frame(temp_frame_path : str, frame_number : int) -> bool: temp_vision_frame = conditional_merge_vision_mask(temp_vision_frame, temp_vision_mask) return write_image(temp_frame_path, temp_vision_frame) + + +def finalize_video(start_time : float) -> ErrorCode: + logger.debug(translator.get('clearing_temp'), __name__) + clear_temp_directory(state_manager.get_item('target_path')) + + if is_video(state_manager.get_item('output_path')): + logger.info(translator.get('processing_video_succeeded').format(seconds = calculate_end_time(start_time)), __name__) + else: + logger.error(translator.get('processing_video_failed'), __name__) + return 1 + return 0 diff --git a/tests/test_execution.py b/tests/test_execution.py index 3c4a8c1..c2cd241 100644 --- a/tests/test_execution.py +++ b/tests/test_execution.py @@ -15,10 +15,10 @@ def test_create_inference_session_providers() -> None: [ ('CUDAExecutionProvider', { - 'device_id': '1', + 'device_id': 1, 'cudnn_conv_algo_search': 'EXHAUSTIVE' }), 'CPUExecutionProvider' ] - assert create_inference_session_providers('1', [ 'cpu', 'cuda' ]) == inference_session_providers + assert create_inference_session_providers(1, [ 'cpu', 'cuda' ]) == inference_session_providers diff --git a/tests/test_face_analyser.py b/tests/test_face_analyser.py index 6bd3a7c..962c7f7 100644 --- a/tests/test_face_analyser.py +++ b/tests/test_face_analyser.py @@ -18,7 +18,7 @@ def before_all() -> None: subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ]) - state_manager.init_item('execution_device_ids', [ '0' ]) + state_manager.init_item('execution_device_ids', [ 0 ]) state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('download_providers', [ 'github' ]) state_manager.init_item('face_detector_angles', [ 0 ]) diff --git a/tests/test_inference_manager.py b/tests/test_inference_manager.py index 78cb46d..8f90e9f 100644 --- a/tests/test_inference_manager.py +++ b/tests/test_inference_manager.py @@ -9,7 +9,7 @@ from facefusion.inference_manager import INFERENCE_POOL_SET, get_inference_pool @pytest.fixture(scope = 'module', autouse = True) def before_all() -> None: - state_manager.init_item('execution_device_ids', [ '0' ]) + state_manager.init_item('execution_device_ids', [ 0 ]) state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('download_providers', [ 'github' ]) content_analyser.pre_check()