Support for download provider mirrors (#847)
This commit is contained in:
@@ -83,12 +83,19 @@ download_provider_set : DownloadProviderSet =\
|
||||
{
|
||||
'github':
|
||||
{
|
||||
'url': 'https://github.com',
|
||||
'urls':
|
||||
[
|
||||
'https://github.com'
|
||||
],
|
||||
'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}'
|
||||
},
|
||||
'huggingface':
|
||||
{
|
||||
'url': 'https://huggingface.co',
|
||||
'urls':
|
||||
[
|
||||
'https://huggingface.co',
|
||||
'https://hf-mirror.com'
|
||||
],
|
||||
'path': '/facefusion/{base_name}/resolve/main/{file_name}'
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ from facefusion.typing import DownloadProvider, DownloadSet
|
||||
|
||||
|
||||
def open_curl(args : List[str]) -> subprocess.Popen[bytes]:
|
||||
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location', '--connect-timeout', '10' ]
|
||||
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ]
|
||||
commands.extend(args)
|
||||
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
|
||||
|
||||
@@ -29,7 +29,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
||||
|
||||
if initial_size < download_size:
|
||||
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
||||
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ]
|
||||
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url, '--connect-timeout', '10' ]
|
||||
open_curl(commands)
|
||||
current_size = initial_size
|
||||
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
||||
@@ -42,7 +42,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def get_static_download_size(url : str) -> int:
|
||||
commands = [ '-I', url ]
|
||||
commands = [ '-I', url, '--connect-timeout', '5' ]
|
||||
process = open_curl(commands)
|
||||
lines = reversed(process.stdout.readlines())
|
||||
|
||||
@@ -57,7 +57,7 @@ def get_static_download_size(url : str) -> int:
|
||||
|
||||
@lru_cache(maxsize = None)
|
||||
def ping_static_url(url : str) -> bool:
|
||||
commands = [ '-I', url ]
|
||||
commands = [ '-I', url, '--connect-timeout', '5' ]
|
||||
process = open_curl(commands)
|
||||
process.communicate()
|
||||
return process.returncode == 0
|
||||
@@ -129,6 +129,7 @@ def validate_hash_paths(hash_paths : List[str]) -> Tuple[List[str], List[str]]:
|
||||
valid_hash_paths.append(hash_path)
|
||||
else:
|
||||
invalid_hash_paths.append(hash_path)
|
||||
|
||||
return valid_hash_paths, invalid_hash_paths
|
||||
|
||||
|
||||
@@ -141,6 +142,7 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
|
||||
valid_source_paths.append(source_path)
|
||||
else:
|
||||
invalid_source_paths.append(source_path)
|
||||
|
||||
return valid_source_paths, invalid_source_paths
|
||||
|
||||
|
||||
@@ -148,16 +150,18 @@ def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
||||
download_providers = state_manager.get_item('download_providers')
|
||||
|
||||
for download_provider in download_providers:
|
||||
if ping_download_provider(download_provider):
|
||||
return resolve_download_url_by_provider(download_provider, base_name, file_name)
|
||||
download_url = resolve_download_url_by_provider(download_provider, base_name, file_name)
|
||||
if download_url:
|
||||
return download_url
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def ping_download_provider(download_provider : DownloadProvider) -> bool:
|
||||
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||
return ping_static_url(download_provider_value.get('url'))
|
||||
|
||||
|
||||
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
||||
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||
return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)
|
||||
|
||||
for download_provider_url in download_provider_value.get('urls'):
|
||||
if ping_static_url(download_provider_url):
|
||||
return download_provider_url + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)
|
||||
|
||||
return None
|
||||
|
||||
@@ -169,7 +169,7 @@ ExecutionDevice = TypedDict('ExecutionDevice',
|
||||
DownloadProvider = Literal['github', 'huggingface']
|
||||
DownloadProviderValue = TypedDict('DownloadProviderValue',
|
||||
{
|
||||
'url' : str,
|
||||
'urls' : List[str],
|
||||
'path' : str
|
||||
})
|
||||
DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]
|
||||
|
||||
Reference in New Issue
Block a user