Support for download provider mirrors (#847)
This commit is contained in:
@@ -83,12 +83,19 @@ download_provider_set : DownloadProviderSet =\
|
|||||||
{
|
{
|
||||||
'github':
|
'github':
|
||||||
{
|
{
|
||||||
'url': 'https://github.com',
|
'urls':
|
||||||
|
[
|
||||||
|
'https://github.com'
|
||||||
|
],
|
||||||
'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}'
|
'path': '/facefusion/facefusion-assets/releases/download/{base_name}/{file_name}'
|
||||||
},
|
},
|
||||||
'huggingface':
|
'huggingface':
|
||||||
{
|
{
|
||||||
'url': 'https://huggingface.co',
|
'urls':
|
||||||
|
[
|
||||||
|
'https://huggingface.co',
|
||||||
|
'https://hf-mirror.com'
|
||||||
|
],
|
||||||
'path': '/facefusion/{base_name}/resolve/main/{file_name}'
|
'path': '/facefusion/{base_name}/resolve/main/{file_name}'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from facefusion.typing import DownloadProvider, DownloadSet
|
|||||||
|
|
||||||
|
|
||||||
def open_curl(args : List[str]) -> subprocess.Popen[bytes]:
|
def open_curl(args : List[str]) -> subprocess.Popen[bytes]:
|
||||||
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location', '--connect-timeout', '10' ]
|
commands = [ shutil.which('curl'), '--silent', '--insecure', '--location' ]
|
||||||
commands.extend(args)
|
commands.extend(args)
|
||||||
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
|
return subprocess.Popen(commands, stdin = subprocess.PIPE, stdout = subprocess.PIPE)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
|||||||
|
|
||||||
if initial_size < download_size:
|
if initial_size < download_size:
|
||||||
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
with tqdm(total = download_size, initial = initial_size, desc = wording.get('downloading'), unit = 'B', unit_scale = True, unit_divisor = 1024, ascii = ' =', disable = state_manager.get_item('log_level') in [ 'warn', 'error' ]) as progress:
|
||||||
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url ]
|
commands = [ '--create-dirs', '--continue-at', '-', '--output', download_file_path, url, '--connect-timeout', '10' ]
|
||||||
open_curl(commands)
|
open_curl(commands)
|
||||||
current_size = initial_size
|
current_size = initial_size
|
||||||
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
progress.set_postfix(download_providers = state_manager.get_item('download_providers'), file_name = download_file_name)
|
||||||
@@ -42,7 +42,7 @@ def conditional_download(download_directory_path : str, urls : List[str]) -> Non
|
|||||||
|
|
||||||
@lru_cache(maxsize = None)
|
@lru_cache(maxsize = None)
|
||||||
def get_static_download_size(url : str) -> int:
|
def get_static_download_size(url : str) -> int:
|
||||||
commands = [ '-I', url ]
|
commands = [ '-I', url, '--connect-timeout', '5' ]
|
||||||
process = open_curl(commands)
|
process = open_curl(commands)
|
||||||
lines = reversed(process.stdout.readlines())
|
lines = reversed(process.stdout.readlines())
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ def get_static_download_size(url : str) -> int:
|
|||||||
|
|
||||||
@lru_cache(maxsize = None)
|
@lru_cache(maxsize = None)
|
||||||
def ping_static_url(url : str) -> bool:
|
def ping_static_url(url : str) -> bool:
|
||||||
commands = [ '-I', url ]
|
commands = [ '-I', url, '--connect-timeout', '5' ]
|
||||||
process = open_curl(commands)
|
process = open_curl(commands)
|
||||||
process.communicate()
|
process.communicate()
|
||||||
return process.returncode == 0
|
return process.returncode == 0
|
||||||
@@ -129,6 +129,7 @@ def validate_hash_paths(hash_paths : List[str]) -> Tuple[List[str], List[str]]:
|
|||||||
valid_hash_paths.append(hash_path)
|
valid_hash_paths.append(hash_path)
|
||||||
else:
|
else:
|
||||||
invalid_hash_paths.append(hash_path)
|
invalid_hash_paths.append(hash_path)
|
||||||
|
|
||||||
return valid_hash_paths, invalid_hash_paths
|
return valid_hash_paths, invalid_hash_paths
|
||||||
|
|
||||||
|
|
||||||
@@ -141,6 +142,7 @@ def validate_source_paths(source_paths : List[str]) -> Tuple[List[str], List[str
|
|||||||
valid_source_paths.append(source_path)
|
valid_source_paths.append(source_path)
|
||||||
else:
|
else:
|
||||||
invalid_source_paths.append(source_path)
|
invalid_source_paths.append(source_path)
|
||||||
|
|
||||||
return valid_source_paths, invalid_source_paths
|
return valid_source_paths, invalid_source_paths
|
||||||
|
|
||||||
|
|
||||||
@@ -148,16 +150,18 @@ def resolve_download_url(base_name : str, file_name : str) -> Optional[str]:
|
|||||||
download_providers = state_manager.get_item('download_providers')
|
download_providers = state_manager.get_item('download_providers')
|
||||||
|
|
||||||
for download_provider in download_providers:
|
for download_provider in download_providers:
|
||||||
if ping_download_provider(download_provider):
|
download_url = resolve_download_url_by_provider(download_provider, base_name, file_name)
|
||||||
return resolve_download_url_by_provider(download_provider, base_name, file_name)
|
if download_url:
|
||||||
|
return download_url
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def ping_download_provider(download_provider : DownloadProvider) -> bool:
|
|
||||||
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
|
||||||
return ping_static_url(download_provider_value.get('url'))
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
def resolve_download_url_by_provider(download_provider : DownloadProvider, base_name : str, file_name : str) -> Optional[str]:
|
||||||
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
download_provider_value = facefusion.choices.download_provider_set.get(download_provider)
|
||||||
return download_provider_value.get('url') + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)
|
|
||||||
|
for download_provider_url in download_provider_value.get('urls'):
|
||||||
|
if ping_static_url(download_provider_url):
|
||||||
|
return download_provider_url + download_provider_value.get('path').format(base_name = base_name, file_name = file_name)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@@ -169,7 +169,7 @@ ExecutionDevice = TypedDict('ExecutionDevice',
|
|||||||
DownloadProvider = Literal['github', 'huggingface']
|
DownloadProvider = Literal['github', 'huggingface']
|
||||||
DownloadProviderValue = TypedDict('DownloadProviderValue',
|
DownloadProviderValue = TypedDict('DownloadProviderValue',
|
||||||
{
|
{
|
||||||
'url' : str,
|
'urls' : List[str],
|
||||||
'path' : str
|
'path' : str
|
||||||
})
|
})
|
||||||
DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]
|
DownloadProviderSet = Dict[DownloadProvider, DownloadProviderValue]
|
||||||
|
|||||||
Reference in New Issue
Block a user