From 1bdc02014ce00babb1c916c4eeeb3f9077e535c8 Mon Sep 17 00:00:00 2001 From: henryruhs Date: Sun, 9 Feb 2025 01:43:58 +0100 Subject: [PATCH] Fix broken inference pools part2 --- facefusion/content_analyser.py | 7 ++++--- facefusion/face_classifier.py | 7 ++++--- facefusion/face_detector.py | 7 ++++--- facefusion/face_landmarker.py | 7 ++++--- facefusion/face_masker.py | 7 ++++--- facefusion/face_recognizer.py | 7 ++++--- facefusion/inference_manager.py | 12 +++++------ facefusion/processors/modules/age_modifier.py | 7 ++++--- facefusion/processors/modules/deep_swapper.py | 7 ++++--- .../processors/modules/expression_restorer.py | 7 ++++--- facefusion/processors/modules/face_editor.py | 7 ++++--- .../processors/modules/face_enhancer.py | 7 ++++--- facefusion/processors/modules/face_swapper.py | 16 ++++++++++----- .../processors/modules/frame_colorizer.py | 7 ++++--- .../processors/modules/frame_enhancer.py | 20 ++++++++++++------- facefusion/processors/modules/lip_syncer.py | 7 ++++--- facefusion/voice_extractor.py | 7 ++++--- tests/test_inference_manager.py | 11 +++++----- 18 files changed, 92 insertions(+), 65 deletions(-) diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index f2e6404..cef3095 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -42,13 +42,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ 'yolo_nsfw' ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ 'yolo_nsfw' ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/face_classifier.py b/facefusion/face_classifier.py index a58c27a..c1c7637 100644 --- a/facefusion/face_classifier.py +++ b/facefusion/face_classifier.py @@ -42,13 +42,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ 'fairface' ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ 'fairface' ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/face_detector.py b/facefusion/face_detector.py index 4435599..3d0518d 100644 --- a/facefusion/face_detector.py +++ b/facefusion/face_detector.py @@ -78,13 +78,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_detector_model') ] _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - _, model_sources = collect_model_downloads() - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('face_detector_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_landmarker.py b/facefusion/face_landmarker.py index 8d0140d..a90948e 100644 --- a/facefusion/face_landmarker.py +++ b/facefusion/face_landmarker.py @@ -79,13 +79,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ] _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - _, model_sources = collect_model_downloads() - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ] + inference_manager.clear_inference_pool(__name__, model_names) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_masker.py b/facefusion/face_masker.py index 2a5c85d..04ead79 100755 --- a/facefusion/face_masker.py +++ b/facefusion/face_masker.py @@ -121,13 +121,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model')] _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - _, model_sources = collect_model_downloads() - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: diff --git a/facefusion/face_recognizer.py b/facefusion/face_recognizer.py index c7f1f99..a794b0d 100644 --- a/facefusion/face_recognizer.py +++ b/facefusion/face_recognizer.py @@ -40,13 +40,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ 'arcface' ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ 'arcface' ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index db009bb..cf15f27 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -17,7 +17,7 @@ INFERENCE_POOLS : InferencePoolSet =\ } -def get_inference_pool(module_name : str, model_sources : DownloadSet) -> InferencePool: +def get_inference_pool(module_name : str, model_names : List[str], model_sources : DownloadSet) -> InferencePool: global INFERENCE_POOLS while process_manager.is_checking(): @@ -25,7 +25,7 @@ def get_inference_pool(module_name : str, model_sources : DownloadSet) -> Infere execution_device_id = state_manager.get_item('execution_device_id') execution_providers = resolve_execution_providers(module_name) app_context = detect_app_context() - inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers) + inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) @@ -48,13 +48,13 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str return inference_pool -def clear_inference_pool(module_name : str, model_sources : DownloadSet) -> None: +def clear_inference_pool(module_name : str, model_names : List[str]) -> None: global INFERENCE_POOLS execution_device_id = state_manager.get_item('execution_device_id') execution_providers = resolve_execution_providers(module_name) app_context = detect_app_context() - inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers) + inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) if INFERENCE_POOLS.get(app_context).get(inference_context): del INFERENCE_POOLS[app_context][inference_context] @@ -65,8 +65,8 @@ def create_inference_session(model_path : str, execution_device_id : str, execut return InferenceSession(model_path, providers = inference_session_providers) -def get_inference_context(module_name : str, model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: - inference_context = '.'.join([ module_name ] + list(model_sources.keys()) + [ execution_device_id ] + list(execution_providers)) +def get_inference_context(module_name : str, model_names : List[str], execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str: + inference_context = '.'.join([ module_name ] + model_names + [ execution_device_id ] + list(execution_providers)) return inference_context diff --git a/facefusion/processors/modules/age_modifier.py b/facefusion/processors/modules/age_modifier.py index 1b2899d..1172878 100755 --- a/facefusion/processors/modules/age_modifier.py +++ b/facefusion/processors/modules/age_modifier.py @@ -64,13 +64,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('age_modifier_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('age_modifier_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index 93fc264..5d6e595 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -240,13 +240,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('deep_swapper_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('deep_swapper_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/expression_restorer.py b/facefusion/processors/modules/expression_restorer.py index b88e538..c83ca76 100755 --- a/facefusion/processors/modules/expression_restorer.py +++ b/facefusion/processors/modules/expression_restorer.py @@ -75,13 +75,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('expression_restorer_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('expression_restorer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_editor.py b/facefusion/processors/modules/face_editor.py index 51103b3..87af4d3 100755 --- a/facefusion/processors/modules/face_editor.py +++ b/facefusion/processors/modules/face_editor.py @@ -105,13 +105,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_editor_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('face_editor_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py index cf4e890..bfdb06a 100755 --- a/facefusion/processors/modules/face_enhancer.py +++ b/facefusion/processors/modules/face_enhancer.py @@ -222,13 +222,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('face_enhancer_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('face_enhancer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/processors/modules/face_swapper.py b/facefusion/processors/modules/face_swapper.py index 21ecd5b..0adfc76 100755 --- a/facefusion/processors/modules/face_swapper.py +++ b/facefusion/processors/modules/face_swapper.py @@ -336,21 +336,27 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ get_face_swapper_model() ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ get_face_swapper_model() ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: + face_swapper_model = get_face_swapper_model() + return create_static_model_set('full').get(face_swapper_model) + + +def get_face_swapper_model() -> str: face_swapper_model = state_manager.get_item('face_swapper_model') if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16': - return create_static_model_set('full').get('inswapper_128') - return create_static_model_set('full').get(face_swapper_model) + return 'inswapper_128' + return face_swapper_model def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py index 06b5615..1daf35c 100644 --- a/facefusion/processors/modules/frame_colorizer.py +++ b/facefusion/processors/modules/frame_colorizer.py @@ -129,13 +129,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('frame_colorizer_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('frame_colorizer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def resolve_execution_providers() -> List[ExecutionProvider]: diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py index 6d1c1d7..fd36a16 100644 --- a/facefusion/processors/modules/frame_enhancer.py +++ b/facefusion/processors/modules/frame_enhancer.py @@ -386,26 +386,32 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ get_frame_enhancer_model() ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ get_frame_enhancer_model() ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: + frame_enhancer_model = get_frame_enhancer_model() + return create_static_model_set('full').get(frame_enhancer_model) + + +def get_frame_enhancer_model() -> str: frame_enhancer_model = state_manager.get_item('frame_enhancer_model') if has_execution_provider('coreml'): if frame_enhancer_model == 'real_esrgan_x2_fp16': - return create_static_model_set('full').get('real_esrgan_x2') + return 'real_esrgan_x2' if frame_enhancer_model == 'real_esrgan_x4_fp16': - return create_static_model_set('full').get('real_esrgan_x4') + return 'real_esrgan_x4' if frame_enhancer_model == 'real_esrgan_x8_fp16': - return create_static_model_set('full').get('real_esrgan_x8') - return create_static_model_set('full').get(frame_enhancer_model) + return 'real_esrgan_x8' + return frame_enhancer_model def register_args(program : ArgumentParser) -> None: diff --git a/facefusion/processors/modules/lip_syncer.py b/facefusion/processors/modules/lip_syncer.py index d12a02d..d8e7929 100755 --- a/facefusion/processors/modules/lip_syncer.py +++ b/facefusion/processors/modules/lip_syncer.py @@ -74,13 +74,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ state_manager.get_item('lip_syncer_model') ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ state_manager.get_item('lip_syncer_model') ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/facefusion/voice_extractor.py b/facefusion/voice_extractor.py index a115cce..a1f2ab1 100644 --- a/facefusion/voice_extractor.py +++ b/facefusion/voice_extractor.py @@ -38,13 +38,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: + model_names = [ 'kim_vocal_2' ] model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_sources) + return inference_manager.get_inference_pool(__name__, model_names, model_sources) def clear_inference_pool() -> None: - model_sources = get_model_options().get('sources') - inference_manager.clear_inference_pool(__name__, model_sources) + model_names = [ 'kim_vocal_2' ] + inference_manager.clear_inference_pool(__name__, model_names) def get_model_options() -> ModelOptions: diff --git a/tests/test_inference_manager.py b/tests/test_inference_manager.py index dbda051..797c808 100644 --- a/tests/test_inference_manager.py +++ b/tests/test_inference_manager.py @@ -16,16 +16,17 @@ def before_all() -> None: def test_get_inference_pool() -> None: + model_names = [ 'yolo_nsfw' ] model_sources = content_analyser.get_model_options().get('sources') with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): - get_inference_pool('test', model_sources) + get_inference_pool('facefusion.content_analyser', model_names, model_sources) - assert isinstance(INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): - get_inference_pool('test', model_sources) + get_inference_pool('facefusion.content_analyser', model_names, model_sources) - assert isinstance(INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) - assert INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser') + assert INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser')