1 Star 0 Fork 0

yasuo_hao/Yet-Another-EfficientDet-Pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
backbone.py 3.59 KB
一键复制 编辑 原始数据 按行查看 历史
zylo117 提交于 2020-07-23 15:38 . supports efficientdet-d7x now
# Author: Zylo117
import torch
from torch import nn
from efficientdet.model import BiFPN, Regressor, Classifier, EfficientNet
from efficientdet.utils import Anchors
class EfficientDetBackbone(nn.Module):
def __init__(self, num_classes=80, compound_coef=0, load_weights=False, **kwargs):
super(EfficientDetBackbone, self).__init__()
self.compound_coef = compound_coef
self.backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6, 7]
self.fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384, 384]
self.fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8, 8]
self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536, 1536]
self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5, 5]
self.pyramid_levels = [5, 5, 5, 5, 5, 5, 5, 5, 6]
self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5., 4.]
self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)])
self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)]))
conv_channel_coef = {
# the channels of P3/P4/P5.
0: [40, 112, 320],
1: [40, 112, 320],
2: [48, 120, 352],
3: [48, 136, 384],
4: [56, 160, 448],
5: [64, 176, 512],
6: [72, 200, 576],
7: [72, 200, 576],
8: [80, 224, 640],
}
num_anchors = len(self.aspect_ratios) * self.num_scales
self.bifpn = nn.Sequential(
*[BiFPN(self.fpn_num_filters[self.compound_coef],
conv_channel_coef[compound_coef],
True if _ == 0 else False,
attention=True if compound_coef < 6 else False,
use_p8=compound_coef > 7)
for _ in range(self.fpn_cell_repeats[compound_coef])])
self.num_classes = num_classes
self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_layers=self.box_class_repeats[self.compound_coef],
pyramid_levels=self.pyramid_levels[self.compound_coef])
self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors,
num_classes=num_classes,
num_layers=self.box_class_repeats[self.compound_coef],
pyramid_levels=self.pyramid_levels[self.compound_coef])
self.anchors = Anchors(anchor_scale=self.anchor_scale[compound_coef],
pyramid_levels=(torch.arange(self.pyramid_levels[self.compound_coef]) + 3).tolist(),
**kwargs)
self.backbone_net = EfficientNet(self.backbone_compound_coef[compound_coef], load_weights)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def forward(self, inputs):
max_size = inputs.shape[-1]
_, p3, p4, p5 = self.backbone_net(inputs)
features = (p3, p4, p5)
features = self.bifpn(features)
regression = self.regressor(features)
classification = self.classifier(features)
anchors = self.anchors(inputs, inputs.dtype)
return features, regression, classification, anchors
def init_backbone(self, path):
state_dict = torch.load(path)
try:
ret = self.load_state_dict(state_dict, strict=False)
print(ret)
except RuntimeError as e:
print('Ignoring ' + str(e) + '"')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/yasuo_hao/Yet-Another-EfficientDet-Pytorch.git
git@gitee.com:yasuo_hao/Yet-Another-EfficientDet-Pytorch.git
yasuo_hao
Yet-Another-EfficientDet-Pytorch
Yet-Another-EfficientDet-Pytorch
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385