This commit is contained in:
harisreedhar
2025-06-10 18:05:20 +05:30
parent 2309b4d79a
commit 6b03388f76

View File

@@ -16,17 +16,17 @@ def before_all() -> None:
def test_get_inference_pool() -> None: def test_get_inference_pool() -> None:
model_names = [ 'yolo_nsfw' ] model_names = [ 'yolo_11m', 'marqo', 'freepik' ]
model_source_set = content_analyser.get_model_options().get('sources') _, model_source_set = content_analyser.collect_model_downloads()
with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'cli'):
get_inference_pool('facefusion.content_analyser', model_names, model_source_set) get_inference_pool('facefusion.content_analyser', model_names, model_source_set)
assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m'), InferenceSession)
with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'): with patch('facefusion.inference_manager.detect_app_context', return_value = 'ui'):
get_inference_pool('facefusion.content_analyser', model_names, model_source_set) get_inference_pool('facefusion.content_analyser', model_names, model_source_set)
assert isinstance(INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser'), InferenceSession) assert isinstance(INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m'), InferenceSession)
assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_nsfw.0.cpu').get('content_analyser') assert INFERENCE_POOL_SET.get('cli').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m') == INFERENCE_POOL_SET.get('ui').get('facefusion.content_analyser.yolo_11m.marqo.freepik.0.cpu').get('yolo_11m')