Fix broken inference pools part2
This commit is contained in:
@@ -64,13 +64,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('age_modifier_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('age_modifier_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -240,13 +240,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('deep_swapper_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('deep_swapper_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -75,13 +75,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('expression_restorer_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('expression_restorer_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -105,13 +105,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('face_editor_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('face_editor_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -222,13 +222,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('face_enhancer_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('face_enhancer_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
@@ -336,21 +336,27 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ get_face_swapper_model() ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ get_face_swapper_model() ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
face_swapper_model = get_face_swapper_model()
|
||||
return create_static_model_set('full').get(face_swapper_model)
|
||||
|
||||
|
||||
def get_face_swapper_model() -> str:
|
||||
face_swapper_model = state_manager.get_item('face_swapper_model')
|
||||
|
||||
if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16':
|
||||
return create_static_model_set('full').get('inswapper_128')
|
||||
return create_static_model_set('full').get(face_swapper_model)
|
||||
return 'inswapper_128'
|
||||
return face_swapper_model
|
||||
|
||||
|
||||
def register_args(program : ArgumentParser) -> None:
|
||||
|
||||
@@ -129,13 +129,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('frame_colorizer_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('frame_colorizer_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def resolve_execution_providers() -> List[ExecutionProvider]:
|
||||
|
||||
@@ -386,26 +386,32 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ get_frame_enhancer_model() ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ get_frame_enhancer_model() ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
frame_enhancer_model = get_frame_enhancer_model()
|
||||
return create_static_model_set('full').get(frame_enhancer_model)
|
||||
|
||||
|
||||
def get_frame_enhancer_model() -> str:
|
||||
frame_enhancer_model = state_manager.get_item('frame_enhancer_model')
|
||||
|
||||
if has_execution_provider('coreml'):
|
||||
if frame_enhancer_model == 'real_esrgan_x2_fp16':
|
||||
return create_static_model_set('full').get('real_esrgan_x2')
|
||||
return 'real_esrgan_x2'
|
||||
if frame_enhancer_model == 'real_esrgan_x4_fp16':
|
||||
return create_static_model_set('full').get('real_esrgan_x4')
|
||||
return 'real_esrgan_x4'
|
||||
if frame_enhancer_model == 'real_esrgan_x8_fp16':
|
||||
return create_static_model_set('full').get('real_esrgan_x8')
|
||||
return create_static_model_set('full').get(frame_enhancer_model)
|
||||
return 'real_esrgan_x8'
|
||||
return frame_enhancer_model
|
||||
|
||||
|
||||
def register_args(program : ArgumentParser) -> None:
|
||||
|
||||
@@ -74,13 +74,14 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
|
||||
|
||||
def get_inference_pool() -> InferencePool:
|
||||
model_names = [ state_manager.get_item('lip_syncer_model') ]
|
||||
model_sources = get_model_options().get('sources')
|
||||
return inference_manager.get_inference_pool(__name__, model_sources)
|
||||
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
|
||||
|
||||
|
||||
def clear_inference_pool() -> None:
|
||||
model_sources = get_model_options().get('sources')
|
||||
inference_manager.clear_inference_pool(__name__, model_sources)
|
||||
model_names = [ state_manager.get_item('lip_syncer_model') ]
|
||||
inference_manager.clear_inference_pool(__name__, model_names)
|
||||
|
||||
|
||||
def get_model_options() -> ModelOptions:
|
||||
|
||||
Reference in New Issue
Block a user