代码拉取完成,页面将自动刷新
同步操作将从 kate/dino_ms 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import copy
from models.backbone import build_backbone
from models.transformer import build_transformer
from models.matcher import build_matcher
from models.loss import SetCriterion
from models.post_process import PostProcess
from models.dino import DINO
def build_dino(args):
# the `num_classes` naming here is somewhat misleading.
# it indeed corresponds to `max_obj_id + 1`, where max_obj_id
# is the maximum id for a class in your dataset. For example,
# COCO has a max_obj_id of 90, so we pass `num_classes` to be 91.
# As another example, for a dataset that has a single class with id 1,
# you should pass `num_classes` to be 2 (max_obj_id + 1).
# For more details on this, check the following discussion
# https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223
# num_classes = 20 if args.dataset_file != 'coco' else 91
# if args.dataset_file == "coco_panoptic":
# # for panoptic, we just add a num_classes that is large enough to hold
# # max_obj_id + 1, but the exact value doesn't really matter
# num_classes = 250
# if args.dataset_file == 'o365':
# num_classes = 366
# if args.dataset_file == 'vanke':
# num_classes = 51
num_classes = args.num_classes
backbone = build_backbone(args)
transformer = build_transformer(args)
try:
match_unstable_error = args.match_unstable_error
dn_labelbook_size = args.dn_labelbook_size
except:
match_unstable_error = True
dn_labelbook_size = num_classes
try:
dec_pred_class_embed_share = args.dec_pred_class_embed_share
except:
dec_pred_class_embed_share = True
try:
dec_pred_bbox_embed_share = args.dec_pred_bbox_embed_share
except:
dec_pred_bbox_embed_share = True
model = DINO(
backbone,
transformer,
num_classes=num_classes,
num_queries=args.num_queries,
aux_loss=True,
iter_update=True,
query_dim=4,
random_refpoints_xy=args.random_refpoints_xy,
fix_refpoints_hw=args.fix_refpoints_hw,
num_feature_levels=args.num_feature_levels,
nheads=args.nheads,
dec_pred_class_embed_share=dec_pred_class_embed_share,
dec_pred_bbox_embed_share=dec_pred_bbox_embed_share,
# two stage
two_stage_type=args.two_stage_type,
# box_share
two_stage_bbox_embed_share=args.two_stage_bbox_embed_share,
two_stage_class_embed_share=args.two_stage_class_embed_share,
decoder_sa_type=args.decoder_sa_type,
num_patterns=args.num_patterns,
dn_number = args.dn_number if args.use_dn else 0,
dn_box_noise_scale = args.dn_box_noise_scale,
dn_label_noise_ratio = args.dn_label_noise_ratio,
dn_labelbook_size = dn_labelbook_size,
)
if args.masks:
model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None))
matcher = build_matcher(args)
# prepare weight dict
weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef}
weight_dict['loss_giou'] = args.giou_loss_coef
clean_weight_dict_wo_dn = copy.deepcopy(weight_dict)
# for DN training
if args.use_dn:
weight_dict['loss_ce_dn'] = args.cls_loss_coef
weight_dict['loss_bbox_dn'] = args.bbox_loss_coef
weight_dict['loss_giou_dn'] = args.giou_loss_coef
if args.masks:
weight_dict["loss_mask"] = args.mask_loss_coef
weight_dict["loss_dice"] = args.dice_loss_coef
clean_weight_dict = copy.deepcopy(weight_dict)
# TODO this is a hack
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in clean_weight_dict.items()})
weight_dict.update(aux_weight_dict)
if args.two_stage_type != 'no':
interm_weight_dict = {}
try:
no_interm_box_loss = args.no_interm_box_loss
except:
no_interm_box_loss = False
_coeff_weight_dict = {
'loss_ce': 1.0,
'loss_bbox': 1.0 if not no_interm_box_loss else 0.0,
'loss_giou': 1.0 if not no_interm_box_loss else 0.0,
}
try:
interm_loss_coef = args.interm_loss_coef
except:
interm_loss_coef = 1.0
interm_weight_dict.update({k + f'_interm': v * interm_loss_coef * _coeff_weight_dict[k] for k, v in clean_weight_dict_wo_dn.items()})
weight_dict.update(interm_weight_dict)
losses = ['labels', 'boxes', 'cardinality']
if args.masks:
losses += ["masks"]
criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
focal_alpha=args.focal_alpha, losses=losses,
)
postprocessors = {'bbox': PostProcess(num_select=args.num_select, nms_iou_threshold=args.nms_iou_threshold)}
if args.masks:
postprocessors['segm'] = PostProcessSegm()
if args.dataset_file == "coco_panoptic":
is_thing_map = {i: i <= 90 for i in range(201)}
postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85)
return model, criterion, postprocessors
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。