Use module_name everywhere

This commit is contained in:
henryruhs
2025-04-16 12:06:36 +02:00
parent 84b9b60e6e
commit 72a0edb6ba
10 changed files with 32 additions and 32 deletions

View File

@@ -76,8 +76,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
age_modifier_model = state_manager.get_item('age_modifier_model')
return create_static_model_set('full').get(age_modifier_model)
model_name = state_manager.get_item('age_modifier_model')
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None:

View File

@@ -252,8 +252,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
deep_swapper_model = state_manager.get_item('deep_swapper_model')
return create_static_model_set('full').get(deep_swapper_model)
model_name = state_manager.get_item('deep_swapper_model')
return create_static_model_set('full').get(model_name)
def get_model_size() -> Size:

View File

@@ -87,8 +87,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
expression_restorer_model = state_manager.get_item('expression_restorer_model')
return create_static_model_set('full').get(expression_restorer_model)
model_name = state_manager.get_item('expression_restorer_model')
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None:

View File

@@ -117,8 +117,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
face_editor_model = state_manager.get_item('face_editor_model')
return create_static_model_set('full').get(face_editor_model)
model_name = state_manager.get_item('face_editor_model')
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None:

View File

@@ -234,8 +234,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
face_enhancer_model = state_manager.get_item('face_enhancer_model')
return create_static_model_set('full').get(face_enhancer_model)
model_name = state_manager.get_item('face_enhancer_model')
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None:

View File

@@ -336,28 +336,28 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ get_face_swapper_model() ]
model_names = [ get_model_name() ]
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
model_names = [ get_face_swapper_model() ]
model_names = [ get_model_name() ]
inference_manager.clear_inference_pool(__name__, model_names)
def get_model_options() -> ModelOptions:
face_swapper_model = get_face_swapper_model()
return create_static_model_set('full').get(face_swapper_model)
model_name = get_model_name()
return create_static_model_set('full').get(model_name)
def get_face_swapper_model() -> str:
face_swapper_model = state_manager.get_item('face_swapper_model')
def get_model_name() -> str:
model_name = state_manager.get_item('face_swapper_model')
if has_execution_provider('coreml') and face_swapper_model == 'inswapper_128_fp16':
if has_execution_provider('coreml') and model_name == 'inswapper_128_fp16':
return 'inswapper_128'
return face_swapper_model
return model_name
def register_args(program : ArgumentParser) -> None:

View File

@@ -147,8 +147,8 @@ def resolve_execution_providers() -> List[ExecutionProvider]:
def get_model_options() -> ModelOptions:
frame_colorizer_model = state_manager.get_item('frame_colorizer_model')
return create_static_model_set('full').get(frame_colorizer_model)
model_name = state_manager.get_item('frame_colorizer_model')
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None:

View File

@@ -398,8 +398,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
frame_enhancer_model = get_frame_enhancer_model()
return create_static_model_set('full').get(frame_enhancer_model)
model_name = get_frame_enhancer_model()
return create_static_model_set('full').get(model_name)
def get_frame_enhancer_model() -> str:

View File

@@ -86,8 +86,8 @@ def clear_inference_pool() -> None:
def get_model_options() -> ModelOptions:
lip_syncer_model = state_manager.get_item('lip_syncer_model')
return create_static_model_set('full').get(lip_syncer_model)
model_name = state_manager.get_item('lip_syncer_model')
return create_static_model_set('full').get(model_name)
def register_args(program : ArgumentParser) -> None: