Add another xseg model, Simplify download mapping (#851)

This commit is contained in:
Henry Ruhs
2025-01-10 12:18:46 +01:00
committed by henryruhs
parent 71092cb951
commit 7f90ca72bb
6 changed files with 39 additions and 35 deletions

View File

@@ -57,6 +57,26 @@ def create_static_model_set(download_scope : DownloadScope) -> ModelSet:
},
'size': (256, 256)
},
'xseg_3':
{
'hashes':
{
'face_occluder':
{
'url': resolve_download_url('models-3.2.0', 'xseg_3.hash'),
'path': resolve_relative_path('../.assets/models/xseg_3.hash')
}
},
'sources':
{
'face_occluder':
{
'url': resolve_download_url('models-3.2.0', 'xseg_3.onnx'),
'path': resolve_relative_path('../.assets/models/xseg_3.onnx')
}
},
'size': (256, 256)
},
'bisenet_resnet_18':
{
'hashes':
@@ -114,21 +134,15 @@ def collect_model_downloads() -> Tuple[DownloadSet, DownloadSet]:
model_sources = {}
model_set = create_static_model_set('full')
if state_manager.get_item('face_occluder_model') == 'xseg_1':
model_hashes['xseg_1'] = model_set.get('xseg_1').get('hashes').get('face_occluder')
model_sources['xseg_1'] = model_set.get('xseg_1').get('sources').get('face_occluder')
for face_occluder_model in [ 'xseg_1', 'xseg_2', 'xseg_3' ]:
if state_manager.get_item('face_occluder_model') == face_occluder_model:
model_hashes[face_occluder_model] = model_set.get(face_occluder_model).get('hashes').get('face_occluder')
model_sources[face_occluder_model] = model_set.get(face_occluder_model).get('sources').get('face_occluder')
if state_manager.get_item('face_occluder_model') == 'xseg_2':
model_hashes['xseg_2'] = model_set.get('xseg_2').get('hashes').get('face_occluder')
model_sources['xseg_2'] = model_set.get('xseg_2').get('sources').get('face_occluder')
if state_manager.get_item('face_parser_model') == 'bisenet_resnet_18':
model_hashes['bisenet_resnet_18'] = model_set.get('bisenet_resnet_18').get('hashes').get('face_parser')
model_sources['bisenet_resnet_18'] = model_set.get('bisenet_resnet_18').get('sources').get('face_parser')
if state_manager.get_item('face_parser_model') == 'bisenet_resnet_34':
model_hashes['bisenet_resnet_34'] = model_set.get('bisenet_resnet_34').get('hashes').get('face_parser')
model_sources['bisenet_resnet_34'] = model_set.get('bisenet_resnet_34').get('sources').get('face_parser')
for face_parser_model in [ 'bisenet_resnet_18', 'bisenet_resnet_34' ]:
if state_manager.get_item('face_parser_model') == face_parser_model:
model_hashes[face_parser_model] = model_set.get(face_parser_model).get('hashes').get('face_parser')
model_sources[face_parser_model] = model_set.get(face_parser_model).get('sources').get('face_parser')
return model_hashes, model_sources