Feat/better resolve execution (#856)

* A better way to resolve execution providers

* Fix issues

* Fix issues
This commit is contained in:
Henry Ruhs
2025-01-19 11:05:36 +01:00
committed by henryruhs
parent 330f86a4e4
commit 3b80d66bf4
4 changed files with 28 additions and 14 deletions

View File

@@ -1,3 +1,4 @@
import importlib
from time import sleep from time import sleep
from typing import List from typing import List
@@ -5,7 +6,7 @@ from onnxruntime import InferenceSession
from facefusion import process_manager, state_manager from facefusion import process_manager, state_manager
from facefusion.app_context import detect_app_context from facefusion.app_context import detect_app_context
from facefusion.execution import create_inference_execution_providers, has_execution_provider from facefusion.execution import create_inference_execution_providers
from facefusion.filesystem import is_file from facefusion.filesystem import is_file
from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
@@ -21,15 +22,17 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
execution_device_id = state_manager.get_item('execution_device_id')
execution_providers = resolve_execution_providers(model_context)
app_context = detect_app_context() app_context = detect_app_context()
inference_context = get_inference_context(model_context) inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
if not INFERENCE_POOLS.get(app_context).get(inference_context): if not INFERENCE_POOLS.get(app_context).get(inference_context):
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), resolve_execution_providers(model_context)) INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers)
return INFERENCE_POOLS.get(app_context).get(inference_context) return INFERENCE_POOLS.get(app_context).get(inference_context)
@@ -48,8 +51,10 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
def clear_inference_pool(model_context : str) -> None: def clear_inference_pool(model_context : str) -> None:
global INFERENCE_POOLS global INFERENCE_POOLS
execution_device_id = state_manager.get_item('execution_device_id')
execution_providers = resolve_execution_providers(model_context)
app_context = detect_app_context() app_context = detect_app_context()
inference_context = get_inference_context(model_context) inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
if INFERENCE_POOLS.get(app_context).get(inference_context): if INFERENCE_POOLS.get(app_context).get(inference_context):
del INFERENCE_POOLS[app_context][inference_context] del INFERENCE_POOLS[app_context][inference_context]
@@ -60,12 +65,14 @@ def create_inference_session(model_path : str, execution_device_id : str, execut
return InferenceSession(model_path, providers = inference_execution_providers) return InferenceSession(model_path, providers = inference_execution_providers)
def get_inference_context(model_context : str) -> str: def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
inference_context = model_context + '.' + '_'.join(state_manager.get_item('execution_providers')) inference_context = model_context + '.' + execution_device_id + '.' + '_'.join(execution_providers)
return inference_context return inference_context
def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]: def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]:
if has_execution_provider('coreml') and model_context == 'facefusion.processors.modules.frame_colorizer': module = importlib.import_module(model_context)
return [ 'cpu' ]
if hasattr(module, 'resolve_execution_providers'):
return getattr(module, 'resolve_execution_providers')()
return state_manager.get_item('execution_providers') return state_manager.get_item('execution_providers')

View File

@@ -11,12 +11,13 @@ import facefusion.processors.core as processors
from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording
from facefusion.common_helper import create_int_metavar from facefusion.common_helper import create_int_metavar
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
from facefusion.execution import has_execution_provider
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension
from facefusion.processors import choices as processors_choices from facefusion.processors import choices as processors_choices
from facefusion.processors.typing import FrameColorizerInputs from facefusion.processors.typing import FrameColorizerInputs
from facefusion.program_helper import find_argument_group from facefusion.program_helper import find_argument_group
from facefusion.thread_helper import thread_semaphore from facefusion.thread_helper import thread_semaphore
from facefusion.typing import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame from facefusion.typing import ApplyStateItem, Args, DownloadScope, ExecutionProvider, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image
@@ -136,6 +137,12 @@ def clear_inference_pool() -> None:
inference_manager.clear_inference_pool(__name__) inference_manager.clear_inference_pool(__name__)
def resolve_execution_providers() -> List[ExecutionProvider]:
if has_execution_provider('coreml'):
return [ 'cpu' ]
return state_manager.get_item('execution_providers')
def get_model_options() -> ModelOptions: def get_model_options() -> ModelOptions:
frame_colorizer_model = state_manager.get_item('frame_colorizer_model') frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
return create_static_model_set('full').get(frame_colorizer_model) return create_static_model_set('full').get(frame_colorizer_model)

View File

@@ -19,7 +19,7 @@ def before_all() -> None:
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ])
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ])
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ]) subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ])
state_manager.init_item('execution_device_id', 0) state_manager.init_item('execution_device_id', '0')
state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('execution_providers', [ 'cpu' ])
state_manager.init_item('download_providers', [ 'github' ]) state_manager.init_item('download_providers', [ 'github' ])
state_manager.init_item('face_detector_angles', [ 0 ]) state_manager.init_item('face_detector_angles', [ 0 ])

View File

@@ -9,7 +9,7 @@ from facefusion.inference_manager import INFERENCE_POOLS, get_inference_pool
@pytest.fixture(scope = 'module', autouse = True) @pytest.fixture(scope = 'module', autouse = True)
def before_all() -> None: def before_all() -> None:
state_manager.init_item('execution_device_id', 0) state_manager.init_item('execution_device_id', '0')
state_manager.init_item('execution_providers', [ 'cpu' ]) state_manager.init_item('execution_providers', [ 'cpu' ])
state_manager.init_item('download_providers', [ 'github' ]) state_manager.init_item('download_providers', [ 'github' ])
content_analyser.pre_check() content_analyser.pre_check()
@@ -21,11 +21,11 @@ def test_get_inference_pool() -> None:
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
get_inference_pool('test', model_sources) get_inference_pool('test', model_sources)
assert isinstance(INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser'), InferenceSession)
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
get_inference_pool('test', model_sources) get_inference_pool('test', model_sources)
assert isinstance(INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser'), InferenceSession)
assert INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser') assert INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser')