diff --git a/facefusion/choices.py b/facefusion/choices.py index b4ba7a6..78996da 100755 --- a/facefusion/choices.py +++ b/facefusion/choices.py @@ -83,12 +83,19 @@ download_provider_set : DownloadProviderSet =\ { 'github': { - 'url': 'https://github.com', + 'urls': + [ + 'https://github.com' + ], 'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}' }, 'huggingface': { - 'url': 'https://huggingface.co', + 'urls': + [ + 'https://huggingface.co', + 'https://hf-mirror.com' + ], 'path': '/facefusion/{base_name}/resolve/main/{file_name}' } } diff --git a/facefusion/download.py b/facefusion/download.py index 273aa69..ac54316 100644 --- a/facefusion/download.py +++ b/facefusion/download.py @@ -15,7 +15,7 @@ from facefusion.typing import DownloadProvider, DownloadSet def open_curl(args : List[str]) -> subprocess.Popen[bytes]: - commands = [ shutil.which('curl'), '--silent', '--insecure', '--location', '--connect-timeout', '10' ] + commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ] commands.extend(args) return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE) @@ -29,7 +29,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non if initial_size < download_size: with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress: - commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ] + commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url, '--connect-timeout', '10' ] open_curl(commands) current_size = initial_size progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name) @@ -42,7 +42,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non @lru_cache(maxsize = None) def get_static_download_size(url : str) -> int: - commands = [ '-I', url ] + commands = [ '-I', url, '--connect-timeout', '5' ] process = open_curl(commands) lines = reversed(process.stdout.readlines()) @@ -57,7 +57,7 @@ def get_static_download_size(url : str) -> int: @lru_cache(maxsize = None) def ping_static_url(url : str) -> bool: - commands = [ '-I', url ] + commands = [ '-I', url, '--connect-timeout', '5' ] process = open_curl(commands) process.communicate() return process.returncode == 0 @@ -129,6 +129,7 @@ def validate_hash_paths(hash_paths : List[str]) -> Tuple[List[str], List[str]]: valid_hash_paths.append(hash_path) else: invalid_hash_paths.append(hash_path) + return valid_hash_paths, invalid_hash_paths @@ -141,6 +142,7 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str valid_source_paths.append(source_path) else: invalid_source_paths.append(source_path) + return valid_source_paths, invalid_source_paths @@ -148,16 +150,18 @@ def resolve_download_url(base_name : str, file_name : str) -> Optional[str]: download_providers = state_manager.get_item('download_providers') for download_provider in download_providers: - if ping_download_provider(download_provider): - return resolve_download_url_by_provider(download_provider, base_name, file_name) + download_url = resolve_download_url_by_provider(download_provider, base_name, file_name) + if download_url: + return download_url + return None -def ping_download_provider(download_provider : DownloadProvider) -> bool: - download_provider_value = facefusion.choices.download_provider_set.get(download_provider) - return ping_static_url(download_provider_value.get('url')) - - def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]: download_provider_value = facefusion.choices.download_provider_set.get(download_provider) - return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name) + + for download_provider_url in download_provider_value.get('urls'): + if ping_static_url(download_provider_url): + return download_provider_url + download_provider_value.get('path').format(base_name = base_name, file_name = file_name) + + return None diff --git a/facefusion/typing.py b/facefusion/typing.py index 5c1b540..ad2bf0a 100755 --- a/facefusion/typing.py +++ b/facefusion/typing.py @@ -169,7 +169,7 @@ ExecutionDevice = TypedDict('ExecutionDevice', DownloadProvider = Literal['github', 'huggingface'] DownloadProviderValue = TypedDict('DownloadProviderValue', { - 'url' : str, + 'urls' : List[str], 'path' : str }) DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]