diff --git a/facefusion/choices.py b/facefusion/choices.py index f8ff95d..5b3e8ae 100755 --- a/facefusion/choices.py +++ b/facefusion/choices.py @@ -73,13 +73,13 @@ webcam_resolutions : List[str] = [ '320x240', '640x480', '800x600', '1024x768', execution_provider_set : ExecutionProviderSet =\ { - 'cpu': 'CPUExecutionProvider', - 'coreml': 'CoreMLExecutionProvider', 'cuda': 'CUDAExecutionProvider', + 'tensorrt': 'TensorrtExecutionProvider', 'directml': 'DmlExecutionProvider', - 'openvino': 'OpenVINOExecutionProvider', 'rocm': 'ROCMExecutionProvider', - 'tensorrt': 'TensorrtExecutionProvider' + 'openvino': 'OpenVINOExecutionProvider', + 'coreml': 'CoreMLExecutionProvider', + 'cpu': 'CPUExecutionProvider' } execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys()) download_provider_set : DownloadProviderSet =\ diff --git a/facefusion/execution.py b/facefusion/execution.py index cf2e814..b59df2e 100644 --- a/facefusion/execution.py +++ b/facefusion/execution.py @@ -7,6 +7,7 @@ from typing import Any, List, Optional from onnxruntime import get_available_providers, set_default_logger_severity import facefusion.choices +from facefusion.common_helper import get_last from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit set_default_logger_severity(3) @@ -16,6 +17,14 @@ def has_execution_provider(execution_provider : ExecutionProvider) -> bool: return execution_provider in get_available_execution_providers() +def suggest_execution_provider(execution_providers : List[ExecutionProvider]) -> ExecutionProvider: + for execution_provider in facefusion.choices.execution_providers: + if execution_provider in execution_providers: + return execution_provider + + return get_last(facefusion.choices.execution_providers) + + def get_available_execution_providers() -> List[ExecutionProvider]: inference_execution_providers = get_available_providers() available_execution_providers = [] @@ -47,17 +56,17 @@ def create_inference_execution_providers(execution_device_id : str, execution_pr 'trt_timing_cache_path': '.caches', 'trt_builder_optimization_level': 5 })) + if execution_provider in [ 'directml', 'rocm' ]: + inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), + { + 'device_id': execution_device_id + })) if execution_provider == 'openvino': inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), { 'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id, 'precision': 'FP32' })) - if execution_provider in [ 'directml', 'rocm' ]: - inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), - { - 'device_id': execution_device_id - })) if execution_provider == 'coreml': inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider)) diff --git a/facefusion/installer.py b/facefusion/installer.py index 2b815e8..d4a4a1f 100644 --- a/facefusion/installer.py +++ b/facefusion/installer.py @@ -4,29 +4,28 @@ import signal import subprocess import sys from argparse import ArgumentParser, HelpFormatter -from typing import Dict, Tuple from facefusion import metadata, wording -from facefusion.common_helper import is_linux, is_macos, is_windows +from facefusion.common_helper import is_linux, is_windows -ONNXRUNTIMES : Dict[str, Tuple[str, str]] = {} -if is_macos(): - ONNXRUNTIMES['default'] = ('onnxruntime', '1.20.1') -else: - ONNXRUNTIMES['default'] = ('onnxruntime', '1.20.1') - ONNXRUNTIMES['cuda'] = ('onnxruntime-gpu', '1.20.1') - ONNXRUNTIMES['openvino'] = ('onnxruntime-openvino', '1.20.0') -if is_linux(): - ONNXRUNTIMES['rocm'] = ('onnxruntime-rocm', '1.19.0') +ONNXRUNTIME_SET =\ +{ + 'default': ('onnxruntime', '1.20.1') +} +if is_windows() or is_linux(): + ONNXRUNTIME_SET['cuda'] = ('onnxruntime-gpu', '1.20.1') + ONNXRUNTIME_SET['openvino'] = ('onnxruntime-openvino', '1.20.0') if is_windows(): - ONNXRUNTIMES['directml'] = ('onnxruntime-directml', '1.17.3') + ONNXRUNTIME_SET['directml'] = ('onnxruntime-directml', '1.17.3') +if is_linux(): + ONNXRUNTIME_SET['rocm'] = ('onnxruntime-rocm', '1.19.0') def cli() -> None: signal.signal(signal.SIGINT, lambda signal_number, frame: sys.exit(0)) program = ArgumentParser(formatter_class = lambda prog: HelpFormatter(prog, max_help_position = 50)) - program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIMES.keys(), required = True) + program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIME_SET.keys(), required = True) program.add_argument('--skip-conda', help = wording.get('help.skip_conda'), action = 'store_true') program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version') run(program) @@ -35,7 +34,7 @@ def cli() -> None: def run(program : ArgumentParser) -> None: args = program.parse_args() has_conda = 'CONDA_PREFIX' in os.environ - onnxruntime_name, onnxruntime_version = ONNXRUNTIMES.get(args.onnxruntime) + onnxruntime_name, onnxruntime_version = ONNXRUNTIME_SET.get(args.onnxruntime) if not args.skip_conda and not has_conda: sys.stdout.write(wording.get('conda_not_activated') + os.linesep) diff --git a/facefusion/program.py b/facefusion/program.py index db85bb7..ccc0d25 100755 --- a/facefusion/program.py +++ b/facefusion/program.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser, HelpFormatter import facefusion.choices from facefusion import config, metadata, state_manager, wording from facefusion.common_helper import create_float_metavar, create_int_metavar, get_last -from facefusion.execution import get_available_execution_providers +from facefusion.execution import get_available_execution_providers, suggest_execution_provider from facefusion.filesystem import get_file_name, resolve_file_paths from facefusion.jobs import job_store from facefusion.processors.core import get_processors_modules @@ -196,7 +196,7 @@ def create_execution_program() -> ArgumentParser: available_execution_providers = get_available_execution_providers() group_execution = program.add_argument_group('execution') group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution.execution_device_id', '0')) - group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', 'cpu'), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS') + group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', suggest_execution_provider(available_execution_providers)), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS') group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution.execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range)) group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution.execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range)) job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ]) @@ -205,9 +205,8 @@ def create_execution_program() -> ArgumentParser: def create_download_providers_program() -> ArgumentParser: program = ArgumentParser(add_help = False) - download_providers = list(facefusion.choices.download_provider_set.keys()) group_download = program.add_argument_group('download') - group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS') + group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(facefusion.choices.download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = facefusion.choices.download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS') job_store.register_job_keys([ 'download_providers' ]) return program @@ -231,9 +230,8 @@ def create_memory_program() -> ArgumentParser: def create_misc_program() -> ArgumentParser: program = ArgumentParser(add_help = False) - log_level_keys = list(facefusion.choices.log_level_set.keys()) group_misc = program.add_argument_group('misc') - group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = log_level_keys) + group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = facefusion.choices.log_levels) job_store.register_job_keys([ 'log_level' ]) return program