1 Star 1 Fork 0

zhugeliang1/UNET-ZOO

Create your Gitee Account
Explore and code with more than 12 million developers,Free private repositories !:)
Sign up
文件
This repository doesn't specify license. Please pay attention to the specific project description and its upstream code dependency when using it.
Clone or Download
r2unet.py 3.41 KB
Copy Edit Raw Blame History
andyzhu authored 2020-02-11 13:13 . Add files via upload
import torch
import torch.nn as nn
import torch.nn.functional as F
class up_conv(nn.Module):
def __init__(self,ch_in,ch_out):
super(up_conv,self).__init__()
self.up = nn.Sequential(
nn.Upsample(scale_factor=2),
nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.up(x)
return x
class Recurrent_block(nn.Module):
def __init__(self, ch_out, t=2):
super(Recurrent_block, self).__init__()
self.t = t
self.ch_out = ch_out
self.conv = nn.Sequential(
nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1, bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self, x):
for i in range(self.t):
if i == 0:
x1 = self.conv(x)
x1 = self.conv(x + x1)
return x1
class RRCNN_block(nn.Module):
def __init__(self,ch_in,ch_out,t=2):
super(RRCNN_block,self).__init__()
self.RCNN = nn.Sequential(
Recurrent_block(ch_out,t=t),
Recurrent_block(ch_out,t=t)
)
self.Conv_1x1 = nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=1,padding=0)
def forward(self,x):
x = self.Conv_1x1(x)
x1 = self.RCNN(x)
return x+x1
class R2U_Net(nn.Module):
def __init__(self, img_ch=3, output_ch=1, t=2):
super(R2U_Net, self).__init__()
self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
self.Upsample = nn.Upsample(scale_factor=2)
self.RRCNN1 = RRCNN_block(ch_in=img_ch, ch_out=64, t=t)
self.RRCNN2 = RRCNN_block(ch_in=64, ch_out=128, t=t)
self.RRCNN3 = RRCNN_block(ch_in=128, ch_out=256, t=t)
self.RRCNN4 = RRCNN_block(ch_in=256, ch_out=512, t=t)
self.RRCNN5 = RRCNN_block(ch_in=512, ch_out=1024, t=t)
self.Up5 = up_conv(ch_in=1024, ch_out=512)
self.Up_RRCNN5 = RRCNN_block(ch_in=1024, ch_out=512, t=t)
self.Up4 = up_conv(ch_in=512, ch_out=256)
self.Up_RRCNN4 = RRCNN_block(ch_in=512, ch_out=256, t=t)
self.Up3 = up_conv(ch_in=256, ch_out=128)
self.Up_RRCNN3 = RRCNN_block(ch_in=256, ch_out=128, t=t)
self.Up2 = up_conv(ch_in=128, ch_out=64)
self.Up_RRCNN2 = RRCNN_block(ch_in=128, ch_out=64, t=t)
self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# encoding path
x1 = self.RRCNN1(x)
x2 = self.Maxpool(x1)
x2 = self.RRCNN2(x2)
x3 = self.Maxpool(x2)
x3 = self.RRCNN3(x3)
x4 = self.Maxpool(x3)
x4 = self.RRCNN4(x4)
x5 = self.Maxpool(x4)
x5 = self.RRCNN5(x5)
# decoding + concat path
d5 = self.Up5(x5)
d5 = torch.cat((x4, d5), dim=1)
d5 = self.Up_RRCNN5(d5)
d4 = self.Up4(d5)
d4 = torch.cat((x3, d4), dim=1)
d4 = self.Up_RRCNN4(d4)
d3 = self.Up3(d4)
d3 = torch.cat((x2, d3), dim=1)
d3 = self.Up_RRCNN3(d3)
d2 = self.Up2(d3)
d2 = torch.cat((x1, d2), dim=1)
d2 = self.Up_RRCNN2(d2)
d1 = self.Conv_1x1(d2)
d1 = self.sigmoid(d1)
return d1
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/wsq_chd/UNET-ZOO.git
git@gitee.com:wsq_chd/UNET-ZOO.git
wsq_chd
UNET-ZOO
UNET-ZOO
master

Search