Qa/follow set naming (#867)

* Follow set naming

* Follow set naming

* Disable type hints

* Uniform order
This commit is contained in:
Henry Ruhs
2025-02-09 09:35:56 +01:00
committed by henryruhs
parent 1bdc02014c
commit f3bbd3e16f
26 changed files with 172 additions and 157 deletions

View File

@@ -43,8 +43,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ 'yolo_nsfw' ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -57,10 +58,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:

View File

@@ -149,11 +149,11 @@ def force_download() -> ErrorCode:
for module in common_modules + processor_modules:
if hasattr(module, 'create_static_model_set'):
for model in module.create_static_model_set(state_manager.get_item('download_scope')).values():
model_hashes = model.get('hashes')
model_sources = model.get('sources')
model_hash_set = model.get('hashes')
model_source_set = model.get('sources')
if model_hashes and model_sources:
if not conditional_download_hashes(model_hashes) or not conditional_download_sources(model_sources):
if model_hash_set and model_source_set:
if not conditional_download_hashes(model_hash_set) or not conditional_download_sources(model_source_set):
return 1
return 0

View File

@@ -70,17 +70,17 @@ def ping_static_url(url : str) -> bool:
return process.returncode == 0
def conditional_download_hashes(hashes : DownloadSet) -> bool:
hash_paths = [ hashes.get(hash_key).get('path') for hash_key in hashes.keys() ]
def conditional_download_hashes(hash_set : DownloadSet) -> bool:
hash_paths = [ hash_set.get(hash_key).get('path') for hash_key in hash_set.keys() ]
process_manager.check()
_, invalid_hash_paths = validate_hash_paths(hash_paths)
if invalid_hash_paths:
for index in hashes:
if hashes.get(index).get('path') in invalid_hash_paths:
invalid_hash_url = hashes.get(index).get('url')
for index in hash_set:
if hash_set.get(index).get('path') in invalid_hash_paths:
invalid_hash_url = hash_set.get(index).get('url')
if invalid_hash_url:
download_directory_path = os.path.dirname(hashes.get(index).get('path'))
download_directory_path = os.path.dirname(hash_set.get(index).get('path'))
conditional_download(download_directory_path, [ invalid_hash_url ])
valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths)
@@ -97,17 +97,17 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool:
return not invalid_hash_paths
def conditional_download_sources(sources : DownloadSet) -> bool:
source_paths = [ sources.get(source_key).get('path') for source_key in sources.keys() ]
def conditional_download_sources(source_set : DownloadSet) -> bool:
source_paths = [ source_set.get(source_key).get('path') for source_key in source_set.keys() ]
process_manager.check()
_, invalid_source_paths = validate_source_paths(source_paths)
if invalid_source_paths:
for index in sources:
if sources.get(index).get('path') in invalid_source_paths:
invalid_source_url = sources.get(index).get('url')
for index in source_set:
if source_set.get(index).get('path') in invalid_source_paths:
invalid_source_url = source_set.get(index).get('url')
if invalid_source_url:
download_directory_path = os.path.dirname(sources.get(index).get('path'))
download_directory_path = os.path.dirname(source_set.get(index).get('path'))
conditional_download(download_directory_path, [ invalid_source_url ])
valid_source_paths, invalid_source_paths = validate_source_paths(source_paths)

View File

@@ -43,8 +43,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ 'fairface' ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -57,10 +58,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]:

View File

@@ -79,8 +79,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_detector_model') ]
_, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
_, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -89,22 +90,22 @@ def clear_inference_pool() -> None:
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_hashes = {}
model_sources = {}
model_set = create_static_model_set('full')
model_hash_set = {}
model_source_set = {}
for face_detector_model in [ 'retinaface', 'scrfd', 'yolo_face' ]:
if state_manager.get_item('face_detector_model') in [ 'many', face_detector_model ]:
model_hashes[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model)
model_sources[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model)
model_hash_set[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model)
model_source_set[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model)
return model_hashes, model_sources
return model_hash_set, model_source_set
def pre_check() -> bool:
model_hashes, model_sources = collect_model_downloads()
model_hash_set, model_source_set = collect_model_downloads()
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def detect_faces(vision_frame : VisionFrame) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]:

View File

