代码拉取完成,页面将自动刷新
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.autograd import Variable
import torch.nn.init as init
def to_var(x, requires_grad=True):
if torch.cuda.is_available():
x = x.cuda()
return Variable(x, requires_grad=requires_grad)
class MetaModule(nn.Module):
# adopted from: Adrien Ecoffet https://github.com/AdrienLE
def params(self):
for name, param in self.named_params(self):
yield param
def named_leaves(self):
return []
def named_submodules(self):
return []
def named_params(self, curr_module=None, memo=None, prefix=''):
if memo is None:
memo = set()
if hasattr(curr_module, 'named_leaves'):
for name, p in curr_module.named_leaves():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
else:
for name, p in curr_module._parameters.items():
if p is not None and p not in memo:
memo.add(p)
yield prefix + ('.' if prefix else '') + name, p
for mname, module in curr_module.named_children():
submodule_prefix = prefix + ('.' if prefix else '') + mname
for name, p in self.named_params(module, memo, submodule_prefix):
yield name, p
def update_params(self,
lr_inner,
first_order=False,
source_params=None,
detach=False):
if source_params is not None:
for tgt, src in zip(self.named_params(self), source_params):
name_t, param_t = tgt
# name_s, param_s = src
# grad = param_s.grad
# name_s, param_s = src
grad = src
if first_order:
grad = to_var(grad.detach().data)
tmp = param_t - lr_inner * grad
self.set_param(self, name_t, tmp)
else:
for name, param in self.named_params(self):
if not detach:
grad = param.grad
if first_order:
grad = to_var(grad.detach().data)
tmp = param - lr_inner * grad
self.set_param(self, name, tmp)
else:
param = param.detach_(
) # https://blog.csdn.net/qq_39709535/article/details/81866686
self.set_param(self, name, param)
def set_param(self, curr_mod, name, param):
if '.' in name:
n = name.split('.')
module_name = n[0]
rest = '.'.join(n[1:])
for name, mod in curr_mod.named_children():
if module_name == name:
self.set_param(mod, rest, param)
break
else:
setattr(curr_mod, name, param)
def detach_params(self):
for name, param in self.named_params(self):
self.set_param(self, name, param.detach())
def copy(self, other, same_var=False):
for name, param in other.named_params():
if not same_var:
param = to_var(param.data.clone(), requires_grad=True)
self.set_param(name, param)
class MetaLinear(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Linear(*args, **kwargs)
self.register_buffer('weight',
to_var(ignore.weight.data, requires_grad=True))
self.register_buffer('bias', to_var(ignore.bias.data, requires_grad=True))
def forward(self, x):
return F.linear(x, self.weight, self.bias)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaConv2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.Conv2d(*args, **kwargs)
self.in_channels = ignore.in_channels
self.out_channels = ignore.out_channels
self.stride = ignore.stride
self.padding = ignore.padding
self.dilation = ignore.dilation
self.groups = ignore.groups
self.kernel_size = ignore.kernel_size
self.register_buffer('weight',
to_var(ignore.weight.data, requires_grad=True))
if ignore.bias is not None:
self.register_buffer('bias', to_var(ignore.bias.data,
requires_grad=True))
else:
self.register_buffer('bias', None)
def forward(self, x):
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaConvTranspose2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.ConvTranspose2d(*args, **kwargs)
self.stride = ignore.stride
self.padding = ignore.padding
self.dilation = ignore.dilation
self.groups = ignore.groups
self.register_buffer('weight',
to_var(ignore.weight.data, requires_grad=True))
if ignore.bias is not None:
self.register_buffer('bias', to_var(ignore.bias.data,
requires_grad=True))
else:
self.register_buffer('bias', None)
def forward(self, x, output_size=None):
output_padding = self._output_padding(x, output_size)
return F.conv_transpose2d(x, self.weight, self.bias, self.stride,
self.padding, output_padding, self.groups,
self.dilation)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
class MetaBatchNorm2d(MetaModule):
def __init__(self, *args, **kwargs):
super().__init__()
ignore = nn.BatchNorm2d(*args, **kwargs)
self.num_features = ignore.num_features
self.eps = ignore.eps
self.momentum = ignore.momentum
self.affine = ignore.affine
self.track_running_stats = ignore.track_running_stats
if self.affine:
self.register_buffer('weight',
to_var(ignore.weight.data, requires_grad=True))
self.register_buffer('bias', to_var(ignore.bias.data,
requires_grad=True))
if self.track_running_stats:
self.register_buffer('running_mean', torch.zeros(self.num_features))
self.register_buffer('running_var', torch.ones(self.num_features))
else:
self.register_parameter('running_mean', None)
self.register_parameter('running_var', None)
def forward(self, x):
return F.batch_norm(x, self.running_mean, self.running_var, self.weight,
self.bias, self.training
or not self.track_running_stats, self.momentum,
self.eps)
def named_leaves(self):
return [('weight', self.weight), ('bias', self.bias)]
def _weights_init(m):
classname = m.__class__.__name__
# print(classname)
if isinstance(m, MetaLinear) or isinstance(m, MetaConv2d):
init.kaiming_normal(m.weight)
class LambdaLayer(MetaModule):
def __init__(self, lambd):
super(LambdaLayer, self).__init__()
self.lambd = lambd
def forward(self, x):
return self.lambd(x)
class BasicBlock(MetaModule):
expansion = 1
def __init__(self, in_planes, planes, stride=1, option='A'):
super(BasicBlock, self).__init__()
self.conv1 = MetaConv2d(in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn1 = MetaBatchNorm2d(planes)
self.conv2 = MetaConv2d(planes,
planes,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn2 = MetaBatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != planes:
if option == 'A':
self.shortcut = LambdaLayer(lambda x: F.pad(x[:, :, ::2, ::2], (
0, 0, 0, 0, planes // 4, planes // 4), "constant", 0))
elif option == 'B':
self.shortcut = nn.Sequential(
MetaConv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False), MetaBatchNorm2d(self.expansion * planes))
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out
class ResNet32(MetaModule):
def __init__(self, num_classes=10, block=BasicBlock, num_blocks=[5, 5, 5]):
super(ResNet32, self).__init__()
self.in_planes = 16
self.conv1 = MetaConv2d(3,
16,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = MetaBatchNorm2d(16)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.linear = MetaLinear(64, num_classes)
self.apply(_weights_init)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = F.avg_pool2d(out, out.size()[3])
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。