1 Star 0 Fork 0

ideaoverflow/CRAFT-Reimplementation

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
craft.py 2.92 KB
一键复制 编辑 原始数据 按行查看 历史
严海 提交于 2019-09-02 14:39 . first commit
"""
Copyright (c) 2019-present NAVER Corp.
MIT License
"""
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import torch.nn.init as init
from torchutil import *
from basenet.vgg16_bn import vgg16_bn
class double_conv(nn.Module):
def __init__(self, in_ch, mid_ch, out_ch):
super(double_conv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
nn.BatchNorm2d(mid_ch),
nn.ReLU(inplace=True),
nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, x):
x = self.conv(x)
return x
class CRAFT(nn.Module):
def __init__(self, pretrained=True, freeze=False):
super(CRAFT, self).__init__()
""" Base network """
# self.net = vgg16_bn(pretrained, freeze)
# self.net.load_state_dict(copyStateDict(torch.load('vgg16_bn-6c64b313.pth')))
# self.basenet = self.net
self.basenet = vgg16_bn(pretrained, freeze)
""" U network """
self.upconv1 = double_conv(1024, 512, 256)
self.upconv2 = double_conv(512, 256, 128)
self.upconv3 = double_conv(256, 128, 64)
self.upconv4 = double_conv(128, 64, 32)
num_class = 2
self.conv_cls = nn.Sequential(
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
nn.Conv2d(16, num_class, kernel_size=1),
)
init_weights(self.upconv1.modules())
init_weights(self.upconv2.modules())
init_weights(self.upconv3.modules())
init_weights(self.upconv4.modules())
init_weights(self.conv_cls.modules())
def forward(self, x):
""" Base network """
sources = self.basenet(x)
""" U network """
y = torch.cat([sources[0], sources[1]], dim=1)
y = self.upconv1(y)
y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[2]], dim=1)
y = self.upconv2(y)
y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[3]], dim=1)
y = self.upconv3(y)
y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
y = torch.cat([y, sources[4]], dim=1)
feature = self.upconv4(y)
y = self.conv_cls(feature)
return y.permute(0, 2, 3, 1), feature
if __name__ == '__main__':
model = CRAFT(pretrained=True).cuda()
output, _ = model(torch.randn(1, 3, 768, 768).cuda())
print(output.shape)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ideaoverflow/CRAFT-Reimplementation.git
git@gitee.com:ideaoverflow/CRAFT-Reimplementation.git
ideaoverflow
CRAFT-Reimplementation
CRAFT-Reimplementation
master

搜索帮助