From 3b80d66bf42debdb66f2eb6a3672f7bbe798335d Mon Sep 17 00:00:00 2001 From: Henry Ruhs Date: Sun, 19 Jan 2025 11:05:36 +0100 Subject: [PATCH] Feat/better resolve execution (#856) * A better way to resolve execution providers * Fix issues * Fix issues --- facefusion/inference_manager.py | 23 ++++++++++++------- .../processors/modules/frame_colorizer.py | 9 +++++++- tests/test_face_analyser.py | 2 +- ...ence_pool.py => test_inference_manager.py} | 8 +++---- 4 files changed, 28 insertions(+), 14 deletions(-) rename tests/{test_inference_pool.py => test_inference_manager.py} (66%) diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index 703835e..7c1e9b7 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -1,3 +1,4 @@ +import importlib from time import sleep from typing import List @@ -5,7 +6,7 @@ from onnxruntime import InferenceSession from facefusion import process_manager, state_manager from facefusion.app_context import detect_app_context -from facefusion.execution import create_inference_execution_providers, has_execution_provider +from facefusion.execution import create_inference_execution_providers from facefusion.filesystem import is_file from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet @@ -21,15 +22,17 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe while process_manager.is_checking(): sleep(0.5) + execution_device_id = state_manager.get_item('execution_device_id') + execution_providers = resolve_execution_providers(model_context) app_context = detect_app_context() - inference_context = get_inference_context(model_context) + inference_context = get_inference_context(model_context, execution_device_id, execution_providers) if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) if not INFERENCE_POOLS.get(app_context).get(inference_context): - INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), resolve_execution_providers(model_context)) + INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers) return INFERENCE_POOLS.get(app_context).get(inference_context) @@ -48,8 +51,10 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str def clear_inference_pool(model_context : str) -> None: global INFERENCE_POOLS + execution_device_id = state_manager.get_item('execution_device_id') + execution_providers = resolve_execution_providers(model_context) app_context = detect_app_context() - inference_context = get_inference_context(model_context) + inference_context = get_inference_context(model_context, execution_device_id, execution_providers) if INFERENCE_POOLS.get(app_context).get(inference_context): del INFERENCE_POOLS[app_context][inference_context] @@ -60,12 +65,14 @@ def create_inference_session(model_path : str, execution_device_id : str, execut return InferenceSession(model_path, providers = inference_execution_providers) -def get_inference_context(model_context : str) -> str: - inference_context = model_context + '.' + '_'.join(state_manager.get_item('execution_providers')) +def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: + inference_context = model_context + '.' + execution_device_id + '.' + '_'.join(execution_providers) return inference_context def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]: - if has_execution_provider('coreml') and model_context == 'facefusion.processors.modules.frame_colorizer': - return [ 'cpu' ] + module = importlib.import_module(model_context) + + if hasattr(module, 'resolve_execution_providers'): + return getattr(module, 'resolve_execution_providers')() return state_manager.get_item('execution_providers') diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py index 3f53e9a..789cb11 100644 --- a/facefusion/processors/modules/frame_colorizer.py +++ b/facefusion/processors/modules/frame_colorizer.py @@ -11,12 +11,13 @@ import facefusion.processors.core as processors from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording from facefusion.common_helper import create_int_metavar from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension from facefusion.processors import choices as processors_choices from facefusion.processors.typing import FrameColorizerInputs from facefusion.program_helper import find_argument_group from facefusion.thread_helper import thread_semaphore -from facefusion.typing import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame +from facefusion.typing import ApplyStateItem, Args, DownloadScope, ExecutionProvider, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image @@ -136,6 +137,12 @@ def clear_inference_pool() -> None: inference_manager.clear_inference_pool(__name__) +def resolve_execution_providers() -> List[ExecutionProvider]: + if has_execution_provider('coreml'): + return [ 'cpu' ] + return state_manager.get_item('execution_providers') + + def get_model_options() -> ModelOptions: frame_colorizer_model = state_manager.get_item('frame_colorizer_model') return create_static_model_set('full').get(frame_colorizer_model) diff --git a/tests/test_face_analyser.py b/tests/test_face_analyser.py index 81b479e..785faf0 100644 --- a/tests/test_face_analyser.py +++ b/tests/test_face_analyser.py @@ -19,7 +19,7 @@ def before_all() -> None: subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ]) - state_manager.init_item('execution_device_id', 0) + state_manager.init_item('execution_device_id', '0') state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('download_providers', [ 'github' ]) state_manager.init_item('face_detector_angles', [ 0 ]) diff --git a/tests/test_inference_pool.py b/tests/test_inference_manager.py similarity index 66% rename from tests/test_inference_pool.py rename to tests/test_inference_manager.py index 1749226..3667cd2 100644 --- a/tests/test_inference_pool.py +++ b/tests/test_inference_manager.py @@ -9,7 +9,7 @@ from facefusion.inference_manager import INFERENCE_POOLS, get_inference_pool @pytest.fixture(scope = 'module', autouse = True) def before_all() -> None: - state_manager.init_item('execution_device_id', 0) + state_manager.init_item('execution_device_id', '0') state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('download_providers', [ 'github' ]) content_analyser.pre_check() @@ -21,11 +21,11 @@ def test_get_inference_pool() -> None: with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): get_inference_pool('test', model_sources) - assert isinstance(INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser'), InferenceSession) with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): get_inference_pool('test', model_sources) - assert isinstance(INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser'), InferenceSession) - assert INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser') + assert INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser')