From 8efd7c1fa0feabc6a93a9f72d24df30188390539 Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Tue, 31 Dec 2024 18:34:39 +0100 Subject: [PATCH] Revert CoreML fallbacks (#841) --- facefusion/processors/modules/face_swapper.py | 3 +++ facefusion/processors/modules/frame_enhancer.py | 9 +++++++++ 2 files changed, 12 insertions(+) diff --git a/facefusion/processors/modules/face_swapper.py b/facefusion/processors/modules/face_swapper.py index bcd2b4e..f73795b 100755 --- a/facefusion/processors/modules/face_swapper.py +++ b/facefusion/processors/modules/face_swapper.py @@ -346,6 +346,9 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: face_swapper_model = state_manager.get_item('face_swapper_model') + + if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16': + return create_static_model_set('full').get('inswapper_128') return create_static_model_set('full').get(face_swapper_model) diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py index 33df5f1..b6dea11 100644 --- a/facefusion/processors/modules/frame_enhancer.py +++ b/facefusion/processors/modules/frame_enhancer.py @@ -11,6 +11,7 @@ import facefusion.processors.core as processors from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording from facefusion.common_helper import create_int_metavar from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension from facefusion.processors import choices as processors_choices from facefusion.processors.typing import FrameEnhancerInputs @@ -395,6 +396,14 @@ def clear_inference_pool() -> None: def get_model_options() -> ModelOptions: frame_enhancer_model = state_manager.get_item('frame_enhancer_model') + + if has_execution_provider('coreml'): + if frame_enhancer_model == 'real_esrgan_x2_fp16': + return create_static_model_set('full').get('real_esrgan_x2') + if frame_enhancer_model == 'real_esrgan_x4_fp16': + return create_static_model_set('full').get('real_esrgan_x4') + if frame_enhancer_model == 'real_esrgan_x8_fp16': + return create_static_model_set('full').get('real_esrgan_x8') return create_static_model_set('full').get(frame_enhancer_model)