Drop has_inference_model and solve issue on Gradio side

This commit is contained in:
henryruhs
2025-01-14 00:44:35 +01:00
parent b11cb07aea
commit ed8e25dbb2
5 changed files with 24 additions and 49 deletions

View File

@@ -17,35 +17,22 @@ INFERENCE_POOLS : InferencePoolSet =\
}
def has_inference_model(model_context : str, model_name : str) -> bool:
def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool:
global INFERENCE_POOLS
while process_manager.is_checking():
sleep(0.5)
app_context = detect_app_context()
inference_context = get_inference_context(model_context)
inference_pool = INFERENCE_POOLS.get(app_context).get(inference_context)
if inference_pool:
return model_name in inference_pool
return False
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
if not INFERENCE_POOLS.get(app_context).get(inference_context):
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), state_manager.get_item('execution_providers'))
def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool:
global INFERENCE_POOLS
with thread_lock():
while process_manager.is_checking():
sleep(0.5)
app_context = detect_app_context()
inference_context = get_inference_context(model_context)
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
if not INFERENCE_POOLS.get(app_context).get(inference_context):
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, state_manager.get_item('execution_device_id'), state_manager.get_item('execution_providers'))
return INFERENCE_POOLS.get(app_context).get(inference_context)
return INFERENCE_POOLS.get(app_context).get(inference_context)
def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool:

View File

@@ -238,10 +238,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
return model_set
def has_inference_model(model_name : str) -> bool:
return inference_manager.has_inference_model(__name__, model_name)
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
@@ -361,12 +357,11 @@ def forward(crop_vision_frame : VisionFrame, deep_swapper_morph : DeepSwapperMor
def has_morph_input() -> bool:
if has_inference_model('deep_swapper'):
deep_swapper = get_inference_pool().get('deep_swapper')
deep_swapper = get_inference_pool().get('deep_swapper')
for deep_swapper_input in deep_swapper.get_inputs():
if deep_swapper_input.name == 'morph_value:0':
return True
for deep_swapper_input in deep_swapper.get_inputs():
if deep_swapper_input.name == 'morph_value:0':
return True
return False

View File

@@ -221,10 +221,6 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
}
def has_inference_model(model_name : str) -> bool:
return inference_manager.has_inference_model(__name__, model_name)
def get_inference_pool() -> InferencePool:
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_sources)
@@ -328,12 +324,11 @@ def forward(crop_vision_frame : VisionFrame, face_enhancer_weight : FaceEnhancer
def has_weight_input() -> bool:
if has_inference_model('face_enhancer'):
face_enhancer = get_inference_pool().get('face_enhancer')
face_enhancer = get_inference_pool().get('face_enhancer')
for deep_swapper_input in face_enhancer.get_inputs():
if deep_swapper_input.name == 'weight':
return True
for deep_swapper_input in face_enhancer.get_inputs():
if deep_swapper_input.name == 'weight':
return True
return False

View File

@@ -6,7 +6,6 @@ from facefusion import state_manager, wording
from facefusion.common_helper import calc_int_step
from facefusion.processors import choices as processors_choices
from facefusion.processors.core import load_processor_module
from facefusion.processors.modules.deep_swapper import has_morph_input
from facefusion.processors.typing import DeepSwapperModel
from facefusion.uis.core import get_ui_component, register_ui_component
@@ -31,7 +30,7 @@ def render() -> None:
step = calc_int_step(processors_choices.deep_swapper_morph_range),
minimum = processors_choices.deep_swapper_morph_range[0],
maximum = processors_choices.deep_swapper_morph_range[-1],
visible = has_morph_input()
visible = has_deep_swapper and load_processor_module('deep_swapper').get_inference_pool() and load_processor_module('deep_swapper').has_morph_input()
)
register_ui_component('deep_swapper_model_dropdown', DEEP_SWAPPER_MODEL_DROPDOWN)
register_ui_component('deep_swapper_morph_slider', DEEP_SWAPPER_MORPH_SLIDER)
@@ -48,7 +47,7 @@ def listen() -> None:
def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider]:
has_deep_swapper = 'deep_swapper' in processors
return gradio.Dropdown(visible = has_deep_swapper), gradio.Slider(visible = has_morph_input())
return gradio.Dropdown(visible = has_deep_swapper), gradio.Slider(visible = has_deep_swapper and load_processor_module('deep_swapper').get_inference_pool() and load_processor_module('deep_swapper').has_morph_input())
def update_deep_swapper_model(deep_swapper_model : DeepSwapperModel) -> Tuple[gradio.Dropdown, gradio.Slider]:
@@ -57,7 +56,7 @@ def update_deep_swapper_model(deep_swapper_model : DeepSwapperModel) -> Tuple[gr
state_manager.set_item('deep_swapper_model', deep_swapper_model)
if deep_swapper_module.pre_check():
return gradio.Dropdown(value = state_manager.get_item('deep_swapper_model')), gradio.Slider(visible = has_morph_input())
return gradio.Dropdown(value = state_manager.get_item('deep_swapper_model')), gradio.Slider(visible = deep_swapper_module.has_morph_input())
return gradio.Dropdown(), gradio.Slider()

View File

@@ -6,7 +6,6 @@ from facefusion import state_manager, wording
from facefusion.common_helper import calc_float_step, calc_int_step
from facefusion.processors import choices as processors_choices
from facefusion.processors.core import load_processor_module
from facefusion.processors.modules.face_enhancer import has_weight_input
from facefusion.processors.typing import FaceEnhancerModel
from facefusion.uis.core import get_ui_component, register_ui_component
@@ -41,7 +40,7 @@ def render() -> None:
step = calc_float_step(processors_choices.face_enhancer_weight_range),
minimum = processors_choices.face_enhancer_weight_range[0],
maximum = processors_choices.face_enhancer_weight_range[-1],
visible = has_face_enhancer and has_weight_input()
visible = has_face_enhancer and load_processor_module('face_enhancer').get_inference_pool() and load_processor_module('face_enhancer').has_weight_input()
)
register_ui_component('face_enhancer_model_dropdown', FACE_ENHANCER_MODEL_DROPDOWN)
register_ui_component('face_enhancer_blend_slider', FACE_ENHANCER_BLEND_SLIDER)
@@ -60,7 +59,7 @@ def listen() -> None:
def remote_update(processors : List[str]) -> Tuple[gradio.Dropdown, gradio.Slider, gradio.Slider]:
has_face_enhancer = 'face_enhancer' in processors
return gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer and has_weight_input())
return gradio.Dropdown(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer), gradio.Slider(visible = has_face_enhancer and load_processor_module('face_enhancer').get_inference_pool() and load_processor_module('face_enhancer').has_weight_input())
def update_face_enhancer_model(face_enhancer_model : FaceEnhancerModel) -> Tuple[gradio.Dropdown, gradio.Slider]:
@@ -69,7 +68,7 @@ def update_face_enhancer_model(face_enhancer_model : FaceEnhancerModel) -> Tuple
state_manager.set_item('face_enhancer_model', face_enhancer_model)
if face_enhancer_module.pre_check():
return gradio.Dropdown(value = state_manager.get_item('face_enhancer_model')), gradio.Slider(visible = has_weight_input())
return gradio.Dropdown(value = state_manager.get_item('face_enhancer_model')), gradio.Slider(visible = face_enhancer_module.has_weight_input())
return gradio.Dropdown(), gradio.Slider()