Suggest best execution provider, Simplify ONNXRUNTIME_SET (#854)

This commit is contained in:
Henry Ruhs
2025-01-15 09:42:46 +01:00
committed by henryruhs
parent 732f096da0
commit faf5020051
4 changed files with 35 additions and 29 deletions

View File

@@ -73,13 +73,13 @@ webcam_resolutions : List[str] = [ '320x240', '640x480', '800x600', '1024x768',
execution_provider_set : ExecutionProviderSet =\ execution_provider_set : ExecutionProviderSet =\
{ {
'cpu': 'CPUExecutionProvider',
'coreml': 'CoreMLExecutionProvider',
'cuda': 'CUDAExecutionProvider', 'cuda': 'CUDAExecutionProvider',
'tensorrt': 'TensorrtExecutionProvider',
'directml': 'DmlExecutionProvider', 'directml': 'DmlExecutionProvider',
'openvino': 'OpenVINOExecutionProvider',
'rocm': 'ROCMExecutionProvider', 'rocm': 'ROCMExecutionProvider',
'tensorrt': 'TensorrtExecutionProvider' 'openvino': 'OpenVINOExecutionProvider',
'coreml': 'CoreMLExecutionProvider',
'cpu': 'CPUExecutionProvider'
} }
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys()) execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
download_provider_set : DownloadProviderSet =\ download_provider_set : DownloadProviderSet =\

View File

@@ -7,6 +7,7 @@ from typing import Any, List, Optional
from onnxruntime import get_available_providers, set_default_logger_severity from onnxruntime import get_available_providers, set_default_logger_severity
import facefusion.choices import facefusion.choices
from facefusion.common_helper import get_last
from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit
set_default_logger_severity(3) set_default_logger_severity(3)
@@ -16,6 +17,14 @@ def has_execution_provider(execution_provider : ExecutionProvider) -> bool:
return execution_provider in get_available_execution_providers() return execution_provider in get_available_execution_providers()
def suggest_execution_provider(execution_providers : List[ExecutionProvider]) -> ExecutionProvider:
for execution_provider in facefusion.choices.execution_providers:
if execution_provider in execution_providers:
return execution_provider
return get_last(facefusion.choices.execution_providers)
def get_available_execution_providers() -> List[ExecutionProvider]: def get_available_execution_providers() -> List[ExecutionProvider]:
inference_execution_providers = get_available_providers() inference_execution_providers = get_available_providers()
available_execution_providers = [] available_execution_providers = []
@@ -47,17 +56,17 @@ def create_inference_execution_providers(execution_device_id : str, execution_pr
'trt_timing_cache_path': '.caches', 'trt_timing_cache_path': '.caches',
'trt_builder_optimization_level': 5 'trt_builder_optimization_level': 5
})) }))
if execution_provider in [ 'directml', 'rocm' ]:
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{
'device_id': execution_device_id
}))
if execution_provider == 'openvino': if execution_provider == 'openvino':
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{ {
'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id, 'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id,
'precision': 'FP32' 'precision': 'FP32'
})) }))
if execution_provider in [ 'directml', 'rocm' ]:
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{
'device_id': execution_device_id
}))
if execution_provider == 'coreml': if execution_provider == 'coreml':
inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider)) inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider))

View File