@@ -7,7 +7,7 @@ from cv2.typing import Size
from facefusion.types import Anchors, Angle, BoundingBox, Distance, FaceDetectorModel, FaceLandmark5, FaceLandmark68, Mask, Matrix, Points, Scale, Score, Translation, VisionFrame, WarpTemplate, WarpTemplateSet
WARP_TEMPLATES : WarpTemplateSet =\
WARP_TEMPLATE_SET : WarpTemplateSet =\
{
'arcface_112_v1': numpy.array(
[
@@ -69,7 +69,7 @@ WARP_TEMPLATES : WarpTemplateSet =\
def estimate_matrix_by_face_landmark_5(face_landmark_5 : FaceLandmark5, warp_template : WarpTemplate, crop_size : Size) -> Matrix:
normed_warp_template = WARP_TEMPLATES.get(warp_template) * crop_size
normed_warp_template = WARP_TEMPLATE_SET.get(warp_template) * crop_size
affine_matrix = cv2.estimateAffinePartial2D(face_landmark_5, normed_warp_template, method = cv2.RANSAC, ransacReprojThreshold = 100)[0]
return affine_matrix

View File

@@ -80,8 +80,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ]
_, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
_, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -91,27 +92,27 @@ def clear_inference_pool() -> None:
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_set = create_static_model_set('full')
model_hashes =\
model_hash_set =\
{
'fan_68_5': model_set.get('fan_68_5').get('hashes').get('fan_68_5')
}
model_sources =\
model_source_set =\
{
'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5')
}
for face_landmarker_model in [ '2dfan4', 'peppa_wutz' ]:
if state_manager.get_item('face_landmarker_model') in [ 'many', face_landmarker_model ]:
model_hashes[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model)
model_sources[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model)
model_hash_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model)
model_source_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model)
return model_hashes, model_sources
return model_hash_set, model_source_set
def pre_check() -> bool:
model_hashes, model_sources = collect_model_downloads()
model_hash_set, model_source_set = collect_model_downloads()
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def detect_face_landmarks(vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]:

View File

@@ -122,8 +122,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model')]
_, model_sources = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
_, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -132,27 +133,27 @@ def clear_inference_pool() -> None:
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_hashes = {}
model_sources = {}
model_set = create_static_model_set('full')
model_hash_set = {}
model_source_set = {}
for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]:
if state_manager.get_item('face_occluder_model') == face_occluder_model:
model_hashes[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder')
model_sources[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder')
model_hash_set[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder')
model_source_set[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder')
for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]:
if state_manager.get_item('face_parser_model') == face_parser_model:
model_hashes[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser')
model_sources[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser')
model_hash_set[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser')
model_source_set[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser')
return model_hashes, model_sources
return model_hash_set, model_source_set
def pre_check() -> bool:
model_hashes, model_sources = collect_model_downloads()
model_hash_set, model_source_set = collect_model_downloads()
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
@lru_cache(maxsize = None)

View File

@@ -41,8 +41,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ 'arcface' ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -55,10 +56,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def calc_embedding(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Embedding, Embedding]:

View File

@@ -10,15 +10,15 @@ from facefusion.execution import create_inference_session_providers
from facefusion.filesystem import is_file
from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
INFERENCE_POOLS : InferencePoolSet =\
INFERENCE_POOL_SET : InferencePoolSet =\
{
'cli': {}, #type:ignore[typeddict-item]
'ui': {} #type:ignore[typeddict-item]
'cli': {},
'ui': {}
}
def get_inference_pool(module_name : str, model_names : List[str], model_sources : DownloadSet) -> InferencePool:
global INFERENCE_POOLS
def get_inference_pool(module_name : str, model_names : List[str], model_source_set : DownloadSet) -> InferencePool:
global INFERENCE_POOL_SET
while process_manager.is_checking():
sleep(0.5)
@@ -27,21 +27,21 @@ def get_inference_pool(module_name : str, model_names : List[str], model_sources
app_context = detect_app_context()
inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers)
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context):
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context)
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context):
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context)
if not INFERENCE_POOLS.get(app_context).get(inference_context):
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers)
if app_context == 'cli' and INFERENCE_POOL_SET.get('ui').get(inference_context):
INFERENCE_POOL_SET['cli'][inference_context] = INFERENCE_POOL_SET.get('ui').get(inference_context)
if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context):
INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context)
if not INFERENCE_POOL_SET.get(app_context).get(inference_context):
INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers)
return INFERENCE_POOLS.get(app_context).get(inference_context)
return INFERENCE_POOL_SET.get(app_context).get(inference_context)
def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool:
def create_inference_pool(model_source_set : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool:
inference_pool : InferencePool = {}
for model_name in model_sources.keys():
model_path = model_sources.get(model_name).get('path')
for model_name in model_source_set.keys():
model_path = model_source_set.get(model_name).get('path')
if is_file(model_path):
inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers)
@@ -49,15 +49,15 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
global INFERENCE_POOLS
global INFERENCE_POOL_SET
execution_device_id = state_manager.get_item('execution_device_id')
execution_providers = resolve_execution_providers(module_name)
app_context = detect_app_context()
inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers)
if INFERENCE_POOLS.get(app_context).get(inference_context):
del INFERENCE_POOLS[app_context][inference_context]
if INFERENCE_POOL_SET.get(app_context).get(inference_context):
del INFERENCE_POOL_SET[app_context][inference_context]
def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession:

View File

