diff --git a/facefusion/content_analyser.py b/facefusion/content_analyser.py index cef3095..b539dd2 100644 --- a/facefusion/content_analyser.py +++ b/facefusion/content_analyser.py @@ -43,8 +43,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ 'yolo_nsfw' ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -57,10 +58,10 @@ def get_model_options() -> ModelOptions: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool: diff --git a/facefusion/core.py b/facefusion/core.py index 9cf9662..8133619 100755 --- a/facefusion/core.py +++ b/facefusion/core.py @@ -149,11 +149,11 @@ def force_download() -> ErrorCode: for module in common_modules + processor_modules: if hasattr(module, 'create_static_model_set'): for model in module.create_static_model_set(state_manager.get_item('download_scope')).values(): - model_hashes = model.get('hashes') - model_sources = model.get('sources') + model_hash_set = model.get('hashes') + model_source_set = model.get('sources') - if model_hashes and model_sources: - if not conditional_download_hashes(model_hashes) or not conditional_download_sources(model_sources): + if model_hash_set and model_source_set: + if not conditional_download_hashes(model_hash_set) or not conditional_download_sources(model_source_set): return 1 return 0 diff --git a/facefusion/download.py b/facefusion/download.py index cafd6bc..f0c92f4 100644 --- a/facefusion/download.py +++ b/facefusion/download.py @@ -70,17 +70,17 @@ def ping_static_url(url : str) -> bool: return process.returncode == 0 -def conditional_download_hashes(hashes : DownloadSet) -> bool: - hash_paths = [ hashes.get(hash_key).get('path') for hash_key in hashes.keys() ] +def conditional_download_hashes(hash_set : DownloadSet) -> bool: + hash_paths = [ hash_set.get(hash_key).get('path') for hash_key in hash_set.keys() ] process_manager.check() _, invalid_hash_paths = validate_hash_paths(hash_paths) if invalid_hash_paths: - for index in hashes: - if hashes.get(index).get('path') in invalid_hash_paths: - invalid_hash_url = hashes.get(index).get('url') + for index in hash_set: + if hash_set.get(index).get('path') in invalid_hash_paths: + invalid_hash_url = hash_set.get(index).get('url') if invalid_hash_url: - download_directory_path = os.path.dirname(hashes.get(index).get('path')) + download_directory_path = os.path.dirname(hash_set.get(index).get('path')) conditional_download(download_directory_path, [ invalid_hash_url ]) valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths) @@ -97,17 +97,17 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool: return not invalid_hash_paths -def conditional_download_sources(sources : DownloadSet) -> bool: - source_paths = [ sources.get(source_key).get('path') for source_key in sources.keys() ] +def conditional_download_sources(source_set : DownloadSet) -> bool: + source_paths = [ source_set.get(source_key).get('path') for source_key in source_set.keys() ] process_manager.check() _, invalid_source_paths = validate_source_paths(source_paths) if invalid_source_paths: - for index in sources: - if sources.get(index).get('path') in invalid_source_paths: - invalid_source_url = sources.get(index).get('url') + for index in source_set: + if source_set.get(index).get('path') in invalid_source_paths: + invalid_source_url = source_set.get(index).get('url') if invalid_source_url: - download_directory_path = os.path.dirname(sources.get(index).get('path')) + download_directory_path = os.path.dirname(source_set.get(index).get('path')) conditional_download(download_directory_path, [ invalid_source_url ]) valid_source_paths, invalid_source_paths = validate_source_paths(source_paths) diff --git a/facefusion/face_classifier.py b/facefusion/face_classifier.py index c1c7637..3b09990 100644 --- a/facefusion/face_classifier.py +++ b/facefusion/face_classifier.py @@ -43,8 +43,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ 'fairface' ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -57,10 +58,10 @@ def get_model_options() -> ModelOptions: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]: diff --git a/facefusion/face_detector.py b/facefusion/face_detector.py index 3d0518d..c3532fd 100644 --- a/facefusion/face_detector.py +++ b/facefusion/face_detector.py @@ -79,8 +79,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('face_detector_model') ] - _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -89,22 +90,22 @@ def clear_inference_pool() -> None: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: - model_hashes = {} - model_sources = {} model_set = create_static_model_set('full') + model_hash_set = {} + model_source_set = {} for face_detector_model in [ 'retinaface', 'scrfd', 'yolo_face' ]: if state_manager.get_item('face_detector_model') in [ 'many', face_detector_model ]: - model_hashes[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model) - model_sources[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model) + model_hash_set[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model) + model_source_set[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model) - return model_hashes, model_sources + return model_hash_set, model_source_set def pre_check() -> bool: - model_hashes, model_sources = collect_model_downloads() + model_hash_set, model_source_set = collect_model_downloads() - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def detect_faces(vision_frame : VisionFrame) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: diff --git a/facefusion/face_helper.py b/facefusion/face_helper.py index 846d481..6020f97 100644 --- a/facefusion/face_helper.py +++ b/facefusion/face_helper.py @@ -7,7 +7,7 @@ from cv2.typing import Size from facefusion.types import Anchors, Angle, BoundingBox, Distance, FaceDetectorModel, FaceLandmark5, FaceLandmark68, Mask, Matrix, Points, Scale, Score, Translation, VisionFrame, WarpTemplate, WarpTemplateSet -WARP_TEMPLATES : WarpTemplateSet =\ +WARP_TEMPLATE_SET : WarpTemplateSet =\ { 'arcface_112_v1': numpy.array( [ @@ -69,7 +69,7 @@ WARP_TEMPLATES : WarpTemplateSet =\ def estimate_matrix_by_face_landmark_5(face_landmark_5 : FaceLandmark5, warp_template : WarpTemplate, crop_size : Size) -> Matrix: - normed_warp_template = WARP_TEMPLATES.get(warp_template) * crop_size + normed_warp_template = WARP_TEMPLATE_SET.get(warp_template) * crop_size affine_matrix = cv2.estimateAffinePartial2D(face_landmark_5, normed_warp_template, method = cv2.RANSAC, ransacReprojThreshold = 100)[0] return affine_matrix diff --git a/facefusion/face_landmarker.py b/facefusion/face_landmarker.py index a90948e..45e9093 100644 --- a/facefusion/face_landmarker.py +++ b/facefusion/face_landmarker.py @@ -80,8 +80,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ] - _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -91,27 +92,27 @@ def clear_inference_pool() -> None: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: model_set = create_static_model_set('full') - model_hashes =\ + model_hash_set =\ { 'fan_68_5': model_set.get('fan_68_5').get('hashes').get('fan_68_5') } - model_sources =\ + model_source_set =\ { 'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5') } for face_landmarker_model in [ '2dfan4', 'peppa_wutz' ]: if state_manager.get_item('face_landmarker_model') in [ 'many', face_landmarker_model ]: - model_hashes[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model) - model_sources[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model) + model_hash_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model) + model_source_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model) - return model_hashes, model_sources + return model_hash_set, model_source_set def pre_check() -> bool: - model_hashes, model_sources = collect_model_downloads() + model_hash_set, model_source_set = collect_model_downloads() - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def detect_face_landmarks(vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]: diff --git a/facefusion/face_masker.py b/facefusion/face_masker.py index 04ead79..09f5cb8 100755 --- a/facefusion/face_masker.py +++ b/facefusion/face_masker.py @@ -122,8 +122,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model')] - _, model_sources = collect_model_downloads() - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + _, model_source_set = collect_model_downloads() + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -132,27 +133,27 @@ def clear_inference_pool() -> None: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: - model_hashes = {} - model_sources = {} model_set = create_static_model_set('full') + model_hash_set = {} + model_source_set = {} for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]: if state_manager.get_item('face_occluder_model') == face_occluder_model: - model_hashes[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder') - model_sources[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder') + model_hash_set[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder') + model_source_set[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder') for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]: if state_manager.get_item('face_parser_model') == face_parser_model: - model_hashes[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser') - model_sources[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser') + model_hash_set[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser') + model_source_set[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser') - return model_hashes, model_sources + return model_hash_set, model_source_set def pre_check() -> bool: - model_hashes, model_sources = collect_model_downloads() + model_hash_set, model_source_set = collect_model_downloads() - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) @lru_cache(maxsize = None) diff --git a/facefusion/face_recognizer.py b/facefusion/face_recognizer.py index a794b0d..c289026 100644 --- a/facefusion/face_recognizer.py +++ b/facefusion/face_recognizer.py @@ -41,8 +41,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ 'arcface' ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -55,10 +56,10 @@ def get_model_options() -> ModelOptions: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def calc_embedding(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Embedding, Embedding]: diff --git a/facefusion/inference_manager.py b/facefusion/inference_manager.py index cf15f27..8c24f7e 100644 --- a/facefusion/inference_manager.py +++ b/facefusion/inference_manager.py @@ -10,15 +10,15 @@ from facefusion.execution import create_inference_session_providers from facefusion.filesystem import is_file from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet -INFERENCE_POOLS : InferencePoolSet =\ +INFERENCE_POOL_SET : InferencePoolSet =\ { - 'cli': {}, #type:ignore[typeddict-item] - 'ui': {} #type:ignore[typeddict-item] + 'cli': {}, + 'ui': {} } -def get_inference_pool(module_name : str, model_names : List[str], model_sources : DownloadSet) -> InferencePool: - global INFERENCE_POOLS +def get_inference_pool(module_name : str, model_names : List[str], model_source_set : DownloadSet) -> InferencePool: + global INFERENCE_POOL_SET while process_manager.is_checking(): sleep(0.5) @@ -27,21 +27,21 @@ def get_inference_pool(module_name : str, model_names : List[str], model_sources app_context = detect_app_context() inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) - if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): - INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) - if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): - INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) - if not INFERENCE_POOLS.get(app_context).get(inference_context): - INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers) + if app_context == 'cli' and INFERENCE_POOL_SET.get('ui').get(inference_context): + INFERENCE_POOL_SET['cli'][inference_context] = INFERENCE_POOL_SET.get('ui').get(inference_context) + if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context): + INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context) + if not INFERENCE_POOL_SET.get(app_context).get(inference_context): + INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers) - return INFERENCE_POOLS.get(app_context).get(inference_context) + return INFERENCE_POOL_SET.get(app_context).get(inference_context) -def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool: +def create_inference_pool(model_source_set : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool: inference_pool : InferencePool = {} - for model_name in model_sources.keys(): - model_path = model_sources.get(model_name).get('path') + for model_name in model_source_set.keys(): + model_path = model_source_set.get(model_name).get('path') if is_file(model_path): inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers) @@ -49,15 +49,15 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str def clear_inference_pool(module_name : str, model_names : List[str]) -> None: - global INFERENCE_POOLS + global INFERENCE_POOL_SET execution_device_id = state_manager.get_item('execution_device_id') execution_providers = resolve_execution_providers(module_name) app_context = detect_app_context() inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) - if INFERENCE_POOLS.get(app_context).get(inference_context): - del INFERENCE_POOLS[app_context][inference_context] + if INFERENCE_POOL_SET.get(app_context).get(inference_context): + del INFERENCE_POOL_SET[app_context][inference_context] def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession: diff --git a/facefusion/jobs/job_manager.py b/facefusion/jobs/job_manager.py index 1a45a49..58f46e5 100644 --- a/facefusion/jobs/job_manager.py +++ b/facefusion/jobs/job_manager.py @@ -82,11 +82,11 @@ def delete_jobs(halt_on_error : bool) -> bool: def find_jobs(job_status : JobStatus) -> JobSet: job_ids = find_job_ids(job_status) - jobs : JobSet = {} + job_set : JobSet = {} for job_id in job_ids: - jobs[job_id] = read_job_file(job_id) - return jobs + job_set[job_id] = read_job_file(job_id) + return job_set def find_job_ids(job_status : JobStatus) -> List[str]: @@ -188,7 +188,6 @@ def set_step_status(job_id : str, step_index : int, step_status : JobStepStatus) if job: steps = job.get('steps') - if has_step(job_id, step_index): steps[step_index]['status'] = step_status return update_job_file(job_id, job) diff --git a/facefusion/jobs/job_runner.py b/facefusion/jobs/job_runner.py index 30ad14b..23a0e38 100644 --- a/facefusion/jobs/job_runner.py +++ b/facefusion/jobs/job_runner.py @@ -101,12 +101,12 @@ def clean_steps(job_id: str) -> bool: def collect_output_set(job_id : str) -> JobOutputSet: steps = job_manager.get_steps(job_id) - output_set : JobOutputSet = {} + job_output_set : JobOutputSet = {} for index, step in enumerate(steps): output_path = step.get('args').get('output_path') if output_path: step_output_path = job_manager.get_step_output_path(job_id, index, output_path) - output_set.setdefault(output_path, []).append(step_output_path) - return output_set + job_output_set.setdefault(output_path, []).append(step_output_path) + return job_output_set diff --git a/facefusion/processors/modules/age_modifier.py b/facefusion/processors/modules/age_modifier.py index 1172878..4712964 100755 --- a/facefusion/processors/modules/age_modifier.py +++ b/facefusion/processors/modules/age_modifier.py @@ -65,8 +65,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('age_modifier_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -93,10 +94,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/deep_swapper.py b/facefusion/processors/modules/deep_swapper.py index 5d6e595..85d883f 100755 --- a/facefusion/processors/modules/deep_swapper.py +++ b/facefusion/processors/modules/deep_swapper.py @@ -241,8 +241,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('deep_swapper_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -278,11 +279,11 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - if model_hashes and model_sources: - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + if model_hash_set and model_source_set: + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) return True diff --git a/facefusion/processors/modules/expression_restorer.py b/facefusion/processors/modules/expression_restorer.py index c83ca76..ad3556f 100755 --- a/facefusion/processors/modules/expression_restorer.py +++ b/facefusion/processors/modules/expression_restorer.py @@ -76,8 +76,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('expression_restorer_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -104,10 +105,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/face_editor.py b/facefusion/processors/modules/face_editor.py index 87af4d3..835b60e 100755 --- a/facefusion/processors/modules/face_editor.py +++ b/facefusion/processors/modules/face_editor.py @@ -106,8 +106,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('face_editor_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -160,10 +161,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/face_enhancer.py b/facefusion/processors/modules/face_enhancer.py index bfdb06a..1634efe 100755 --- a/facefusion/processors/modules/face_enhancer.py +++ b/facefusion/processors/modules/face_enhancer.py @@ -223,8 +223,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('face_enhancer_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -253,10 +254,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/face_swapper.py b/facefusion/processors/modules/face_swapper.py index 0adfc76..c899bbb 100755 --- a/facefusion/processors/modules/face_swapper.py +++ b/facefusion/processors/modules/face_swapper.py @@ -337,8 +337,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ get_face_swapper_model() ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -375,10 +376,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/frame_colorizer.py b/facefusion/processors/modules/frame_colorizer.py index 1daf35c..1b00af9 100644 --- a/facefusion/processors/modules/frame_colorizer.py +++ b/facefusion/processors/modules/frame_colorizer.py @@ -130,8 +130,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('frame_colorizer_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -166,10 +167,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/frame_enhancer.py b/facefusion/processors/modules/frame_enhancer.py index fd36a16..645963a 100644 --- a/facefusion/processors/modules/frame_enhancer.py +++ b/facefusion/processors/modules/frame_enhancer.py @@ -387,8 +387,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ get_frame_enhancer_model() ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -428,10 +429,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/processors/modules/lip_syncer.py b/facefusion/processors/modules/lip_syncer.py index d8e7929..3c4f9a4 100755 --- a/facefusion/processors/modules/lip_syncer.py +++ b/facefusion/processors/modules/lip_syncer.py @@ -75,8 +75,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ state_manager.get_item('lip_syncer_model') ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -101,10 +102,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def pre_process(mode : ProcessMode) -> bool: diff --git a/facefusion/state_manager.py b/facefusion/state_manager.py index 23ba08a..aba6c57 100644 --- a/facefusion/state_manager.py +++ b/facefusion/state_manager.py @@ -1,37 +1,37 @@ from typing import Any, Union from facefusion.app_context import detect_app_context -from facefusion.processors.types import ProcessorState, ProcessorStateKey +from facefusion.processors.types import ProcessorState, ProcessorStateKey, ProcessorStateSet from facefusion.types import State, StateKey, StateSet -STATES : Union[StateSet, ProcessorState] =\ +STATE_SET : Union[StateSet, ProcessorStateSet] =\ { - 'cli': {}, #type:ignore[typeddict-item] - 'ui': {} #type:ignore[typeddict-item] + 'cli': {}, #type:ignore[assignment] + 'ui': {} #type:ignore[assignment] } def get_state() -> Union[State, ProcessorState]: app_context = detect_app_context() - return STATES.get(app_context) #type:ignore + return STATE_SET.get(app_context) def init_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None: - STATES['cli'][key] = value #type:ignore - STATES['ui'][key] = value #type:ignore + STATE_SET['cli'][key] = value #type:ignore[literal-required] + STATE_SET['ui'][key] = value #type:ignore[literal-required] def get_item(key : Union[StateKey, ProcessorStateKey]) -> Any: - return get_state().get(key) #type:ignore + return get_state().get(key) #type:ignore[literal-required] def set_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None: app_context = detect_app_context() - STATES[app_context][key] = value #type:ignore + STATE_SET[app_context][key] = value #type:ignore[literal-required] def sync_item(key : Union[StateKey, ProcessorStateKey]) -> None: - STATES['cli'][key] = STATES.get('ui').get(key) #type:ignore + STATE_SET['cli'][key] = STATE_SET.get('ui').get(key) #type:ignore[literal-required] def clear_item(key : Union[StateKey, ProcessorStateKey]) -> None: diff --git a/facefusion/uis/components/instant_runner.py b/facefusion/uis/components/instant_runner.py index fce29bd..71a3f7a 100644 --- a/facefusion/uis/components/instant_runner.py +++ b/facefusion/uis/components/instant_runner.py @@ -92,7 +92,7 @@ def create_and_run_job(step_args : Args) -> bool: job_id = job_helper.suggest_job_id('ui') for key in job_store.get_job_keys(): - state_manager.sync_item(key) #type:ignore + state_manager.sync_item(key) #type:ignore[arg-type] return job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step) diff --git a/facefusion/uis/components/job_runner.py b/facefusion/uis/components/job_runner.py index 55a53b9..df69eb0 100644 --- a/facefusion/uis/components/job_runner.py +++ b/facefusion/uis/components/job_runner.py @@ -84,7 +84,7 @@ def run(job_action : JobRunnerAction, job_id : str) -> Tuple[gradio.Button, grad job_id = convert_str_none(job_id) for key in job_store.get_job_keys(): - state_manager.sync_item(key) #type:ignore + state_manager.sync_item(key) #type:ignore[arg-type] if job_action == 'job-run': logger.info(wording.get('running_job').format(job_id = job_id), __name__) diff --git a/facefusion/voice_extractor.py b/facefusion/voice_extractor.py index a1f2ab1..6fca54a 100644 --- a/facefusion/voice_extractor.py +++ b/facefusion/voice_extractor.py @@ -39,8 +39,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet: def get_inference_pool() -> InferencePool: model_names = [ 'kim_vocal_2' ] - model_sources = get_model_options().get('sources') - return inference_manager.get_inference_pool(__name__, model_names, model_sources) + model_source_set = get_model_options().get('sources') + + return inference_manager.get_inference_pool(__name__, model_names, model_source_set) def clear_inference_pool() -> None: @@ -53,10 +54,10 @@ def get_model_options() -> ModelOptions: def pre_check() -> bool: - model_hashes = get_model_options().get('hashes') - model_sources = get_model_options().get('sources') + model_hash_set = get_model_options().get('hashes') + model_source_set = get_model_options().get('sources') - return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) + return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set) def batch_extract_voice(audio : Audio, chunk_size : int, step_size : int) -> Audio: diff --git a/tests/test_inference_manager.py b/tests/test_inference_manager.py index 797c808..62bd509 100644 --- a/tests/test_inference_manager.py +++ b/tests/test_inference_manager.py @@ -4,7 +4,7 @@ import pytest from onnxruntime import InferenceSession from facefusion import content_analyser, state_manager -from facefusion.inference_manager import INFERENCE_POOLS, get_inference_pool +from facefusion.inference_manager import INFERENCE_POOL_SET, get_inference_pool @pytest.fixture(scope = 'module', autouse = True) @@ -17,16 +17,16 @@ def before_all() -> None: def test_get_inference_pool() -> None: model_names = [ 'yolo_nsfw' ] - model_sources = content_analyser.get_model_options().get('sources') + model_source_set = content_analyser.get_model_options().get('sources') with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): - get_inference_pool('facefusion.content_analyser', model_names, model_sources) + get_inference_pool('facefusion.content_analyser', model_names, model_source_set) - assert isinstance(INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): - get_inference_pool('facefusion.content_analyser', model_names, model_sources) + get_inference_pool('facefusion.content_analyser', model_names, model_source_set) - assert isinstance(INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) + assert isinstance(INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) - assert INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') + assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser')