diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index 8b5de1b..f2e6404 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -47,7 +47,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/face_classifier.py b/facefusion/face_classifier.py index dcae123..a58c27a 100644 --- a/facefusion/face_classifier.py +++ b/facefusion/face_classifier.py @@ -47,7 +47,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/face_detector.py b/facefusion/face_detector.py index 590bc3f..4435599 100644 --- a/facefusion/face_detector.py +++ b/facefusion/face_detector.py @@ -83,7 +83,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + _, model_sources = collect_model_downloads() + inference_manager.clear_inference_pool(__name__, model_sources) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_landmarker.py b/facefusion/face_landmarker.py index e2453f5..8d0140d 100644 --- a/facefusion/face_landmarker.py +++ b/facefusion/face_landmarker.py @@ -84,7 +84,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + _, model_sources = collect_model_downloads() + inference_manager.clear_inference_pool(__name__, model_sources) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_masker.py b/facefusion/face_masker.py index cd1fc51..2a5c85d 100755 --- a/facefusion/face_masker.py +++ b/facefusion/face_masker.py @@ -126,7 +126,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + _, model_sources = collect_model_downloads() + inference_manager.clear_inference_pool(__name__, model_sources) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_recognizer.py b/facefusion/face_recognizer.py index af95a07..c7f1f99 100644 --- a/facefusion/face_recognizer.py +++ b/facefusion/face_recognizer.py @@ -45,7 +45,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index 0ae8103..db009bb 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -17,15 +17,15 @@ INFERENCE_POOLS : InferencePoolSet =\ } -def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool: +def get_inference_pool(module_name : str, model_sources : DownloadSet) -> InferencePool: global INFERENCE_POOLS while process_manager.is_checking(): sleep(0.5) execution_device_id = state_manager.get_item('execution_device_id') - execution_providers = resolve_execution_providers(model_context) + execution_providers = resolve_execution_providers(module_name) app_context = detect_app_context() - inference_context = get_inference_context(model_context, execution_device_id, execution_providers) + inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers) if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) @@ -48,13 +48,13 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str return inference_pool -def clear_inference_pool(model_context : str) -> None: +def clear_inference_pool(module_name : str, model_sources : DownloadSet) -> None: global INFERENCE_POOLS execution_device_id = state_manager.get_item('execution_device_id') - execution_providers = resolve_execution_providers(model_context) + execution_providers = resolve_execution_providers(module_name) app_context = detect_app_context() - inference_context = get_inference_context(model_context, execution_device_id, execution_providers) + inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers) if INFERENCE_POOLS.get(app_context).get(inference_context): del INFERENCE_POOLS[app_context][inference_context] @@ -65,13 +65,13 @@ def create_inference_session(model_path : str, execution_device_id : str, execut return InferenceSession(model_path, providers = inference_session_providers) -def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: - inference_context = model_context + '.' + execution_device_id + '.' + '_'.join(execution_providers) +def get_inference_context(module_name : str, model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: + inference_context = '.'.join([ module_name ] + list(model_sources.keys()) + [ execution_device_id ] + list(execution_providers)) return inference_context -def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]: - module = importlib.import_module(model_context) +def resolve_execution_providers(module_name : str) -> List[ExecutionProvider]: + module = importlib.import_module(module_name) if hasattr(module, 'resolve_execution_providers'): return getattr(module, 'resolve_execution_providers')() diff --git a/facefusion/processors/modules/age_modifier.py b/facefusion/processors/modules/age_modifier.py index 79903bf..1b2899d 100755 --- a/facefusion/processors/modules/age_modifier.py +++ b/facefusion/processors/modules/age_modifier.py @@ -69,7 +69,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index 786a4fc..93fc264 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -245,7 +245,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/expression_restorer.py b/facefusion/processors/modules/expression_restorer.py index 68a9c78..b88e538 100755 --- a/facefusion/processors/modules/expression_restorer.py +++ b/facefusion/processors/modules/expression_restorer.py @@ -80,7 +80,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_editor.py b/facefusion/processors/modules/face_editor.py index 4d33cfe..51103b3 100755 --- a/facefusion/processors/modules/face_editor.py +++ b/facefusion/processors/modules/face_editor.py @@ -110,7 +110,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py index f722e7b..cf4e890 100755 --- a/facefusion/processors/modules/face_enhancer.py +++ b/facefusion/processors/modules/face_enhancer.py @@ -227,7 +227,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_swapper.py b/facefusion/processors/modules/face_swapper.py index d432bf8..21ecd5b 100755 --- a/facefusion/processors/modules/face_swapper.py +++ b/facefusion/processors/modules/face_swapper.py @@ -341,7 +341,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py index 9fef9c1..06b5615 100644 --- a/facefusion/processors/modules/frame_colorizer.py +++ b/facefusion/processors/modules/frame_colorizer.py @@ -134,7 +134,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def resolve_execution_providers() -> List[ExecutionProvider]: diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py index d360e35..6d1c1d7 100644 --- a/facefusion/processors/modules/frame_enhancer.py +++ b/facefusion/processors/modules/frame_enhancer.py @@ -391,7 +391,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/lip_syncer.py b/facefusion/processors/modules/lip_syncer.py index ad49771..d12a02d 100755 --- a/facefusion/processors/modules/lip_syncer.py +++ b/facefusion/processors/modules/lip_syncer.py @@ -79,7 +79,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/facefusion/voice_extractor.py b/facefusion/voice_extractor.py index f1f4a0a..a115cce 100644 --- a/facefusion/voice_extractor.py +++ b/facefusion/voice_extractor.py @@ -43,7 +43,8 @@ def get_inference_pool() -> InferencePool: def clear_inference_pool() -> None: - inference_manager.clear_inference_pool(__name__) + model_sources = get_model_options().get('sources') + inference_manager.clear_inference_pool(__name__, model_sources) def get_model_options() -> ModelOptions: diff --git a/tests/test_inference_manager.py b/tests/test_inference_manager.py index 3667cd2..dbda051 100644 --- a/tests/test_inference_manager.py +++ b/tests/test_inference_manager.py @@ -21,11 +21,11 @@ def test_get_inference_pool() -> None: with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): get_inference_pool('test', model_sources) - assert isinstance(INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession) with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): get_inference_pool('test', model_sources) - assert isinstance(INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession) - assert INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser') + assert INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser')