Qa/follow set naming (#867)

* Follow set naming

* Follow set naming

* Disable type hints

* Uniform order
This commit is contained in:
Henry Ruhs
2025-02-09 09:35:56 +01:00
committed by henryruhs
parent 1bdc02014c
commit f3bbd3e16f
26 changed files with 172 additions and 157 deletions

View File

@@ -43,8 +43,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ 'yolo_nsfw' ] model_names = [ 'yolo_nsfw' ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -57,10 +58,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool: def analyse_stream(vision_frame : VisionFrame, video_fps : Fps) -> bool:

View File

@@ -149,11 +149,11 @@ def force_download() -> ErrorCode:
for module in common_modules + processor_modules: for module in common_modules + processor_modules:
if hasattr(module, 'create_static_model_set'): if hasattr(module, 'create_static_model_set'):
for model in module.create_static_model_set(state_manager.get_item('download_scope')).values(): for model in module.create_static_model_set(state_manager.get_item('download_scope')).values():
model_hashes = model.get('hashes') model_hash_set = model.get('hashes')
model_sources = model.get('sources') model_source_set = model.get('sources')
if model_hashes and model_sources: if model_hash_set and model_source_set:
if not conditional_download_hashes(model_hashes) or not conditional_download_sources(model_sources): if not conditional_download_hashes(model_hash_set) or not conditional_download_sources(model_source_set):
return 1 return 1
return 0 return 0

View File

@@ -70,17 +70,17 @@ def ping_static_url(url : str) -> bool:
return process.returncode == 0 return process.returncode == 0
def conditional_download_hashes(hashes : DownloadSet) -> bool: def conditional_download_hashes(hash_set : DownloadSet) -> bool:
hash_paths = [ hashes.get(hash_key).get('path') for hash_key in hashes.keys() ] hash_paths = [ hash_set.get(hash_key).get('path') for hash_key in hash_set.keys() ]
process_manager.check() process_manager.check()
_, invalid_hash_paths = validate_hash_paths(hash_paths) _, invalid_hash_paths = validate_hash_paths(hash_paths)
if invalid_hash_paths: if invalid_hash_paths:
for index in hashes: for index in hash_set:
if hashes.get(index).get('path') in invalid_hash_paths: if hash_set.get(index).get('path') in invalid_hash_paths:
invalid_hash_url = hashes.get(index).get('url') invalid_hash_url = hash_set.get(index).get('url')
if invalid_hash_url: if invalid_hash_url:
download_directory_path = os.path.dirname(hashes.get(index).get('path')) download_directory_path = os.path.dirname(hash_set.get(index).get('path'))
conditional_download(download_directory_path, [ invalid_hash_url ]) conditional_download(download_directory_path, [ invalid_hash_url ])
valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths) valid_hash_paths, invalid_hash_paths = validate_hash_paths(hash_paths)
@@ -97,17 +97,17 @@ def conditional_download_hashes(hashes : DownloadSet) -> bool:
return not invalid_hash_paths return not invalid_hash_paths
def conditional_download_sources(sources : DownloadSet) -> bool: def conditional_download_sources(source_set : DownloadSet) -> bool:
source_paths = [ sources.get(source_key).get('path') for source_key in sources.keys() ] source_paths = [ source_set.get(source_key).get('path') for source_key in source_set.keys() ]
process_manager.check() process_manager.check()
_, invalid_source_paths = validate_source_paths(source_paths) _, invalid_source_paths = validate_source_paths(source_paths)
if invalid_source_paths: if invalid_source_paths:
for index in sources: for index in source_set:
if sources.get(index).get('path') in invalid_source_paths: if source_set.get(index).get('path') in invalid_source_paths:
invalid_source_url = sources.get(index).get('url') invalid_source_url = source_set.get(index).get('url')
if invalid_source_url: if invalid_source_url:
download_directory_path = os.path.dirname(sources.get(index).get('path')) download_directory_path = os.path.dirname(source_set.get(index).get('path'))
conditional_download(download_directory_path, [ invalid_source_url ]) conditional_download(download_directory_path, [ invalid_source_url ])
valid_source_paths, invalid_source_paths = validate_source_paths(source_paths) valid_source_paths, invalid_source_paths = validate_source_paths(source_paths)

