We name it InferenceSessionProvider

This commit is contained in:
henryruhs
2025-01-22 22:51:52 +01:00
parent 5270bd679c
commit dbbf3445b6
4 changed files with 20 additions and 19 deletions

View File

@@ -8,7 +8,7 @@ from onnxruntime import get_available_providers, set_default_logger_severity
import facefusion.choices import facefusion.choices
from facefusion.common_helper import get_last from facefusion.common_helper import get_last
from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit from facefusion.typing import ExecutionDevice, ExecutionProvider, InferenceSessionProvider, ValueAndUnit
set_default_logger_severity(3) set_default_logger_severity(3)
@@ -26,28 +26,28 @@ def suggest_execution_provider(execution_providers : List[ExecutionProvider]) ->
def get_available_execution_providers() -> List[ExecutionProvider]: def get_available_execution_providers() -> List[ExecutionProvider]:
inference_execution_providers = get_available_providers() inference_session_providers = get_available_providers()
available_execution_providers = [] available_execution_providers = []
for execution_provider, execution_provider_value in facefusion.choices.execution_provider_set.items(): for execution_provider, execution_provider_value in facefusion.choices.execution_provider_set.items():
if execution_provider_value in inference_execution_providers: if execution_provider_value in inference_session_providers:
available_execution_providers.append(execution_provider) available_execution_providers.append(execution_provider)
return available_execution_providers return available_execution_providers
def create_inference_execution_providers(execution_device_id : str, execution_providers : List[ExecutionProvider]) -> List[Any]: def create_inference_session_providers(execution_device_id : str, execution_providers : List[ExecutionProvider]) -> List[InferenceSessionProvider]:
inference_execution_providers : List[Any] = [] inference_session_providers : List[InferenceSessionProvider] = []
for execution_provider in execution_providers: for execution_provider in execution_providers:
if execution_provider == 'cuda': if execution_provider == 'cuda':
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{ {
'device_id': execution_device_id, 'device_id': execution_device_id,
'cudnn_conv_algo_search': 'DEFAULT' if is_geforce_16_series() else 'EXHAUSTIVE' 'cudnn_conv_algo_search': 'DEFAULT' if is_geforce_16_series() else 'EXHAUSTIVE'
})) }))
if execution_provider == 'tensorrt': if execution_provider == 'tensorrt':
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{ {
'device_id': execution_device_id, 'device_id': execution_device_id,
'trt_engine_cache_enable': True, 'trt_engine_cache_enable': True,
@@ -57,23 +57,23 @@ def create_inference_execution_providers(execution_device_id : str, execution_pr
'trt_builder_optimization_level': 5 'trt_builder_optimization_level': 5
})) }))
if execution_provider in [ 'directml', 'rocm' ]: if execution_provider in [ 'directml', 'rocm' ]:
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{ {
'device_id': execution_device_id 'device_id': execution_device_id
})) }))
if execution_provider == 'openvino': if execution_provider == 'openvino':
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_session_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{ {
'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id, 'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id,
'precision': 'FP32' 'precision': 'FP32'
})) }))
if execution_provider == 'coreml': if execution_provider == 'coreml':
inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider)) inference_session_providers.append(facefusion.choices.execution_provider_set.get(execution_provider))
if 'cpu' in execution_providers: if 'cpu' in execution_providers:
inference_execution_providers.append(facefusion.choices.execution_provider_set.get('cpu')) inference_session_providers.append(facefusion.choices.execution_provider_set.get('cpu'))
return inference_execution_providers return inference_session_providers
def is_geforce_16_series() -> bool: def is_geforce_16_series() -> bool:

View File

@@ -6,7 +6,7 @@ from onnxruntime import InferenceSession
from facefusion import process_manager, state_manager from facefusion import process_manager, state_manager
from facefusion.app_context import detect_app_context from facefusion.app_context import detect_app_context
from facefusion.execution import create_inference_execution_providers from facefusion.execution import create_inference_session_providers
from facefusion.filesystem import is_file from facefusion.filesystem import is_file
from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet from facefusion.typing import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
@@ -61,8 +61,8 @@ def clear_inference_pool(model_context : str) -> None:
def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession: def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession:
inference_execution_providers = create_inference_execution_providers(execution_device_id, execution_providers) inference_session_providers = create_inference_session_providers(execution_device_id, execution_providers)
return InferenceSession(model_path, providers = inference_execution_providers) return InferenceSession(model_path, providers = inference_session_providers)
def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:

View File

@@ -131,6 +131,7 @@ ModelInitializer = NDArray[Any]
ExecutionProvider = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino', 'rocm', 'tensorrt'] ExecutionProvider = Literal['cpu', 'coreml', 'cuda', 'directml', 'openvino', 'rocm', 'tensorrt']
ExecutionProviderValue = Literal['CPUExecutionProvider', 'CoreMLExecutionProvider', 'CUDAExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'TensorrtExecutionProvider'] ExecutionProviderValue = Literal['CPUExecutionProvider', 'CoreMLExecutionProvider', 'CUDAExecutionProvider', 'DmlExecutionProvider', 'OpenVINOExecutionProvider', 'ROCMExecutionProvider', 'TensorrtExecutionProvider']
ExecutionProviderSet = Dict[ExecutionProvider, ExecutionProviderValue] ExecutionProviderSet = Dict[ExecutionProvider, ExecutionProviderValue]
InferenceSessionProvider = Any
ValueAndUnit = TypedDict('ValueAndUnit', ValueAndUnit = TypedDict('ValueAndUnit',
{ {
'value' : int, 'value' : int,

View File

@@ -1,4 +1,4 @@
from facefusion.execution import create_inference_execution_providers, get_available_execution_providers, has_execution_provider from facefusion.execution import create_inference_session_providers, get_available_execution_providers, has_execution_provider
def test_has_execution_provider() -> None: def test_has_execution_provider() -> None:
@@ -10,8 +10,8 @@ def test_get_available_execution_providers() -> None:
assert 'cpu' in get_available_execution_providers() assert 'cpu' in get_available_execution_providers()
def test_create_inference_execution_providers() -> None: def test_create_inference_session_providers() -> None:
execution_providers =\ inference_session_providers =\
[ [
('CUDAExecutionProvider', ('CUDAExecutionProvider',
{ {
@@ -21,4 +21,4 @@ def test_create_inference_execution_providers() -> None:
'CPUExecutionProvider' 'CPUExecutionProvider'
] ]
assert create_inference_execution_providers('1', [ 'cpu', 'cuda' ]) == execution_providers assert create_inference_session_providers('1', [ 'cpu', 'cuda' ]) == inference_session_providers