diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index 18e0367..bc2d86b 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -1,14 +1,15 @@ from functools import lru_cache -from typing import Tuple +from typing import List, Tuple import numpy from tqdm import tqdm from facefusion import inference_manager, state_manager, wording from facefusion.download import conditional_download_hashes, conditional_download_sources, resolve_download_url +from facefusion.execution import has_execution_provider from facefusion.filesystem import resolve_relative_path from facefusion.thread_helper import conditional_thread_semaphore -from facefusion.types import Detection, DownloadScope, DownloadSet, Fps, InferencePool, ModelSet, VisionFrame +from facefusion.types import Detection, DownloadScope, DownloadSet, ExecutionProvider, Fps, InferencePool, ModelSet, VisionFrame from facefusion.vision import detect_video_fps, fit_frame, read_image, read_video_frame STREAM_COUNTER = 0 @@ -102,6 +103,12 @@ def clear_inference_pool() -> None: inference_manager.clear_inference_pool(__name__, model_names) +def resolve_execution_providers() -> List[ExecutionProvider]: + if has_execution_provider('coreml'): + return [ 'cpu' ] + return state_manager.get_item('execution_providers') + + def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: model_set = create_static_model_set('full') model_hash_set = {}