Suggest best execution provider, Simplify ONNXRUNTIME_SET (#854)
This commit is contained in:
@@ -73,13 +73,13 @@ webcam_resolutions : List[str] = [ '320x240', '640x480', '800x600', '1024x768',
|
|||||||
|
|
||||||
execution_provider_set : ExecutionProviderSet =\
|
execution_provider_set : ExecutionProviderSet =\
|
||||||
{
|
{
|
||||||
'cpu': 'CPUExecutionProvider',
|
|
||||||
'coreml': 'CoreMLExecutionProvider',
|
|
||||||
'cuda': 'CUDAExecutionProvider',
|
'cuda': 'CUDAExecutionProvider',
|
||||||
|
'tensorrt': 'TensorrtExecutionProvider',
|
||||||
'directml': 'DmlExecutionProvider',
|
'directml': 'DmlExecutionProvider',
|
||||||
'openvino': 'OpenVINOExecutionProvider',
|
|
||||||
'rocm': 'ROCMExecutionProvider',
|
'rocm': 'ROCMExecutionProvider',
|
||||||
'tensorrt': 'TensorrtExecutionProvider'
|
'openvino': 'OpenVINOExecutionProvider',
|
||||||
|
'coreml': 'CoreMLExecutionProvider',
|
||||||
|
'cpu': 'CPUExecutionProvider'
|
||||||
}
|
}
|
||||||
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
|
execution_providers : List[ExecutionProvider] = list(execution_provider_set.keys())
|
||||||
download_provider_set : DownloadProviderSet =\
|
download_provider_set : DownloadProviderSet =\
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import Any, List, Optional
|
|||||||
from onnxruntime import get_available_providers, set_default_logger_severity
|
from onnxruntime import get_available_providers, set_default_logger_severity
|
||||||
|
|
||||||
import facefusion.choices
|
import facefusion.choices
|
||||||
|
from facefusion.common_helper import get_last
|
||||||
from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit
|
from facefusion.typing import ExecutionDevice, ExecutionProvider, ValueAndUnit
|
||||||
|
|
||||||
set_default_logger_severity(3)
|
set_default_logger_severity(3)
|
||||||
@@ -16,6 +17,14 @@ def has_execution_provider(execution_provider : ExecutionProvider) -> bool:
|
|||||||
return execution_provider in get_available_execution_providers()
|
return execution_provider in get_available_execution_providers()
|
||||||
|
|
||||||
|
|
||||||
|
def suggest_execution_provider(execution_providers : List[ExecutionProvider]) -> ExecutionProvider:
|
||||||
|
for execution_provider in facefusion.choices.execution_providers:
|
||||||
|
if execution_provider in execution_providers:
|
||||||
|
return execution_provider
|
||||||
|
|
||||||
|
return get_last(facefusion.choices.execution_providers)
|
||||||
|
|
||||||
|
|
||||||
def get_available_execution_providers() -> List[ExecutionProvider]:
|
def get_available_execution_providers() -> List[ExecutionProvider]:
|
||||||
inference_execution_providers = get_available_providers()
|
inference_execution_providers = get_available_providers()
|
||||||
available_execution_providers = []
|
available_execution_providers = []
|
||||||
@@ -47,17 +56,17 @@ def create_inference_execution_providers(execution_device_id : str, execution_pr
|
|||||||
'trt_timing_cache_path': '.caches',
|
'trt_timing_cache_path': '.caches',
|
||||||
'trt_builder_optimization_level': 5
|
'trt_builder_optimization_level': 5
|
||||||
}))
|
}))
|
||||||
|
if execution_provider in [ 'directml', 'rocm' ]:
|
||||||
|
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||||
|
{
|
||||||
|
'device_id': execution_device_id
|
||||||
|
}))
|
||||||
if execution_provider == 'openvino':
|
if execution_provider == 'openvino':
|
||||||
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
||||||
{
|
{
|
||||||
'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id,
|
'device_type': 'GPU' if execution_device_id == '0' else 'GPU.' + execution_device_id,
|
||||||
'precision': 'FP32'
|
'precision': 'FP32'
|
||||||
}))
|
}))
|
||||||
if execution_provider in [ 'directml', 'rocm' ]:
|
|
||||||
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
|
|
||||||
{
|
|
||||||
'device_id': execution_device_id
|
|
||||||
}))
|
|
||||||
if execution_provider == 'coreml':
|
if execution_provider == 'coreml':
|
||||||
inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider))
|
inference_execution_providers.append(facefusion.choices.execution_provider_set.get(execution_provider))
|
||||||
|
|
||||||
|
|||||||
@@ -4,29 +4,28 @@ import signal
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from argparse import ArgumentParser, HelpFormatter
|
from argparse import ArgumentParser, HelpFormatter
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
from facefusion import metadata, wording
|
from facefusion import metadata, wording
|
||||||
from facefusion.common_helper import is_linux, is_macos, is_windows
|
from facefusion.common_helper import is_linux, is_windows
|
||||||
|
|
||||||
ONNXRUNTIMES : Dict[str, Tuple[str, str]] = {}
|
|
||||||
|
|
||||||
if is_macos():
|
ONNXRUNTIME_SET =\
|
||||||
ONNXRUNTIMES['default'] = ('onnxruntime', '1.20.1')
|
{
|
||||||
else:
|
'default': ('onnxruntime', '1.20.1')
|
||||||
ONNXRUNTIMES['default'] = ('onnxruntime', '1.20.1')
|
}
|
||||||
ONNXRUNTIMES['cuda'] = ('onnxruntime-gpu', '1.20.1')
|
if is_windows() or is_linux():
|
||||||
ONNXRUNTIMES['openvino'] = ('onnxruntime-openvino', '1.20.0')
|
ONNXRUNTIME_SET['cuda'] = ('onnxruntime-gpu', '1.20.1')
|
||||||
if is_linux():
|
ONNXRUNTIME_SET['openvino'] = ('onnxruntime-openvino', '1.20.0')
|
||||||
ONNXRUNTIMES['rocm'] = ('onnxruntime-rocm', '1.19.0')
|
|
||||||
if is_windows():
|
if is_windows():
|
||||||
ONNXRUNTIMES['directml'] = ('onnxruntime-directml', '1.17.3')
|
ONNXRUNTIME_SET['directml'] = ('onnxruntime-directml', '1.17.3')
|
||||||
|
if is_linux():
|
||||||
|
ONNXRUNTIME_SET['rocm'] = ('onnxruntime-rocm', '1.19.0')
|
||||||
|
|
||||||
|
|
||||||
def cli() -> None:
|
def cli() -> None:
|
||||||
signal.signal(signal.SIGINT, lambda signal_number, frame: sys.exit(0))
|
signal.signal(signal.SIGINT, lambda signal_number, frame: sys.exit(0))
|
||||||
program = ArgumentParser(formatter_class = lambda prog: HelpFormatter(prog, max_help_position = 50))
|
program = ArgumentParser(formatter_class = lambda prog: HelpFormatter(prog, max_help_position = 50))
|
||||||
program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIMES.keys(), required = True)
|
program.add_argument('--onnxruntime', help = wording.get('help.install_dependency').format(dependency = 'onnxruntime'), choices = ONNXRUNTIME_SET.keys(), required = True)
|
||||||
program.add_argument('--skip-conda', help = wording.get('help.skip_conda'), action = 'store_true')
|
program.add_argument('--skip-conda', help = wording.get('help.skip_conda'), action = 'store_true')
|
||||||
program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version')
|
program.add_argument('-v', '--version', version = metadata.get('name') + ' ' + metadata.get('version'), action = 'version')
|
||||||
run(program)
|
run(program)
|
||||||
@@ -35,7 +34,7 @@ def cli() -> None:
|
|||||||
def run(program : ArgumentParser) -> None:
|
def run(program : ArgumentParser) -> None:
|
||||||
args = program.parse_args()
|
args = program.parse_args()
|
||||||
has_conda = 'CONDA_PREFIX' in os.environ
|
has_conda = 'CONDA_PREFIX' in os.environ
|
||||||
onnxruntime_name, onnxruntime_version = ONNXRUNTIMES.get(args.onnxruntime)
|
onnxruntime_name, onnxruntime_version = ONNXRUNTIME_SET.get(args.onnxruntime)
|
||||||
|
|
||||||
if not args.skip_conda and not has_conda:
|
if not args.skip_conda and not has_conda:
|
||||||
sys.stdout.write(wording.get('conda_not_activated') + os.linesep)
|
sys.stdout.write(wording.get('conda_not_activated') + os.linesep)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from argparse import ArgumentParser, HelpFormatter
|
|||||||
import facefusion.choices
|
import facefusion.choices
|
||||||
from facefusion import config, metadata, state_manager, wording
|
from facefusion import config, metadata, state_manager, wording
|
||||||
from facefusion.common_helper import create_float_metavar, create_int_metavar, get_last
|
from facefusion.common_helper import create_float_metavar, create_int_metavar, get_last
|
||||||
from facefusion.execution import get_available_execution_providers
|
from facefusion.execution import get_available_execution_providers, suggest_execution_provider
|
||||||
from facefusion.filesystem import get_file_name, resolve_file_paths
|
from facefusion.filesystem import get_file_name, resolve_file_paths
|
||||||
from facefusion.jobs import job_store
|
from facefusion.jobs import job_store
|
||||||
from facefusion.processors.core import get_processors_modules
|
from facefusion.processors.core import get_processors_modules
|
||||||
@@ -196,7 +196,7 @@ def create_execution_program() -> ArgumentParser:
|
|||||||
available_execution_providers = get_available_execution_providers()
|
available_execution_providers = get_available_execution_providers()
|
||||||
group_execution = program.add_argument_group('execution')
|
group_execution = program.add_argument_group('execution')
|
||||||
group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution.execution_device_id', '0'))
|
group_execution.add_argument('--execution-device-id', help = wording.get('help.execution_device_id'), default = config.get_str_value('execution.execution_device_id', '0'))
|
||||||
group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', 'cpu'), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS')
|
group_execution.add_argument('--execution-providers', help = wording.get('help.execution_providers').format(choices = ', '.join(available_execution_providers)), default = config.get_str_list('execution.execution_providers', suggest_execution_provider(available_execution_providers)), choices = available_execution_providers, nargs = '+', metavar = 'EXECUTION_PROVIDERS')
|
||||||
group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution.execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range))
|
group_execution.add_argument('--execution-thread-count', help = wording.get('help.execution_thread_count'), type = int, default = config.get_int_value('execution.execution_thread_count', '4'), choices = facefusion.choices.execution_thread_count_range, metavar = create_int_metavar(facefusion.choices.execution_thread_count_range))
|
||||||
group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution.execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range))
|
group_execution.add_argument('--execution-queue-count', help = wording.get('help.execution_queue_count'), type = int, default = config.get_int_value('execution.execution_queue_count', '1'), choices = facefusion.choices.execution_queue_count_range, metavar = create_int_metavar(facefusion.choices.execution_queue_count_range))
|
||||||
job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ])
|
job_store.register_job_keys([ 'execution_device_id', 'execution_providers', 'execution_thread_count', 'execution_queue_count' ])
|
||||||
@@ -205,9 +205,8 @@ def create_execution_program() -> ArgumentParser:
|
|||||||
|
|
||||||
def create_download_providers_program() -> ArgumentParser:
|
def create_download_providers_program() -> ArgumentParser:
|
||||||
program = ArgumentParser(add_help = False)
|
program = ArgumentParser(add_help = False)
|
||||||
download_providers = list(facefusion.choices.download_provider_set.keys())
|
|
||||||
group_download = program.add_argument_group('download')
|
group_download = program.add_argument_group('download')
|
||||||
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
group_download.add_argument('--download-providers', help = wording.get('help.download_providers').format(choices = ', '.join(facefusion.choices.download_providers)), default = config.get_str_list('download.download_providers', ' '.join(facefusion.choices.download_providers)), choices = facefusion.choices.download_providers, nargs = '+', metavar = 'DOWNLOAD_PROVIDERS')
|
||||||
job_store.register_job_keys([ 'download_providers' ])
|
job_store.register_job_keys([ 'download_providers' ])
|
||||||
return program
|
return program
|
||||||
|
|
||||||
@@ -231,9 +230,8 @@ def create_memory_program() -> ArgumentParser:
|
|||||||
|
|
||||||
def create_misc_program() -> ArgumentParser:
|
def create_misc_program() -> ArgumentParser:
|
||||||
program = ArgumentParser(add_help = False)
|
program = ArgumentParser(add_help = False)
|
||||||
log_level_keys = list(facefusion.choices.log_level_set.keys())
|
|
||||||
group_misc = program.add_argument_group('misc')
|
group_misc = program.add_argument_group('misc')
|
||||||
group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = log_level_keys)
|
group_misc.add_argument('--log-level', help = wording.get('help.log_level'), default = config.get_str_value('misc.log_level', 'info'), choices = facefusion.choices.log_levels)
|
||||||
job_store.register_job_keys([ 'log_level' ])
|
job_store.register_job_keys([ 'log_level' ])
|
||||||
return program
|
return program
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user