1 Star 0 Fork 1

RaymondCHENG/dino_ms

forked from kate/dino_ms 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
build_model.py 5.16 KB
一键复制 编辑 原始数据 按行查看 历史
kate 提交于 2022-12-15 17:15 . init
import copy
from models.backbone import build_backbone
from models.transformer import build_transformer
from models.matcher import build_matcher
from models.loss import SetCriterion
from models.post_process import PostProcess
from models.dino import DINO
def build_dino(args):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
# num_classes = 20 if args.dataset_file != 'coco' else 91
# if args.dataset_file == "coco_panoptic":
# # for panoptic, we just add a num_classes that is large enough to hold
# # max_obj_id + 1, but the exact value doesn't really matter
# num_classes = 250
# if args.dataset_file == 'o365':
# num_classes = 366
# if args.dataset_file == 'vanke':
# num_classes = 51
num_classes = args.num_classes
backbone = build_backbone(args)
transformer = build_transformer(args)
try:
match_unstable_error = args.match_unstable_error
dn_labelbook_size = args.dn_labelbook_size
except:
match_unstable_error = True
dn_labelbook_size = num_classes
try:
dec_pred_class_embed_share = args.dec_pred_class_embed_share
except:
dec_pred_class_embed_share = True
try:
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
except:
dec_pred_bbox_embed_share = True
model = DINO(
backbone,
transformer,
num_classes=num_classes,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
random_refpoints_xy=args.random_refpoints_xy,
fix_refpoints_hw=args.fix_refpoints_hw,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_class_embed_share=dec_pred_class_embed_share,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
# two stage
two_stage_type=args.two_stage_type,
# box_share
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
decoder_sa_type=args.decoder_sa_type,
num_patterns=args.num_patterns,
dn_number = args.dn_number if args.use_dn else 0,
dn_box_noise_scale = args.dn_box_noise_scale,
dn_label_noise_ratio = args.dn_label_noise_ratio,
dn_labelbook_size = dn_labelbook_size,
)
if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args)
# prepare weight dict
weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict)
# for DN training
if args.use_dn:
weight_dict['loss_ce_dn'] = args.cls_loss_coef
weight_dict['loss_bbox_dn'] = args.bbox_loss_coef
weight_dict['loss_giou_dn'] = args.giou_loss_coef
if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef
clean_weight_dict = copy.deepcopy(weight_dict)
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in clean_weight_dict.items()})
weight_dict.update(aux_weight_dict)
if args.two_stage_type != 'no':
interm_weight_dict = {}
try:
no_interm_box_loss = args.no_interm_box_loss
except:
no_interm_box_loss = False
_coeff_weight_dict = {
'loss_ce': 1.0,
'loss_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_giou': 1.0 if not no_interm_box_loss else 0.0,
}
try:
interm_loss_coef = args.interm_loss_coef
except:
interm_loss_coef = 1.0
interm_weight_dict.update({k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items()})
weight_dict.update(interm_weight_dict)
losses = ['labels', 'boxes', 'cardinality']
if args.masks:
losses += ["masks"]
criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
focal_alpha=args.focal_alpha, losses=losses,
)
postprocessors = {'bbox': PostProcess(num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold)}
if args.masks:
postprocessors['segm'] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
return model, criterion, postprocessors
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/raymondzxyu/dino_ms.git
git@gitee.com:raymondzxyu/dino_ms.git
raymondzxyu
dino_ms
dino_ms
master

搜索帮助