Fix broken inference pools (#866)
* Fix broken inference pools * Fix broken inference pools
This commit is contained in:
@@ -47,7 +47,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -47,7 +47,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -83,7 +83,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
_, model_sources = collect_model_downloads()
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||||
|
|||||||
@@ -84,7 +84,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
_, model_sources = collect_model_downloads()
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||||
|
|||||||
@@ -126,7 +126,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
_, model_sources = collect_model_downloads()
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||||
|
|||||||
@@ -45,7 +45,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -17,15 +17,15 @@ INFERENCE_POOLS : InferencePoolSet =\
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_inference_pool(model_context : str, model_sources : DownloadSet) -> InferencePool:
|
def get_inference_pool(module_name : str, model_sources : DownloadSet) -> InferencePool:
|
||||||
global INFERENCE_POOLS
|
global INFERENCE_POOLS
|
||||||
|
|
||||||
while process_manager.is_checking():
|
while process_manager.is_checking():
|
||||||
sleep(0.5)
|
sleep(0.5)
|
||||||
execution_device_id = state_manager.get_item('execution_device_id')
|
execution_device_id = state_manager.get_item('execution_device_id')
|
||||||
execution_providers = resolve_execution_providers(model_context)
|
execution_providers = resolve_execution_providers(module_name)
|
||||||
app_context = detect_app_context()
|
app_context = detect_app_context()
|
||||||
inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
|
inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers)
|
||||||
|
|
||||||
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
||||||
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
||||||
@@ -48,13 +48,13 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
|
|||||||
return inference_pool
|
return inference_pool
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool(model_context : str) -> None:
|
def clear_inference_pool(module_name : str, model_sources : DownloadSet) -> None:
|
||||||
global INFERENCE_POOLS
|
global INFERENCE_POOLS
|
||||||
|
|
||||||
execution_device_id = state_manager.get_item('execution_device_id')
|
execution_device_id = state_manager.get_item('execution_device_id')
|
||||||
execution_providers = resolve_execution_providers(model_context)
|
execution_providers = resolve_execution_providers(module_name)
|
||||||
app_context = detect_app_context()
|
app_context = detect_app_context()
|
||||||
inference_context = get_inference_context(model_context, execution_device_id, execution_providers)
|
inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers)
|
||||||
|
|
||||||
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||||
del INFERENCE_POOLS[app_context][inference_context]
|
del INFERENCE_POOLS[app_context][inference_context]
|
||||||
@@ -65,13 +65,13 @@ def create_inference_session(model_path : str, execution_device_id : str, execut
|
|||||||
return InferenceSession(model_path, providers = inference_session_providers)
|
return InferenceSession(model_path, providers = inference_session_providers)
|
||||||
|
|
||||||
|
|
||||||
def get_inference_context(model_context : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
|
def get_inference_context(module_name : str, model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
|
||||||
inference_context = model_context + '.' + execution_device_id + '.' + '_'.join(execution_providers)
|
inference_context = '.'.join([ module_name ] + list(model_sources.keys()) + [ execution_device_id ] + list(execution_providers))
|
||||||
return inference_context
|
return inference_context
|
||||||
|
|
||||||
|
|
||||||
def resolve_execution_providers(model_context : str) -> List[ExecutionProvider]:
|
def resolve_execution_providers(module_name : str) -> List[ExecutionProvider]:
|
||||||
module = importlib.import_module(model_context)
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
if hasattr(module, 'resolve_execution_providers'):
|
if hasattr(module, 'resolve_execution_providers'):
|
||||||
return getattr(module, 'resolve_execution_providers')()
|
return getattr(module, 'resolve_execution_providers')()
|
||||||
|
|||||||
@@ -69,7 +69,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -245,7 +245,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -80,7 +80,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -110,7 +110,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -227,7 +227,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -341,7 +341,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -134,7 +134,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||||
|
|||||||
@@ -391,7 +391,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -79,7 +79,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -43,7 +43,8 @@ def get_inference_pool() -> InferencePool:
|
|||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
inference_manager.clear_inference_pool(__name__)
|
model_sources = get_model_options().get('sources')
|
||||||
|
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -21,11 +21,11 @@ def test_get_inference_pool() -> None:
|
|||||||
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
|
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
|
||||||
get_inference_pool('test', model_sources)
|
get_inference_pool('test', model_sources)
|
||||||
|
|
||||||
assert isinstance(INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser'), InferenceSession)
|
assert isinstance(INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession)
|
||||||
|
|
||||||
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
|
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
|
||||||
get_inference_pool('test', model_sources)
|
get_inference_pool('test', model_sources)
|
||||||
|
|
||||||
assert isinstance(INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser'), InferenceSession)
|
assert isinstance(INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession)
|
||||||
|
|
||||||
assert INFERENCE_POOLS.get('cli').get('test.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.0.cpu').get('content_analyser')
|
assert INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser')
|
||||||
|
|||||||
Reference in New Issue
Block a user