View File

@@ -43,8 +43,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ 'fairface' ] model_names = [ 'fairface' ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -57,10 +58,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]: def classify_face(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Gender, Age, Race]:

View File

@@ -79,8 +79,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_detector_model') ] model_names = [ state_manager.get_item('face_detector_model') ]
_, model_sources = collect_model_downloads() _, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -89,22 +90,22 @@ def clear_inference_pool() -> None:
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_hashes = {}
model_sources = {}
model_set = create_static_model_set('full') model_set = create_static_model_set('full')
model_hash_set = {}
model_source_set = {}
for face_detector_model in [ 'retinaface', 'scrfd', 'yolo_face' ]: for face_detector_model in [ 'retinaface', 'scrfd', 'yolo_face' ]:
if state_manager.get_item('face_detector_model') in [ 'many', face_detector_model ]: if state_manager.get_item('face_detector_model') in [ 'many', face_detector_model ]:
model_hashes[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model) model_hash_set[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model)
model_sources[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model) model_source_set[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model)
return model_hashes, model_sources return model_hash_set, model_source_set
def pre_check() -> bool: def pre_check() -> bool:
model_hashes, model_sources = collect_model_downloads() model_hash_set, model_source_set = collect_model_downloads()
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def detect_faces(vision_frame : VisionFrame) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]: def detect_faces(vision_frame : VisionFrame) -> Tuple[List[BoundingBox], List[Score], List[FaceLandmark5]]:

View File

@@ -7,7 +7,7 @@ from cv2.typing import Size
from facefusion.types import Anchors, Angle, BoundingBox, Distance, FaceDetectorModel, FaceLandmark5, FaceLandmark68, Mask, Matrix, Points, Scale, Score, Translation, VisionFrame, WarpTemplate, WarpTemplateSet from facefusion.types import Anchors, Angle, BoundingBox, Distance, FaceDetectorModel, FaceLandmark5, FaceLandmark68, Mask, Matrix, Points, Scale, Score, Translation, VisionFrame, WarpTemplate, WarpTemplateSet
WARP_TEMPLATES : WarpTemplateSet =\ WARP_TEMPLATE_SET : WarpTemplateSet =\
{ {
'arcface_112_v1': numpy.array( 'arcface_112_v1': numpy.array(
[ [
@@ -69,7 +69,7 @@ WARP_TEMPLATES : WarpTemplateSet =\
def estimate_matrix_by_face_landmark_5(face_landmark_5 : FaceLandmark5, warp_template : WarpTemplate, crop_size : Size) -> Matrix: def estimate_matrix_by_face_landmark_5(face_landmark_5 : FaceLandmark5, warp_template : WarpTemplate, crop_size : Size) -> Matrix:
normed_warp_template = WARP_TEMPLATES.get(warp_template) * crop_size normed_warp_template = WARP_TEMPLATE_SET.get(warp_template) * crop_size
affine_matrix = cv2.estimateAffinePartial2D(face_landmark_5, normed_warp_template, method = cv2.RANSAC, ransacReprojThreshold = 100)[0] affine_matrix = cv2.estimateAffinePartial2D(face_landmark_5, normed_warp_template, method = cv2.RANSAC, ransacReprojThreshold = 100)[0]
return affine_matrix return affine_matrix

View File

@@ -80,8 +80,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ] model_names = [ state_manager.get_item('face_landmarker_model'), 'fan_68_5' ]
_, model_sources = collect_model_downloads() _, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -91,27 +92,27 @@ def clear_inference_pool() -> None:
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_set = create_static_model_set('full') model_set = create_static_model_set('full')
model_hashes =\ model_hash_set =\
{ {
'fan_68_5': model_set.get('fan_68_5').get('hashes').get('fan_68_5') 'fan_68_5': model_set.get('fan_68_5').get('hashes').get('fan_68_5')
} }
model_sources =\ model_source_set =\
{ {
'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5') 'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5')
} }
for face_landmarker_model in [ '2dfan4', 'peppa_wutz' ]: for face_landmarker_model in [ '2dfan4', 'peppa_wutz' ]:
if state_manager.get_item('face_landmarker_model') in [ 'many', face_landmarker_model ]: if state_manager.get_item('face_landmarker_model') in [ 'many', face_landmarker_model ]:
model_hashes[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model) model_hash_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model)
model_sources[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model) model_source_set[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model)
return model_hashes, model_sources return model_hash_set, model_source_set
def pre_check() -> bool: def pre_check() -> bool:
model_hashes, model_sources = collect_model_downloads() model_hash_set, model_source_set = collect_model_downloads()
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def detect_face_landmarks(vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]: def detect_face_landmarks(vision_frame : VisionFrame, bounding_box : BoundingBox, face_angle : Angle) -> Tuple[FaceLandmark68, Score]:

View File

@@ -122,8 +122,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model')] model_names = [state_manager.get_item('face_occluder_model'), state_manager.get_item('face_parser_model')]
_, model_sources = collect_model_downloads() _, model_source_set = collect_model_downloads()
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -132,27 +133,27 @@ def clear_inference_pool() -> None:
def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]: def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_hashes = {}
model_sources = {}
model_set = create_static_model_set('full') model_set = create_static_model_set('full')
model_hash_set = {}
model_source_set = {}
for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]: for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]:
if state_manager.get_item('face_occluder_model') == face_occluder_model: if state_manager.get_item('face_occluder_model') == face_occluder_model:
model_hashes[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder') model_hash_set[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder')
model_sources[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder') model_source_set[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder')
for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]: for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]:
if state_manager.get_item('face_parser_model') == face_parser_model: if state_manager.get_item('face_parser_model') == face_parser_model:
model_hashes[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser') model_hash_set[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser')
model_sources[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser') model_source_set[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser')
return model_hashes, model_sources return model_hash_set, model_source_set
def pre_check() -> bool: def pre_check() -> bool:
model_hashes, model_sources = collect_model_downloads() model_hash_set, model_source_set = collect_model_downloads()
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
@lru_cache(maxsize = None) @lru_cache(maxsize = None)