@@ -4,29 +4,28 @@ import signal
import subprocess import subprocess
import sys import sys
from argparse import ArgumentParser, HelpFormatter from argparse import ArgumentParser, HelpFormatter
from typing import Dict, Tuple
from facefusion import metadata, wording from facefusion import metadata, wording
from facefusion.common_helper import is_linux, is_macos, is_windows from facefusion.common_helper import is_linux, is_windows
ONNXRUNTIMES : Dict[str, Tuple[str, str]] = {}
if is_macos(): ONNXRUNTIME_SET =\
ONNXRUNTIMES['default'] = ('onnxruntime', '1.20.1') {
else: 'default': ('onnxruntime', '1.20.1')
ONNXRUNTIMES['default'] = ('onnxruntime', '1.20.1') }
ONNXRUNTIMES['cuda'] = ('onnxruntime-gpu', '1.20.1') if is_windows() or is_linux():
ONNXRUNTIMES['openvino'] = ('onnxruntime-openvino', '1.20.0') ONNXRUNTIME_SET['cuda'] = ('onnxruntime-gpu', '1.20.1')
if is_linux(): ONNXRUNTIME_SET['openvino'] = ('onnxruntime-openvino', '1.20.0')
ONNXRUNTIMES['rocm'] = ('onnxruntime-rocm', '1.19.0')
if is_windows(): if is_windows():
ONNXRUNTIMES['directml'] = ('onnxruntime-directml', '1.17.3') ONNXRUNTIME_SET['directml'] = ('onnxruntime-directml', '1.17.3')
if is_linux():
ONNXRUNTIME_SET['rocm'] = ('onnxruntime-rocm', '1.19.0')
def cli() -> None: def cli() -> None:
signal.signal(signal.SIGINT, lambda signal_number, frame: sys.exit(0)) signal.signal(signal.SIGINT, lambda signal_number, frame: sys.exit(0))
program = ArgumentParser(formatter_class = lambda prog: HelpFormatter(prog, max_help_position = 50)) program = ArgumentParser(formatter_class = lambda prog: HelpFormatter(prog, max_help_position = 50))
program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIMES.keys(), required = True) program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIME_SET.keys(), required = True)
program.add_argument('--skip-conda', help = wording.get('help.skip_conda'), action = 'store_true') program.add_argument('--skip-conda', help = wording.get('help.skip_conda'), action = 'store_true')
program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version') program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version')
run(program) run(program)
@@ -35,7 +34,7 @@ def cli() -> None:
def run(program : ArgumentParser) -> None: def run(program : ArgumentParser) -> None:
args = program.parse_args() args = program.parse_args()
has_conda = 'CONDA_PREFIX' in os.environ has_conda = 'CONDA_PREFIX' in os.environ
onnxruntime_name, onnxruntime_version = ONNXRUNTIMES.get(args.onnxruntime) onnxruntime_name, onnxruntime_version = ONNXRUNTIME_SET.get(args.onnxruntime)
if not args.skip_conda and not has_conda: if not args.skip_conda and not has_conda:
sys.stdout.write(wording.get('conda_not_activated') + os.linesep) sys.stdout.write(wording.get('conda_not_activated') + os.linesep)

View File

@@ -4,7 +4,7 @@ from argparse import ArgumentParser, HelpFormatter
import facefusion.choices import facefusion.choices
from facefusion import config, metadata, state_manager, wording from facefusion import config, metadata, state_manager, wording
from facefusion.common_helper import create_float_metavar, create_int_metavar, get_last from facefusion.common_helper import create_float_metavar, create_int_metavar, get_last
from facefusion.execution import get_available_execution_providers from facefusion.execution import get_available_execution_providers, suggest_execution_provider
from facefusion.filesystem import get_file_name, resolve_file_paths from facefusion.filesystem import get_file_name, resolve_file_paths
from facefusion.jobs import job_store from facefusion.jobs import job_store
from facefusion.processors.core import get_processors_modules from facefusion.processors.core import get_processors_modules
@@ -196,7 +196,7 @@ def create_execution_program() -> ArgumentParser:
available_execution_providers = get_available_execution_providers() available_execution_providers = get_available_execution_providers()
group_execution = program.add_argument_group('execution') group_execution = program.add_argument_group('execution')
group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution.execution_device_id', '0')) group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution.execution_device_id', '0'))
group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', 'cpu'), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS') group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', suggest_execution_provider(available_execution_providers)), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS')
group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution.execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range)) group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution.execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range))
group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution.execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range)) group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution.execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range))
job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ]) job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ])
@@ -205,9 +205,8 @@ def create_execution_program() -> ArgumentParser:
def create_download_providers_program() -> ArgumentParser: def create_download_providers_program() -> ArgumentParser:
program = ArgumentParser(add_help = False) program = ArgumentParser(add_help = False)
download_providers = list(facefusion.choices.download_provider_set.keys())
group_download = program.add_argument_group('download') group_download = program.add_argument_group('download')
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS') group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(facefusion.choices.download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = facefusion.choices.download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
job_store.register_job_keys([ 'download_providers' ]) job_store.register_job_keys([ 'download_providers' ])
return program return program
@@ -231,9 +230,8 @@ def create_memory_program() -> ArgumentParser:
def create_misc_program() -> ArgumentParser: def create_misc_program() -> ArgumentParser:
program = ArgumentParser(add_help = False) program = ArgumentParser(add_help = False)
log_level_keys = list(facefusion.choices.log_level_set.keys())
group_misc = program.add_argument_group('misc') group_misc = program.add_argument_group('misc')
group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = log_level_keys) group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = facefusion.choices.log_levels)
job_store.register_job_keys([ 'log_level' ]) job_store.register_job_keys([ 'log_level' ])
return program return program