Hotfix Geforce 16 series

This commit is contained in:
henryruhs
2025-01-02 22:03:48 +01:00
parent 8656411336
commit 197773c346
2 changed files with 10 additions and 2 deletions

View File

@@ -34,7 +34,8 @@ def create_inference_execution_providers(execution_device_id : str, execution_pr
if execution_provider == 'cuda': if execution_provider == 'cuda':
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
{ {
'device_id': execution_device_id 'device_id': execution_device_id,
'cudnn_conv_algo_search': 'DEFAULT' if is_geforce_16_series() else 'EXHAUSTIVE'
})) }))
if execution_provider == 'tensorrt': if execution_provider == 'tensorrt':
inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider), inference_execution_providers.append((facefusion.choices.execution_provider_set.get(execution_provider),
@@ -66,6 +67,13 @@ def create_inference_execution_providers(execution_device_id : str, execution_pr
return inference_execution_providers return inference_execution_providers
def is_geforce_16_series() -> bool:
execution_devices = detect_static_execution_devices()
product_names = ('GeForce GTX 1630', 'GeForce GTX 1650', 'GeForce GTX 1660')
return any(execution_device.get('product').get('name').startswith(product_names) for execution_device in execution_devices)
def run_nvidia_smi() -> subprocess.Popen[bytes]: def run_nvidia_smi() -> subprocess.Popen[bytes]:
commands = [ shutil.which('nvidia-smi'), '--query', '--xml-format' ] commands = [ shutil.which('nvidia-smi'), '--query', '--xml-format' ]
return subprocess.Popen(commands, stdout = subprocess.PIPE) return subprocess.Popen(commands, stdout = subprocess.PIPE)

View File

@@ -89,5 +89,5 @@ def run(program : ArgumentParser) -> None:
subprocess.call([ shutil.which('conda'), 'env', 'config', 'vars', 'set', 'PATH=' + os.pathsep.join(library_paths) ]) subprocess.call([ shutil.which('conda'), 'env', 'config', 'vars', 'set', 'PATH=' + os.pathsep.join(library_paths) ])
if args.onnxruntime in [ 'rocm', 'directml' ]: if args.onnxruntime in [ 'directml', 'rocm' ]:
subprocess.call([ shutil.which('pip'), 'install', 'numpy==1.26.4', '--force-reinstall' ]) subprocess.call([ shutil.which('pip'), 'install', 'numpy==1.26.4', '--force-reinstall' ])