Feat/better resolve execution (#856)
* A better way to resolve execution providers * Fix issues * Fix issues
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import importlib
|
||||
from time import sleep
|
||||
from typing import List
|
||||
|
||||
@@ -5,7 +6,7 @@ from onnxruntime import InferenceSession
|
||||
|
||||
from facefusion import process_manager, state_manager
|
||||
from facefusion.app_context import detect_app_context
|
||||
from facefusion.execution import create_inference_execution_providers, has_execution_provider
|
||||
from facefusion.execution import create_inference_execution_providers
|
||||
from facefusion.filesystem import is_file
|
||||
from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
|
||||
|
||||
@@ -21,15 +22,17 @@ def get_inference_pool(model_context : str, model_sources : DownloadSet) -> Infe
|
||||
|
||||
while process_manager.is_checking():
|
||||
sleep(0.5)
|
||||
execution_device_id = state_manager.get_item('execution_device_id')
|
||||
execution_providers = resolve_execution_providers(model_context)
|
||||
app_context = detect_app_context()
|
||||
inference_context = get_inference_context(model_context)
|
||||
inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
|
||||
|
||||
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
||||
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
||||
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
|
||||
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
|
||||
if not INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), resolve_execution_providers(model_context))
|
||||
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers)
|
||||
|
||||
return INFERENCE_POOLS.get(app_context).get(inference_context)
|
||||
|
||||
@@ -48,8 +51,10 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
|
||||
def clear_inference_pool(model_context : str) -> None:
|
||||
global INFERENCE_POOLS
|
||||
|
||||
execution_device_id = state_manager.get_item('execution_device_id')
|
||||
execution_providers = resolve_execution_providers(model_context)
|
||||
app_context = detect_app_context()
|
||||
inference_context = get_inference_context(model_context)
|
||||
inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
|
||||
|
||||
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||
del INFERENCE_POOLS[app_context][inference_context]
|
||||
@@ -60,12 +65,14 @@ def create_inference_session(model_path : str, execution_device_id : str, execut
|
||||
return InferenceSession(model_path, providers = inference_execution_providers)
|
||||
|
||||
|
||||
def get_inference_context(model_context : str) -> str:
|
||||
inference_context = model_context + '.' + '_'.join(state_manager.get_item('execution_providers'))
|
||||
def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
|
||||
inference_context = model_context + '.' + execution_device_id + '.' + '_'.join(execution_providers)
|
||||
return inference_context
|
||||
|
||||
|
||||
def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]:
|
||||
if has_execution_provider('coreml') and model_context == 'facefusion.processors.modules.frame_colorizer':
|
||||
return [ 'cpu' ]
|
||||
module = importlib.import_module(model_context)
|
||||
|
||||
if hasattr(module, 'resolve_execution_providers'):
|
||||
return getattr(module, 'resolve_execution_providers')()
|
||||
return state_manager.get_item('execution_providers')
|
||||
|
||||
@@ -11,12 +11,13 @@ import facefusion.processors.core as processors
|
||||
from facefusion import config, content_analyser, inference_manager, logger, process_manager, state_manager, wording
|
||||
from facefusion.common_helper import create_int_metavar
|
||||
from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url
|
||||
from facefusion.execution import has_execution_provider
|
||||
from facefusion.filesystem import in_directory, is_image, is_video, resolve_relative_path, same_file_extension
|
||||
from facefusion.processors import choices as processors_choices
|
||||
from facefusion.processors.typing import FrameColorizerInputs
|
||||
from facefusion.program_helper import find_argument_group
|
||||
from facefusion.thread_helper import thread_semaphore
|
||||
from facefusion.typing import ApplyStateItem, Args, DownloadScope, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
|
||||
from facefusion.typing import ApplyStateItem, Args, DownloadScope, ExecutionProvider, Face, InferencePool, ModelOptions, ModelSet, ProcessMode, QueuePayload, UpdateProgress, VisionFrame
|
||||
from facefusion.vision import read_image, read_static_image, unpack_resolution, write_image
|
||||
|
||||
|
||||
@@ -136,6 +137,12 @@ def clear_inference_pool() -> None:
|
||||
inference_manager.clear_inference_pool(__name__)
|
||||
|
||||
|
||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||
if has_execution_provider('coreml'):
|
||||
return [ 'cpu' ]
|
||||
return state_manager.get_item('execution_providers')
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
|
||||
return create_static_model_set('full').get(frame_colorizer_model)
|
||||
|
||||
Reference in New Issue
Block a user