Feat/better resolve execution (#856)
* A better way to resolve execution providers * Fix issues * Fix issues
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
import importlib
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@@ -5,7 +6,7 @@ from onnxruntime import InferenceSession
|
|||||||
|
|
||||||
from facefusion import process_manager, state_manager
|
from facefusion import process_manager, state_manager
|
||||||
from facefusion.app_context import detect_app_context
|
from facefusion.app_context import detect_app_context
|
||||||
from facefusion.execution import create_inference_execution_providers, has_execution_provider
|
from facefusion.execution import create_inference_execution_providers
|
||||||
from facefusion.filesystem import is_file
|
from facefusion.filesystem import is_file
|
||||||
from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
|
from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
|
||||||
|
|
||||||
@@ -21,15 +22,17 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe
|
|||||||
|
|
||||||
while process_manager.is_checking():
|
while process_manager.is_checking():
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
|
execution_device_id = state_manager.get_item('execution_device_id')
|
||||||
|
execution_providers = resolve_execution_providers(model_context)
|
||||||
app_context = detect_app_context()
|
app_context = detect_app_context()
|
||||||
inference_context = get_inference_context(model_context)
|
inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
|
||||||
|
|
||||||
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
||||||
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
||||||
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
|
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
|
||||||
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
|
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
|
||||||
if not INFERENCE_POOLS.get(app_context).get(inference_context):
|
if not INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||||
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), resolve_execution_providers(model_context))
|
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers)
|
||||||
|
|
||||||
return INFERENCE_POOLS.get(app_context).get(inference_context)
|
return INFERENCE_POOLS.get(app_context).get(inference_context)
|
||||||
|
|
||||||
@@ -48,8 +51,10 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
|
|||||||
def clear_inference_pool(model_context : str) -> None:
|
def clear_inference_pool(model_context : str) -> None:
|
||||||
global INFERENCE_POOLS
|
global INFERENCE_POOLS
|
||||||
|
|
||||||
|
execution_device_id = state_manager.get_item('execution_device_id')
|
||||||
|
execution_providers = resolve_execution_providers(model_context)
|
||||||
app_context = detect_app_context()
|
app_context = detect_app_context()
|
||||||
inference_context = get_inference_context(model_context)
|
inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
|
||||||
|
|
||||||
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||||
del INFERENCE_POOLS[app_context][inference_context]
|
del INFERENCE_POOLS[app_context][inference_context]
|
||||||
@@ -60,12 +65,14 @@ def create_inference_session(model_path : str, execution_device_id : str, execut
|
|||||||
return InferenceSession(model_path, providers = inference_execution_providers)
|
return InferenceSession(model_path, providers = inference_execution_providers)
|
||||||
|
|
||||||
|
|
||||||
def get_inference_context(model_context : str) -> str:
|
def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
|
||||||
inference_context = model_context + '.' + '_'.join(state_manager.get_item('execution_providers'))
|
inference_context = model_context + '.' + execution_device_id + '.' + '_'.join(execution_providers)
|
||||||
return inference_context
|
return inference_context
|
||||||
|
|
||||||
|
|
||||||
def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]:
|
def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]:
|
||||||
if has_execution_provider('coreml') and model_context == 'facefusion.processors.modules.frame_colorizer':
|
module = importlib.import_module(model_context)
|
||||||
return [ 'cpu' ]
|
|
||||||
|
if hasattr(module, 'resolve_execution_providers'):
|
||||||
|
return getattr(module, 'resolve_execution_providers')()
|
||||||
return state_manager.get_item('execution_providers')
|
return state_manager.get_item('execution_providers')
|
||||||
|
|||||||
@@ -11,12 +11,13 @@ import facefusion.processors.core as processors
|
|||||||
from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording
|
from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording
|
||||||
from facefusion.common_helper import create_int_metavar
|
from facefusion.common_helper import create_int_metavar
|
||||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||||
|
from facefusion.execution import has_execution_provider
|
||||||
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension
|
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension
|
||||||
from facefusion.processors import choices as processors_choices
|
from facefusion.processors import choices as processors_choices
|
||||||
from facefusion.processors.typing import FrameColorizerInputs
|
from facefusion.processors.typing import FrameColorizerInputs
|
||||||
from facefusion.program_helper import find_argument_group
|
from facefusion.program_helper import find_argument_group
|
||||||
from facefusion.thread_helper import thread_semaphore
|
from facefusion.thread_helper import thread_semaphore
|
||||||
from facefusion.typing import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
|
from facefusion.typing import ApplyStateItem, Args, DownloadScope, ExecutionProvider, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
|
||||||
from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image
|
from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image
|
||||||
|
|
||||||
|
|
||||||
@@ -136,6 +137,12 @@ def clear_inference_pool() -> None:
|
|||||||
inference_manager.clear_inference_pool(__name__)
|
inference_manager.clear_inference_pool(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||||
|
if has_execution_provider('coreml'):
|
||||||
|
return [ 'cpu' ]
|
||||||
|
return state_manager.get_item('execution_providers')
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
|
frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
|
||||||
return create_static_model_set('full').get(frame_colorizer_model)
|
return create_static_model_set('full').get(frame_colorizer_model)
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def before_all() -> None:
|
|||||||
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ])
|
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.8:ih*0.8', get_test_example_file('source-80crop.jpg') ])
|
||||||
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ])
|
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.7:ih*0.7', get_test_example_file('source-70crop.jpg') ])
|
||||||
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ])
|
subprocess.run([ 'ffmpeg', '-i', get_test_example_file('source.jpg'), '-vf', 'crop=iw*0.6:ih*0.6', get_test_example_file('source-60crop.jpg') ])
|
||||||
state_manager.init_item('execution_device_id', 0)
|
state_manager.init_item('execution_device_id', '0')
|
||||||
state_manager.init_item('execution_providers', [ 'cpu' ])
|
state_manager.init_item('execution_providers', [ 'cpu' ])
|
||||||
state_manager.init_item('download_providers', [ 'github' ])
|
state_manager.init_item('download_providers', [ 'github' ])
|
||||||
state_manager.init_item('face_detector_angles', [ 0 ])
|
state_manager.init_item('face_detector_angles', [ 0 ])
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from facefusion.inference_manager import INFERENCE_POOLS, get_inference_pool
|
|||||||
|
|
||||||
@pytest.fixture(scope = 'module', autouse = True)
|
@pytest.fixture(scope = 'module', autouse = True)
|
||||||
def before_all() -> None:
|
def before_all() -> None:
|
||||||
state_manager.init_item('execution_device_id', 0)
|
state_manager.init_item('execution_device_id', '0')
|
||||||
state_manager.init_item('execution_providers', [ 'cpu' ])
|
state_manager.init_item('execution_providers', [ 'cpu' ])
|
||||||
state_manager.init_item('download_providers', [ 'github' ])
|
state_manager.init_item('download_providers', [ 'github' ])
|
||||||
content_analyser.pre_check()
|
content_analyser.pre_check()
|
||||||
@@ -21,11 +21,11 @@ def test_get_inference_pool() -> None:
|
|||||||
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
|
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
|
||||||
get_inference_pool('test', model_sources)
|
get_inference_pool('test', model_sources)
|
||||||
|
|
||||||
assert isinstance(INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser'), InferenceSession)
|
assert isinstance(INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser'), InferenceSession)
|
||||||
|
|
||||||
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
|
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
|
||||||
get_inference_pool('test', model_sources)
|
get_inference_pool('test', model_sources)
|
||||||
|
|
||||||
assert isinstance(INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser'), InferenceSession)
|
assert isinstance(INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser'), InferenceSession)
|
||||||
|
|
||||||
assert INFERENCE_POOLS.get('cli').get('test.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.cpu').get('content_analyser')
|
assert INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser')
|
||||||
Reference in New Issue
Block a user