1 Star 0 Fork 1

stawary/ESPCN_Learning

forked from vegee/ESPCN_Learning 
加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
model.py 1.24 KB
一键复制 编辑 原始数据 按行查看 历史
vegee 提交于 2021-01-14 03:30 . 修正部分文件,为工程增加说明
# -*- coding: UTF-8 -*-
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__(self, upscale_factor):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(32, 1 * (upscale_factor ** 2), (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = torch.tanh(self.conv1(x))
x = torch.tanh(self.conv2(x))
x = torch.sigmoid(self.pixel_shuffle(self.conv3(x)))
return x
if __name__ == "__main__":
model = Net(upscale_factor=3)
print(model)
# 用此处代码测试时要改模型,网络的input_channel = 3
oritensor = torch.randn(1, 3, 33, 33)
oritensor = torch.clamp(oritensor, 0, 1)
newtensor = torch.clamp(model(oritensor), 0, 1)
orinumpy = oritensor.detach().squeeze().permute(1, 2, 0).numpy()
newnumpy = newtensor.detach().squeeze().permute(1, 2, 0).numpy()
plt.imshow(orinumpy)
plt.title('original noise')
plt.show()
plt.imshow(newnumpy)
plt.title('reconstruct noise')
plt.show()
print(orinumpy.shape)
print(newnumpy.shape)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/stawary/ESPCN_Learning.git
git@gitee.com:stawary/ESPCN_Learning.git
stawary
ESPCN_Learning
ESPCN_Learning
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385