View File

@@ -41,8 +41,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ 'arcface' ] model_names = [ 'arcface' ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -55,10 +56,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def calc_embedding(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Embedding, Embedding]: def calc_embedding(temp_vision_frame : VisionFrame, face_landmark_5 : FaceLandmark5) -> Tuple[Embedding, Embedding]:

View File

@@ -10,15 +10,15 @@ from facefusion.execution import create_inference_session_providers
from facefusion.filesystem import is_file from facefusion.filesystem import is_file
from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet from facefusion.types import DownloadSet, ExecutionProvider, InferencePool, InferencePoolSet
INFERENCE_POOLS : InferencePoolSet =\ INFERENCE_POOL_SET : InferencePoolSet =\
{ {
'cli': {}, #type:ignore[typeddict-item] 'cli': {},
'ui': {} #type:ignore[typeddict-item] 'ui': {}
} }
def get_inference_pool(module_name : str, model_names : List[str], model_sources : DownloadSet) -> InferencePool: def get_inference_pool(module_name : str, model_names : List[str], model_source_set : DownloadSet) -> InferencePool:
global INFERENCE_POOLS global INFERENCE_POOL_SET
while process_manager.is_checking(): while process_manager.is_checking():
sleep(0.5) sleep(0.5)
@@ -27,21 +27,21 @@ def get_inference_pool(module_name : str, model_names : List[str], model_sources
app_context = detect_app_context() app_context = detect_app_context()
inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers)
if app_context == 'cli' and INFERENCE_POOLS.get('ui').get(inference_context): if app_context == 'cli' and INFERENCE_POOL_SET.get('ui').get(inference_context):
INFERENCE_POOLS['cli'][inference_context] = INFERENCE_POOLS.get('ui').get(inference_context) INFERENCE_POOL_SET['cli'][inference_context] = INFERENCE_POOL_SET.get('ui').get(inference_context)
if app_context == 'ui' and INFERENCE_POOLS.get('cli').get(inference_context): if app_context == 'ui' and INFERENCE_POOL_SET.get('cli').get(inference_context):
INFERENCE_POOLS['ui'][inference_context] = INFERENCE_POOLS.get('cli').get(inference_context) INFERENCE_POOL_SET['ui'][inference_context] = INFERENCE_POOL_SET.get('cli').get(inference_context)
if not INFERENCE_POOLS.get(app_context).get(inference_context): if not INFERENCE_POOL_SET.get(app_context).get(inference_context):
INFERENCE_POOLS[app_context][inference_context] = create_inference_pool(model_sources, execution_device_id, execution_providers) INFERENCE_POOL_SET[app_context][inference_context] = create_inference_pool(model_source_set, execution_device_id, execution_providers)
return INFERENCE_POOLS.get(app_context).get(inference_context) return INFERENCE_POOL_SET.get(app_context).get(inference_context)
def create_inference_pool(model_sources : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool: def create_inference_pool(model_source_set : DownloadSet, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferencePool:
inference_pool : InferencePool = {} inference_pool : InferencePool = {}
for model_name in model_sources.keys(): for model_name in model_source_set.keys():
model_path = model_sources.get(model_name).get('path') model_path = model_source_set.get(model_name).get('path')
if is_file(model_path): if is_file(model_path):
inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers) inference_pool[model_name] = create_inference_session(model_path, execution_device_id, execution_providers)
@@ -49,15 +49,15 @@ def create_inference_pool(model_sources : DownloadSet, execution_device_id : str
def clear_inference_pool(module_name : str, model_names : List[str]) -> None: def clear_inference_pool(module_name : str, model_names : List[str]) -> None:
global INFERENCE_POOLS global INFERENCE_POOL_SET
execution_device_id = state_manager.get_item('execution_device_id') execution_device_id = state_manager.get_item('execution_device_id')
execution_providers = resolve_execution_providers(module_name) execution_providers = resolve_execution_providers(module_name)
app_context = detect_app_context() app_context = detect_app_context()
inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers) inference_context = get_inference_context(module_name, model_names, execution_device_id, execution_providers)
if INFERENCE_POOLS.get(app_context).get(inference_context): if INFERENCE_POOL_SET.get(app_context).get(inference_context):
del INFERENCE_POOLS[app_context][inference_context] del INFERENCE_POOL_SET[app_context][inference_context]
def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession: def create_inference_session(model_path : str, execution_device_id : str, execution_providers : List[ExecutionProvider]) -> InferenceSession:

View File

@@ -82,11 +82,11 @@ def delete_jobs(halt_on_error : bool) -> bool:
def find_jobs(job_status : JobStatus) -> JobSet: def find_jobs(job_status : JobStatus) -> JobSet:
job_ids = find_job_ids(job_status) job_ids = find_job_ids(job_status)
jobs : JobSet = {} job_set : JobSet = {}
for job_id in job_ids: for job_id in job_ids:
jobs[job_id] = read_job_file(job_id) job_set[job_id] = read_job_file(job_id)
return jobs return job_set
def find_job_ids(job_status : JobStatus) -> List[str]: def find_job_ids(job_status : JobStatus) -> List[str]:
@@ -188,7 +188,6 @@ def set_step_status(job_id : str, step_index : int, step_status : JobStepStatus)
if job: if job:
steps = job.get('steps') steps = job.get('steps')
if has_step(job_id, step_index): if has_step(job_id, step_index):
steps[step_index]['status'] = step_status steps[step_index]['status'] = step_status
return update_job_file(job_id, job) return update_job_file(job_id, job)

View File

@@ -101,12 +101,12 @@ def clean_steps(job_id: str) -> bool:
def collect_output_set(job_id : str) -> JobOutputSet: def collect_output_set(job_id : str) -> JobOutputSet:
steps = job_manager.get_steps(job_id) steps = job_manager.get_steps(job_id)
output_set : JobOutputSet = {} job_output_set : JobOutputSet = {}
for index, step in enumerate(steps): for index, step in enumerate(steps):
output_path = step.get('args').get('output_path') output_path = step.get('args').get('output_path')
if output_path: if output_path:
step_output_path = job_manager.get_step_output_path(job_id, index, output_path) step_output_path = job_manager.get_step_output_path(job_id, index, output_path)
output_set.setdefault(output_path, []).append(step_output_path) job_output_set.setdefault(output_path, []).append(step_output_path)
return output_set return job_output_set

View File

@@ -65,8 +65,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('age_modifier_model') ] model_names = [ state_manager.get_item('age_modifier_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -93,10 +94,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -241,8 +241,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('deep_swapper_model') ] model_names = [ state_manager.get_item('deep_swapper_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -278,11 +279,11 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
if model_hashes and model_sources: if model_hash_set and model_source_set:
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
return True return True

View File

@@ -76,8 +76,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('expression_restorer_model') ] model_names = [ state_manager.get_item('expression_restorer_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -104,10 +105,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -106,8 +106,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_editor_model') ] model_names = [ state_manager.get_item('face_editor_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -160,10 +161,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -223,8 +223,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('face_enhancer_model') ] model_names = [ state_manager.get_item('face_enhancer_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -253,10 +254,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -337,8 +337,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ get_face_swapper_model() ] model_names = [ get_face_swapper_model() ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -375,10 +376,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -130,8 +130,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('frame_colorizer_model') ] model_names = [ state_manager.get_item('frame_colorizer_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -166,10 +167,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -387,8 +387,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ get_frame_enhancer_model() ] model_names = [ get_frame_enhancer_model() ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -428,10 +429,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -75,8 +75,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ state_manager.get_item('lip_syncer_model') ] model_names = [ state_manager.get_item('lip_syncer_model') ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -101,10 +102,10 @@ def apply_args(args : Args, apply_state_item : ApplyStateItem) -> None:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def pre_process(mode : ProcessMode) -> bool: def pre_process(mode : ProcessMode) -> bool:

View File

@@ -1,37 +1,37 @@
from typing import Any, Union from typing import Any, Union
from facefusion.app_context import detect_app_context from facefusion.app_context import detect_app_context
from facefusion.processors.types import ProcessorState, ProcessorStateKey from facefusion.processors.types import ProcessorState, ProcessorStateKey, ProcessorStateSet
from facefusion.types import State, StateKey, StateSet from facefusion.types import State, StateKey, StateSet
STATES : Union[StateSet, ProcessorState] =\ STATE_SET : Union[StateSet, ProcessorStateSet] =\
{ {
'cli': {}, #type:ignore[typeddict-item] 'cli': {}, #type:ignore[assignment]
'ui': {} #type:ignore[typeddict-item] 'ui': {} #type:ignore[assignment]
} }
def get_state() -> Union[State, ProcessorState]: def get_state() -> Union[State, ProcessorState]:
app_context = detect_app_context() app_context = detect_app_context()
return STATES.get(app_context) #type:ignore return STATE_SET.get(app_context)
def init_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None: def init_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None:
STATES['cli'][key] = value #type:ignore STATE_SET['cli'][key] = value #type:ignore[literal-required]
STATES['ui'][key] = value #type:ignore STATE_SET['ui'][key] = value #type:ignore[literal-required]
def get_item(key : Union[StateKey, ProcessorStateKey]) -> Any: def get_item(key : Union[StateKey, ProcessorStateKey]) -> Any:
return get_state().get(key) #type:ignore return get_state().get(key) #type:ignore[literal-required]
def set_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None: def set_item(key : Union[StateKey, ProcessorStateKey], value : Any) -> None:
app_context = detect_app_context() app_context = detect_app_context()
STATES[app_context][key] = value #type:ignore STATE_SET[app_context][key] = value #type:ignore[literal-required]
def sync_item(key : Union[StateKey, ProcessorStateKey]) -> None: def sync_item(key : Union[StateKey, ProcessorStateKey]) -> None:
STATES['cli'][key] = STATES.get('ui').get(key) #type:ignore STATE_SET['cli'][key] = STATE_SET.get('ui').get(key) #type:ignore[literal-required]
def clear_item(key : Union[StateKey, ProcessorStateKey]) -> None: def clear_item(key : Union[StateKey, ProcessorStateKey]) -> None:

View File

@@ -92,7 +92,7 @@ def create_and_run_job(step_args : Args) -> bool:
job_id = job_helper.suggest_job_id('ui') job_id = job_helper.suggest_job_id('ui')
for key in job_store.get_job_keys(): for key in job_store.get_job_keys():
state_manager.sync_item(key) #type:ignore state_manager.sync_item(key) #type:ignore[arg-type]
return job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step) return job_manager.create_job(job_id) and job_manager.add_step(job_id, step_args) and job_manager.submit_job(job_id) and job_runner.run_job(job_id, process_step)

View File

@@ -84,7 +84,7 @@ def run(job_action : JobRunnerAction, job_id : str) -> Tuple[gradio.Button, grad
job_id = convert_str_none(job_id) job_id = convert_str_none(job_id)
for key in job_store.get_job_keys(): for key in job_store.get_job_keys():
state_manager.sync_item(key) #type:ignore state_manager.sync_item(key) #type:ignore[arg-type]
if job_action == 'job-run': if job_action == 'job-run':
logger.info(wording.get('running_job').format(job_id = job_id), __name__) logger.info(wording.get('running_job').format(job_id = job_id), __name__)

View File

@@ -39,8 +39,9 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
def get_inference_pool() -> InferencePool: def get_inference_pool() -> InferencePool:
model_names = [ 'kim_vocal_2' ] model_names = [ 'kim_vocal_2' ]
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return inference_manager.get_inference_pool(__name__, model_names, model_sources)
return inference_manager.get_inference_pool(__name__, model_names, model_source_set)
def clear_inference_pool() -> None: def clear_inference_pool() -> None:
@@ -53,10 +54,10 @@ def get_model_options() -> ModelOptions:
def pre_check() -> bool: def pre_check() -> bool:
model_hashes = get_model_options().get('hashes') model_hash_set = get_model_options().get('hashes')
model_sources = get_model_options().get('sources') model_source_set = get_model_options().get('sources')
return conditional_download_hashes(model_hashes) and conditional_download_sources(model_sources) return conditional_download_hashes(model_hash_set) and conditional_download_sources(model_source_set)
def batch_extract_voice(audio : Audio, chunk_size : int, step_size : int) -> Audio: def batch_extract_voice(audio : Audio, chunk_size : int, step_size : int) -> Audio:

View File

@@ -4,7 +4,7 @@ import pytest
from onnxruntime import InferenceSession from onnxruntime import InferenceSession
from facefusion import content_analyser, state_manager from facefusion import content_analyser, state_manager
from facefusion.inference_manager import INFERENCE_POOLS, get_inference_pool from facefusion.inference_manager import INFERENCE_POOL_SET, get_inference_pool
@pytest.fixture(scope = 'module', autouse = True) @pytest.fixture(scope = 'module', autouse = True)
@@ -17,16 +17,16 @@ def before_all() -> None:
def test_get_inference_pool() -> None: def test_get_inference_pool() -> None:
model_names = [ 'yolo_nsfw' ] model_names = [ 'yolo_nsfw' ]
model_sources = content_analyser.get_model_options().get('sources') model_source_set = content_analyser.get_model_options().get('sources')
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
get_inference_pool('facefusion.content_analyser', model_names, model_sources) get_inference_pool('facefusion.content_analyser', model_names, model_source_set)
assert isinstance(INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession)
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
get_inference_pool('facefusion.content_analyser', model_names, model_sources) get_inference_pool('facefusion.content_analyser', model_names, model_source_set)
assert isinstance(INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession)
assert INFERENCE_POOLS.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOLS.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser')