1 Star 0 Fork 1

xuzhiyang483/Pytorch-CapsuleNet

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
capsnet.py 6.12 KB
一键复制 编辑 原始数据 按行查看 历史
jindongwang 提交于 2018-04-10 14:34 . add: cifar-10
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
USE_CUDA = True if torch.cuda.is_available() else False
class ConvLayer(nn.Module):
def __init__(self, in_channels=1, out_channels=256, kernel_size=9):
super(ConvLayer, self).__init__()
self.conv = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1
)
def forward(self, x):
return F.relu(self.conv(x))
class PrimaryCaps(nn.Module):
def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, num_routes=32 * 6 * 6):
super(PrimaryCaps, self).__init__()
self.num_routes = num_routes
self.capsules = nn.ModuleList([
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=2, padding=0)
for _ in range(num_capsules)])
def forward(self, x):
u = [capsule(x) for capsule in self.capsules]
u = torch.stack(u, dim=1)
u = u.view(x.size(0), self.num_routes, -1)
return self.squash(u)
def squash(self, input_tensor):
squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
return output_tensor
class DigitCaps(nn.Module):
def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16):
super(DigitCaps, self).__init__()
self.in_channels = in_channels
self.num_routes = num_routes
self.num_capsules = num_capsules
self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))
def forward(self, x):
batch_size = x.size(0)
x = torch.stack([x] * self.num_capsules, dim=2).unsqueeze(4)
W = torch.cat([self.W] * batch_size, dim=0)
u_hat = torch.matmul(W, x)
b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
if USE_CUDA:
b_ij = b_ij.cuda()
num_iterations = 3
for iteration in range(num_iterations):
c_ij = F.softmax(b_ij, dim=1)
c_ij = torch.cat([c_ij] * batch_size, dim=0).unsqueeze(4)
s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)
v_j = self.squash(s_j)
if iteration < num_iterations - 1:
a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim=1))
b_ij = b_ij + a_ij.squeeze(4).mean(dim=0, keepdim=True)
return v_j.squeeze(1)
def squash(self, input_tensor):
squared_norm = (input_tensor ** 2).sum(-1, keepdim=True)
output_tensor = squared_norm * input_tensor / ((1. + squared_norm) * torch.sqrt(squared_norm))
return output_tensor
class Decoder(nn.Module):
def __init__(self, input_width=28, input_height=28, input_channel=1):
super(Decoder, self).__init__()
self.input_width = input_width
self.input_height = input_height
self.input_channel = input_channel
self.reconstraction_layers = nn.Sequential(
nn.Linear(16 * 10, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 1024),
nn.ReLU(inplace=True),
nn.Linear(1024, self.input_height * self.input_height * self.input_channel),
nn.Sigmoid()
)
def forward(self, x, data):
classes = torch.sqrt((x ** 2).sum(2))
classes = F.softmax(classes, dim=0)
_, max_length_indices = classes.max(dim=1)
masked = Variable(torch.sparse.torch.eye(10))
if USE_CUDA:
masked = masked.cuda()
masked = masked.index_select(dim=0, index=Variable(max_length_indices.squeeze(1).data))
t = (x * masked[:, :, None, None]).view(x.size(0), -1)
reconstructions = self.reconstraction_layers(t)
reconstructions = reconstructions.view(-1, self.input_channel, self.input_width, self.input_height)
return reconstructions, masked
class CapsNet(nn.Module):
def __init__(self, config=None):
super(CapsNet, self).__init__()
if config:
self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels, config.cnn_kernel_size)
self.primary_capsules = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, config.pc_out_channels,
config.pc_kernel_size, config.pc_num_routes)
self.digit_capsules = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
config.dc_out_channels)
self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)
else:
self.conv_layer = ConvLayer()
self.primary_capsules = PrimaryCaps()
self.digit_capsules = DigitCaps()
self.decoder = Decoder()
self.mse_loss = nn.MSELoss()
def forward(self, data):
output = self.digit_capsules(self.primary_capsules(self.conv_layer(data)))
reconstructions, masked = self.decoder(output, data)
return output, reconstructions, masked
def loss(self, data, x, target, reconstructions):
return self.margin_loss(x, target) + self.reconstruction_loss(data, reconstructions)
def margin_loss(self, x, labels, size_average=True):
batch_size = x.size(0)
v_c = torch.sqrt((x ** 2).sum(dim=2, keepdim=True))
left = F.relu(0.9 - v_c).view(batch_size, -1)
right = F.relu(v_c - 0.1).view(batch_size, -1)
loss = labels * left + 0.5 * (1.0 - labels) * right
loss = loss.sum(dim=1).mean()
return loss
def reconstruction_loss(self, data, reconstructions):
loss = self.mse_loss(reconstructions.view(reconstructions.size(0), -1), data.view(reconstructions.size(0), -1))
return loss * 0.0005
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/xuzhiyang483/Pytorch-CapsuleNet.git
git@gitee.com:xuzhiyang483/Pytorch-CapsuleNet.git
xuzhiyang483
Pytorch-CapsuleNet
Pytorch-CapsuleNet
master

搜索帮助

D67c1975 1850385 1daf7b77 1850385