diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index 4905063..840c2ce 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -17,35 +17,22 @@ INFERENCE_POOLS : InferencePoolSet =\ } -def has_inference_model(model_context : str, model_name : str) -> bool: +def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool: + global INFERENCE_POOLS + while process_manager.is_checking(): sleep(0.5) app_context = detect_app_context() inference_context = get_inference_context(model_context) - inference_pool = INFERENCE_POOLS.get(app_context).get(inference_context) - if inference_pool: - return model_name in inference_pool - return False + if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): + INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) + if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): + INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) + if not INFERENCE_POOLS.get(app_context).get(inference_context): + INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), state_manager.get_item('execution_providers')) - -def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool: - global INFERENCE_POOLS - - with thread_lock(): - while process_manager.is_checking(): - sleep(0.5) - app_context = detect_app_context() - inference_context = get_inference_context(model_context) - - if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): - INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) - if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): - INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) - if not INFERENCE_POOLS.get(app_context).get(inference_context): - INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), state_manager.get_item('execution_providers')) - - return INFERENCE_POOLS.get(app_context).get(inference_context) + return INFERENCE_POOLS.get(app_context).get(inference_context) def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool: diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index b433d6c..3d04bc3 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -238,10 +238,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: return model_set -def has_inference_model(model_name : str) -> bool: - return inference_manager.has_inference_model(__name__, model_name) - - def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') return inference_manager.get_inference_pool(__name__, model_sources) @@ -361,12 +357,11 @@ def forward(crop_vision_frame : VisionFrame, deep_swapper_morph : DeepSwapperMor def has_morph_input() -> bool: - if has_inference_model('deep_swapper'): - deep_swapper = get_inference_pool().get('deep_swapper') + deep_swapper = get_inference_pool().get('deep_swapper') - for deep_swapper_input in deep_swapper.get_inputs(): - if deep_swapper_input.name == 'morph_value:0': - return True + for deep_swapper_input in deep_swapper.get_inputs(): + if deep_swapper_input.name == 'morph_value:0': + return True return False diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py index 529252f..48a896a 100755 --- a/facefusion/processors/modules/face_enhancer.py +++ b/facefusion/processors/modules/face_enhancer.py @@ -221,10 +221,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: } -def has_inference_model(model_name : str) -> bool: - return inference_manager.has_inference_model(__name__, model_name) - - def get_inference_pool() -> InferencePool: model_sources = get_model_options().get('sources') return inference_manager.get_inference_pool(__name__, model_sources) @@ -328,12 +324,11 @@ def forward(crop_vision_frame : VisionFrame, face_enhancer_weight : FaceEnhancer def has_weight_input() -> bool: - if has_inference_model('face_enhancer'): - face_enhancer = get_inference_pool().get('face_enhancer') + face_enhancer = get_inference_pool().get('face_enhancer') - for deep_swapper_input in face_enhancer.get_inputs(): - if deep_swapper_input.name == 'weight': - return True + for deep_swapper_input in face_enhancer.get_inputs(): + if deep_swapper_input.name == 'weight': + return True return False diff --git a/facefusion/uis/components/deep_swapper_options.py b/facefusion/uis/components/deep_swapper_options.py index 11ad70c..590d424 100755 --- a/facefusion/uis/components/deep_swapper_options.py +++ b/facefusion/uis/components/deep_swapper_options.py @@ -6,7 +6,6 @@ from facefusion import state_manager, wording from facefusion.common_helper import calc_int_step from facefusion.processors import choices as processors_choices from facefusion.processors.core import load_processor_module -from facefusion.processors.modules.deep_swapper import has_morph_input from facefusion.processors.typing import DeepSwapperModel from facefusion.uis.core import get_ui_component, register_ui_component @@ -31,7 +30,7 @@ def render() -> None: step = calc_int_step(processors_choices.deep_swapper_morph_range), minimum = processors_choices.deep_swapper_morph_range[0], maximum = processors_choices.deep_swapper_morph_range[-1], - visible = has_morph_input() + visible = has_deep_swapper and load_processor_module('deep_swapper').get_inference_pool() and load_processor_module('deep_swapper').has_morph_input() ) register_ui_component('deep_swapper_model_dropdown', DEEP_SWAPPER_MODEL_DROPDOWN) register_ui_component('deep_swapper_morph_slider', DEEP_SWAPPER_MORPH_SLIDER) @@ -48,7 +47,7 @@ def listen() -> None: def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]: has_deep_swapper = 'deep_swapper' in processors - return gradio.Dropdown(visible = has_deep_swapper), gradio.Slider(visible = has_morph_input()) + return gradio.Dropdown(visible = has_deep_swapper), gradio.Slider(visible = has_deep_swapper and load_processor_module('deep_swapper').get_inference_pool() and load_processor_module('deep_swapper').has_morph_input()) def update_deep_swapper_model(deep_swapper_model : DeepSwapperModel) -> Tuple[gradio.Dropdown, gradio.Slider]: @@ -57,7 +56,7 @@ def update_deep_swapper_model(deep_swapper_model : DeepSwapperModel) -> Tuple[gr state_manager.set_item('deep_swapper_model', deep_swapper_model) if deep_swapper_module.pre_check(): - return gradio.Dropdown(value = state_manager.get_item('deep_swapper_model')), gradio.Slider(visible = has_morph_input()) + return gradio.Dropdown(value = state_manager.get_item('deep_swapper_model')), gradio.Slider(visible = deep_swapper_module.has_morph_input()) return gradio.Dropdown(), gradio.Slider() diff --git a/facefusion/uis/components/face_enhancer_options.py b/facefusion/uis/components/face_enhancer_options.py index 54a06ce..65d20af 100755 --- a/facefusion/uis/components/face_enhancer_options.py +++ b/facefusion/uis/components/face_enhancer_options.py @@ -6,7 +6,6 @@ from facefusion import state_manager, wording from facefusion.common_helper import calc_float_step, calc_int_step from facefusion.processors import choices as processors_choices from facefusion.processors.core import load_processor_module -from facefusion.processors.modules.face_enhancer import has_weight_input from facefusion.processors.typing import FaceEnhancerModel from facefusion.uis.core import get_ui_component, register_ui_component @@ -41,7 +40,7 @@ def render() -> None: step = calc_float_step(processors_choices.face_enhancer_weight_range), minimum = processors_choices.face_enhancer_weight_range[0], maximum = processors_choices.face_enhancer_weight_range[-1], - visible = has_face_enhancer and has_weight_input() + visible = has_face_enhancer and load_processor_module('face_enhancer').get_inference_pool() and load_processor_module('face_enhancer').has_weight_input() ) register_ui_component('face_enhancer_model_dropdown', FACE_ENHANCER_MODEL_DROPDOWN) register_ui_component('face_enhancer_blend_slider', FACE_ENHANCER_BLEND_SLIDER) @@ -60,7 +59,7 @@ def listen() -> None: def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider, gradio.Slider]: has_face_enhancer = 'face_enhancer' in processors - return gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer and has_weight_input()) + return gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer and load_processor_module('face_enhancer').get_inference_pool() and load_processor_module('face_enhancer').has_weight_input()) def update_face_enhancer_model(face_enhancer_model : FaceEnhancerModel) -> Tuple[gradio.Dropdown, gradio.Slider]: @@ -69,7 +68,7 @@ def update_face_enhancer_model(face_enhancer_model : FaceEnhancerModel) -> Tuple state_manager.set_item('face_enhancer_model', face_enhancer_model) if face_enhancer_module.pre_check(): - return gradio.Dropdown(value = state_manager.get_item('face_enhancer_model')), gradio.Slider(visible = has_weight_input()) + return gradio.Dropdown(value = state_manager.get_item('face_enhancer_model')), gradio.Slider(visible = face_enhancer_module.has_weight_input()) return gradio.Dropdown(), gradio.Slider()