@@ -82,11 +82,11 @@ def delete_jobs(halt_on_error : bool) -> bool:
def find_jobs(job_status : JobStatus) -> JobSet:
job_ids = find_job_ids(job_status)
jobs : JobSet = {}
job_set : JobSet = {}
for job_id in job_ids:
jobs[job_id] = read_job_file(job_id)
return jobs
job_set[job_id] = read_job_file(job_id)
return job_set
def find_job_ids(job_status : JobStatus) -> List[str]:
@@ -188,7 +188,6 @@ def set_step_status(job_id : str, step_index : int, step_status : JobStepStatus)
if job:
steps = job.get('steps')
if has_step(job_id, step_index):
steps[step_index]['status'] = step_status
return update_job_file(job_id, job)

View File

@@ -101,12 +101,12 @@ def clean_steps(job_id: str) -> bool:
def collect_output_set(job_id : str) -> JobOutputSet:
steps = job_manager.get_steps(job_id)
output_set : JobOutputSet = {}
job_output_set : JobOutputSet = {}
for index, step in enumerate(steps):
output_path = step.get('args').get('output_path')
if output_path:
step_output_path = job_manager.get_step_output_path(job_id, index, output_path)
output_set.setdefault(output_path, []).append(step_output_path)
return output_set
job_output_set.setdefault(output_path, []).append(step_output_path)
return job_output_set

View File

@@ -65,8 +65,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('age_modifier_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -93,10 +94,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -241,8 +241,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('deep_swapper_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -278,11 +279,11 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
if model_hashes and model_sources:
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
if model_hash_set and model_source_set:
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
return True

View File

@@ -76,8 +76,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('expression_restorer_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -104,10 +105,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -106,8 +106,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_editor_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -160,10 +161,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -223,8 +223,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_enhancer_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -253,10 +254,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -337,8 +337,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ get_face_swapper_model() ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -375,10 +376,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -130,8 +130,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('frame_colorizer_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -166,10 +167,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -387,8 +387,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ get_frame_enhancer_model() ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -428,10 +429,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -75,8 +75,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('lip_syncer_model') ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -101,10 +102,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool:

View File

@@ -1,37 +1,37 @@
from typing import Any, Union
from facefusion.app_context import detect_app_context
from facefusion.processors.types import ProcessorState, ProcessorStateKey
from facefusion.processors.types import ProcessorState, ProcessorStateKey, ProcessorStateSet
from facefusion.types import State, StateKey, StateSet
STATES : Union[StateSet, ProcessorState] =\
STATE_SET : Union[StateSet, ProcessorStateSet] =\
{
'cli': {}, #type:ignore[typeddict-item]
'ui': {} #type:ignore[typeddict-item]
'cli': {}, #type:ignore[assignment]
'ui': {} #type:ignore[assignment]
}
def get_state() -> Union[State, ProcessorState]:
app_context = detect_app_context()
return STATES.get(app_context) #type:ignore
return STATE_SET.get(app_context)
def init_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None:
STATES['cli'][key] = value #type:ignore
STATES['ui'][key] = value #type:ignore
STATE_SET['cli'][key] = value #type:ignore[literal-required]
STATE_SET['ui'][key] = value #type:ignore[literal-required]
def get_item(key : Union[StateKey, ProcessorStateKey]) -> Any:
return get_state().get(key) #type:ignore
return get_state().get(key) #type:ignore[literal-required]
def set_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None:
app_context = detect_app_context()
STATES[app_context][key] = value #type:ignore
STATE_SET[app_context][key] = value #type:ignore[literal-required]
def sync_item(key : Union[StateKey, ProcessorStateKey]) -> None:
STATES['cli'][key] = STATES.get('ui').get(key) #type:ignore
STATE_SET['cli'][key] = STATE_SET.get('ui').get(key) #type:ignore[literal-required]
def clear_item(key : Union[StateKey, ProcessorStateKey]) -> None:

View File

@@ -92,7 +92,7 @@ def create_and_run_job(step_args : Args) -> bool:
job_id = job_helper.suggest_job_id('ui')
for key in job_store.get_job_keys():
state_manager.sync_item(key) #type:ignore
state_manager.sync_item(key) #type:ignore[arg-type]
return job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step)

View File

@@ -84,7 +84,7 @@ def run(job_action : JobRunnerAction, job_id : str) -> Tuple[gradio.Button, grad
job_id = convert_str_none(job_id)
for key in job_store.get_job_keys():
state_manager.sync_item(key) #type:ignore
state_manager.sync_item(key) #type:ignore[arg-type]
if job_action == 'job-run':
logger.info(wording.get('running_job').format(job_id = job_id), __name__)

View File

@@ -39,8 +39,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool:
model_names = [ 'kim_vocal_2' ]
model_sources = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None:
@@ -53,10 +54,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool:
model_hashes = get_model_options().get('hashes')
model_sources = get_model_options().get('sources')
model_hash_set = get_model_options().get('hashes')
model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources)
return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def batch_extract_voice(audio : Audio, chunk_size : int, step_size : int) -> Audio: