Add another xseg model, Simplify download mapping (#851)
This commit is contained in:
@@ -17,7 +17,7 @@ face_selector_modes : List[FaceSelectorMode] = [ 'many', 'one', 'reference' ]
|
||||
face_selector_orders : List[FaceSelectorOrder] = [ 'left-right', 'right-left', 'top-bottom', 'bottom-top', 'small-large', 'large-small', 'best-worst', 'worst-best' ]
|
||||
face_selector_genders : List[Gender] = [ 'female', 'male' ]
|
||||
face_selector_races : List[Race] = [ 'white', 'black', 'latino', 'asian', 'indian', 'arabic' ]
|
||||
face_occluder_models : List[FaceOccluderModel] = [ 'xseg_1', 'xseg_2' ]
|
||||
face_occluder_models : List[FaceOccluderModel] = [ 'xseg_1', 'xseg_2', 'xseg_3' ]
|
||||
face_parser_models : List[FaceParserModel] = [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]
|
||||
face_mask_types : List[FaceMaskType] = [ 'box', 'occlusion', 'region' ]
|
||||
face_mask_region_set : FaceMaskRegionSet =\
|
||||
|
||||
@@ -91,17 +91,10 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||
model_sources = {}
|
||||
model_set = create_static_model_set('full')
|
||||
|
||||
if state_manager.get_item('face_detector_model') in [ 'many', 'retinaface' ]:
|
||||
model_hashes['retinaface'] = model_set.get('retinaface').get('hashes').get('retinaface')
|
||||
model_sources['retinaface'] = model_set.get('retinaface').get('sources').get('retinaface')
|
||||
|
||||
if state_manager.get_item('face_detector_model') in [ 'many', 'scrfd' ]:
|
||||
model_hashes['scrfd'] = model_set.get('scrfd').get('hashes').get('scrfd')
|
||||
model_sources['scrfd'] = model_set.get('scrfd').get('sources').get('scrfd')
|
||||
|
||||
if state_manager.get_item('face_detector_model') in [ 'many', 'yoloface' ]:
|
||||
model_hashes['yoloface'] = model_set.get('yoloface').get('hashes').get('yoloface')
|
||||
model_sources['yoloface'] = model_set.get('yoloface').get('sources').get('yoloface')
|
||||
for face_detector_model in [ 'retinaface', 'scrfd', 'yoloface' ]:
|
||||
if state_manager.get_item('face_detector_model') in [ 'many', face_detector_model ]:
|
||||
model_hashes[face_detector_model] = model_set.get(face_detector_model).get('hashes').get(face_detector_model)
|
||||
model_sources[face_detector_model] = model_set.get(face_detector_model).get('sources').get(face_detector_model)
|
||||
|
||||
return model_hashes, model_sources
|
||||
|
||||
|
||||
@@ -98,13 +98,10 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||
'fan_68_5': model_set.get('fan_68_5').get('sources').get('fan_68_5')
|
||||
}
|
||||
|
||||
if state_manager.get_item('face_landmarker_model') in [ 'many', '2dfan4' ]:
|
||||
model_hashes['2dfan4'] = model_set.get('2dfan4').get('hashes').get('2dfan4')
|
||||
model_sources['2dfan4'] = model_set.get('2dfan4').get('sources').get('2dfan4')
|
||||
|
||||
if state_manager.get_item('face_landmarker_model') in [ 'many', 'peppa_wutz' ]:
|
||||
model_hashes['peppa_wutz'] = model_set.get('peppa_wutz').get('hashes').get('peppa_wutz')
|
||||
model_sources['peppa_wutz'] = model_set.get('peppa_wutz').get('sources').get('peppa_wutz')
|
||||
for face_landmarker_model in [ '2dfan4', 'peppa_wutz' ]:
|
||||
if state_manager.get_item('face_landmarker_model') in [ 'many', face_landmarker_model ]:
|
||||
model_hashes[face_landmarker_model] = model_set.get(face_landmarker_model).get('hashes').get(face_landmarker_model)
|
||||
model_sources[face_landmarker_model] = model_set.get(face_landmarker_model).get('sources').get(face_landmarker_model)
|
||||
|
||||
return model_hashes, model_sources
|
||||
|
||||
|
||||
@@ -57,6 +57,26 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
|
||||
},
|
||||
'size': (256, 256)
|
||||
},
|
||||
'xseg_3':
|
||||
{
|
||||
'hashes':
|
||||
{
|
||||
'face_occluder':
|
||||
{
|
||||
'url': resolve_download_url('models-3.2.0', 'xseg_3.hash'),
|
||||
'path': resolve_relative_path('../.assets/models/xseg_3.hash')
|
||||
}
|
||||
},
|
||||
'sources':
|
||||
{
|
||||
'face_occluder':
|
||||
{
|
||||
'url': resolve_download_url('models-3.2.0', 'xseg_3.onnx'),
|
||||
'path': resolve_relative_path('../.assets/models/xseg_3.onnx')
|
||||
}
|
||||
},
|
||||
'size': (256, 256)
|
||||
},
|
||||
'bisenet_resnet_18':
|
||||
{
|
||||
'hashes':
|
||||
@@ -114,21 +134,15 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
|
||||
model_sources = {}
|
||||
model_set = create_static_model_set('full')
|
||||
|
||||
if state_manager.get_item('face_occluder_model') == 'xseg_1':
|
||||
model_hashes['xseg_1'] = model_set.get('xseg_1').get('hashes').get('face_occluder')
|
||||
model_sources['xseg_1'] = model_set.get('xseg_1').get('sources').get('face_occluder')
|
||||
for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]:
|
||||
if state_manager.get_item('face_occluder_model') == face_occluder_model:
|
||||
model_hashes[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder')
|
||||
model_sources[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder')
|
||||
|
||||
if state_manager.get_item('face_occluder_model') == 'xseg_2':
|
||||
model_hashes['xseg_2'] = model_set.get('xseg_2').get('hashes').get('face_occluder')
|
||||
model_sources['xseg_2'] = model_set.get('xseg_2').get('sources').get('face_occluder')
|
||||
|
||||
if state_manager.get_item('face_parser_model') == 'bisenet_resnet_18':
|
||||
model_hashes['bisenet_resnet_18'] = model_set.get('bisenet_resnet_18').get('hashes').get('face_parser')
|
||||
model_sources['bisenet_resnet_18'] = model_set.get('bisenet_resnet_18').get('sources').get('face_parser')
|
||||
|
||||
if state_manager.get_item('face_parser_model') == 'bisenet_resnet_34':
|
||||
model_hashes['bisenet_resnet_34'] = model_set.get('bisenet_resnet_34').get('hashes').get('face_parser')
|
||||
model_sources['bisenet_resnet_34'] = model_set.get('bisenet_resnet_34').get('sources').get('face_parser')
|
||||
for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]:
|
||||
if state_manager.get_item('face_parser_model') == face_parser_model:
|
||||
model_hashes[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser')
|
||||
model_sources[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser')
|
||||
|
||||
return model_hashes, model_sources
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ FaceLandmarkerModel = Literal['many', '2dfan4', 'peppa_wutz']
|
||||
FaceDetectorSet = Dict[FaceDetectorModel, List[str]]
|
||||
FaceSelectorMode = Literal['many', 'one', 'reference']
|
||||
FaceSelectorOrder = Literal['left-right', 'right-left', 'top-bottom', 'bottom-top', 'small-large', 'large-small', 'best-worst', 'worst-best']
|
||||
FaceOccluderModel = Literal['xseg_1', 'xseg_2']
|
||||
FaceOccluderModel = Literal['xseg_1', 'xseg_2', 'xseg_3']
|
||||
FaceParserModel = Literal['bisenet_resnet_18', 'bisenet_resnet_34']
|
||||
FaceMaskType = Literal['box', 'occlusion', 'region']
|
||||
FaceMaskRegion = Literal['skin', 'left-eyebrow', 'right-eyebrow', 'left-eye', 'right-eye', 'glasses', 'nose', 'mouth', 'upper-lip', 'lower-lip']
|
||||
|
||||
@@ -78,5 +78,5 @@ def create_tqdm_output(self : tqdm) -> Optional[str]:
|
||||
|
||||
def read_logs() -> str:
|
||||
LOG_BUFFER.seek(0)
|
||||
logs = LOG_BUFFER.read().rstrip()
|
||||
logs = LOG_BUFFER.read().strip()
|
||||
return logs
|
||||
|
||||
Reference in New Issue
Block a user