Fix broken inference pools part2
This commit is contained in:
@@ -42,13 +42,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ 'yolo_nsfw' ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ 'yolo_nsfw' ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -42,13 +42,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ 'fairface' ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ 'fairface' ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -78,13 +78,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('face_detector_model') ]
|
||||||
_, model_sources = collect_model_downloads()
|
_, model_sources = collect_model_downloads()
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
_, model_sources = collect_model_downloads()
|
model_names = [ state_manager.get_item('face_detector_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||||
|
|||||||
@@ -79,13 +79,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ]
|
||||||
_, model_sources = collect_model_downloads()
|
_, model_sources = collect_model_downloads()
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
_, model_sources = collect_model_downloads()
|
model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||||
|
|||||||
@@ -121,13 +121,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model')]
|
||||||
_, model_sources = collect_model_downloads()
|
_, model_sources = collect_model_downloads()
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
_, model_sources = collect_model_downloads()
|
model_names = [ state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||||
|
|||||||
@@ -40,13 +40,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ 'arcface' ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ 'arcface' ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ INFERENCE_POOLS : InferencePoolSet =\
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_inference_pool(module_name : str, model_sources : DownloadSet) -> InferencePool:
|
def get_inference_pool(module_name : str, model_names : List[str], model_sources : DownloadSet) -> InferencePool:
|
||||||
global INFERENCE_POOLS
|
global INFERENCE_POOLS
|
||||||
|
|
||||||
while process_manager.is_checking():
|
while process_manager.is_checking():
|
||||||
@@ -25,7 +25,7 @@ def get_inference_pool(module_name : str, model_sources : DownloadSet) -> Infere
|
|||||||
execution_device_id = state_manager.get_item('execution_device_id')
|
execution_device_id = state_manager.get_item('execution_device_id')
|
||||||
execution_providers = resolve_execution_providers(module_name)
|
execution_providers = resolve_execution_providers(module_name)
|
||||||
app_context = detect_app_context()
|
app_context = detect_app_context()
|
||||||
inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers)
|
inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers)
|
||||||
|
|
||||||
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
|
||||||
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
|
||||||
@@ -48,13 +48,13 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
|
|||||||
return inference_pool
|
return inference_pool
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool(module_name : str, model_sources : DownloadSet) -> None:
|
def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
|
||||||
global INFERENCE_POOLS
|
global INFERENCE_POOLS
|
||||||
|
|
||||||
execution_device_id = state_manager.get_item('execution_device_id')
|
execution_device_id = state_manager.get_item('execution_device_id')
|
||||||
execution_providers = resolve_execution_providers(module_name)
|
execution_providers = resolve_execution_providers(module_name)
|
||||||
app_context = detect_app_context()
|
app_context = detect_app_context()
|
||||||
inference_context = get_inference_context(module_name, model_sources, execution_device_id, execution_providers)
|
inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers)
|
||||||
|
|
||||||
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
if INFERENCE_POOLS.get(app_context).get(inference_context):
|
||||||
del INFERENCE_POOLS[app_context][inference_context]
|
del INFERENCE_POOLS[app_context][inference_context]
|
||||||
@@ -65,8 +65,8 @@ def create_inference_session(model_path : str, execution_device_id : str, execut
|
|||||||
return InferenceSession(model_path, providers = inference_session_providers)
|
return InferenceSession(model_path, providers = inference_session_providers)
|
||||||
|
|
||||||
|
|
||||||
def get_inference_context(module_name : str, model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
|
def get_inference_context(module_name : str, model_names : List[str], execution_device_id : str, execution_providers : List[ExecutionProvider]) -> str:
|
||||||
inference_context = '.'.join([ module_name ] + list(model_sources.keys()) + [ execution_device_id ] + list(execution_providers))
|
inference_context = '.'.join([ module_name ] + model_names + [ execution_device_id ] + list(execution_providers))
|
||||||
return inference_context
|
return inference_context
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -64,13 +64,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('age_modifier_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('age_modifier_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -240,13 +240,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('deep_swapper_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('deep_swapper_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -75,13 +75,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('expression_restorer_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('expression_restorer_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -105,13 +105,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('face_editor_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('face_editor_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -222,13 +222,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('face_enhancer_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('face_enhancer_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -336,21 +336,27 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ get_face_swapper_model() ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ get_face_swapper_model() ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
face_swapper_model = get_face_swapper_model()
|
||||||
|
return create_static_model_set('full').get(face_swapper_model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_face_swapper_model() -> str:
|
||||||
face_swapper_model = state_manager.get_item('face_swapper_model')
|
face_swapper_model = state_manager.get_item('face_swapper_model')
|
||||||
|
|
||||||
if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16':
|
if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16':
|
||||||
return create_static_model_set('full').get('inswapper_128')
|
return 'inswapper_128'
|
||||||
return create_static_model_set('full').get(face_swapper_model)
|
return face_swapper_model
|
||||||
|
|
||||||
|
|
||||||
def register_args(program : ArgumentParser) -> None:
|
def register_args(program : ArgumentParser) -> None:
|
||||||
|
|||||||
@@ -129,13 +129,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('frame_colorizer_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('frame_colorizer_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||||
|
|||||||
@@ -386,26 +386,32 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ get_frame_enhancer_model() ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ get_frame_enhancer_model() ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
frame_enhancer_model = get_frame_enhancer_model()
|
||||||
|
return create_static_model_set('full').get(frame_enhancer_model)
|
||||||
|
|
||||||
|
|
||||||
|
def get_frame_enhancer_model() -> str:
|
||||||
frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
|
frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
|
||||||
|
|
||||||
if has_execution_provider('coreml'):
|
if has_execution_provider('coreml'):
|
||||||
if frame_enhancer_model == 'real_esrgan_x2_fp16':
|
if frame_enhancer_model == 'real_esrgan_x2_fp16':
|
||||||
return create_static_model_set('full').get('real_esrgan_x2')
|
return 'real_esrgan_x2'
|
||||||
if frame_enhancer_model == 'real_esrgan_x4_fp16':
|
if frame_enhancer_model == 'real_esrgan_x4_fp16':
|
||||||
return create_static_model_set('full').get('real_esrgan_x4')
|
return 'real_esrgan_x4'
|
||||||
if frame_enhancer_model == 'real_esrgan_x8_fp16':
|
if frame_enhancer_model == 'real_esrgan_x8_fp16':
|
||||||
return create_static_model_set('full').get('real_esrgan_x8')
|
return 'real_esrgan_x8'
|
||||||
return create_static_model_set('full').get(frame_enhancer_model)
|
return frame_enhancer_model
|
||||||
|
|
||||||
|
|
||||||
def register_args(program : ArgumentParser) -> None:
|
def register_args(program : ArgumentParser) -> None:
|
||||||
|
|||||||
@@ -74,13 +74,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ state_manager.get_item('lip_syncer_model') ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ state_manager.get_item('lip_syncer_model') ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -38,13 +38,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
|||||||
|
|
||||||
|
|
||||||
def get_inference_pool() -> InferencePool:
|
def get_inference_pool() -> InferencePool:
|
||||||
|
model_names = [ 'kim_vocal_2' ]
|
||||||
model_sources = get_model_options().get('sources')
|
model_sources = get_model_options().get('sources')
|
||||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||||
|
|
||||||
|
|
||||||
def clear_inference_pool() -> None:
|
def clear_inference_pool() -> None:
|
||||||
model_sources = get_model_options().get('sources')
|
model_names = [ 'kim_vocal_2' ]
|
||||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
inference_manager.clear_inference_pool(__name__, model_names)
|
||||||
|
|
||||||
|
|
||||||
def get_model_options() -> ModelOptions:
|
def get_model_options() -> ModelOptions:
|
||||||
|
|||||||
@@ -16,16 +16,17 @@ def before_all() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_get_inference_pool() -> None:
|
def test_get_inference_pool() -> None:
|
||||||
|
model_names = [ 'yolo_nsfw' ]
|
||||||
model_sources = content_analyser.get_model_options().get('sources')
|
model_sources = content_analyser.get_model_options().get('sources')
|
||||||
|
|
||||||
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
|
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
|
||||||
get_inference_pool('test', model_sources)
|
get_inference_pool('facefusion.content_analyser', model_names, model_sources)
|
||||||
|
|
||||||
assert isinstance(INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession)
|
assert isinstance(INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession)
|
||||||
|
|
||||||
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
|
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
|
||||||
get_inference_pool('test', model_sources)
|
get_inference_pool('facefusion.content_analyser', model_names, model_sources)
|
||||||
|
|
||||||
assert isinstance(INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser'), InferenceSession)
|
assert isinstance(INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession)
|
||||||
|
|
||||||
assert INFERENCE_POOLS.get('cli').get('test.content_analyser.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('test.content_analyser.0.cpu').get('content_analyser')
|
assert INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser')
|
||||||
|
|||||||
Reference in New Issue
Block a user