1 Star 0 Fork 0

123/MAML-Pytorch

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
miniimagenet_train.py 4.19 KB
一键复制 编辑 原始数据 按行查看 历史
dragen_d592 提交于 2018-12-19 16:25 . resolve bug:losses_q.append(loss_q)
import torch, os
import numpy as np
from MiniImagenet import MiniImagenet
import scipy.stats
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import random, sys, pickle
import argparse
from meta import Meta
def mean_confidence_interval(accs, confidence=0.95):
n = accs.shape[0]
m, se = np.mean(accs), scipy.stats.sem(accs)
h = se * scipy.stats.t._ppf((1 + confidence) / 2, n - 1)
return m, h
def main():
torch.manual_seed(222)
torch.cuda.manual_seed_all(222)
np.random.seed(222)
print(args)
config = [
('conv2d', [32, 3, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 2, 0]),
('conv2d', [32, 32, 3, 3, 1, 0]),
('relu', [True]),
('bn', [32]),
('max_pool2d', [2, 1, 0]),
('flatten', []),
('linear', [args.n_way, 32 * 5 * 5])
]
device = torch.device('cuda')
maml = Meta(args, config).to(device)
tmp = filter(lambda x: x.requires_grad, maml.parameters())
num = sum(map(lambda x: np.prod(x.shape), tmp))
print(maml)
print('Total trainable tensors:', num)
# batchsz here means total episode number
mini = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='train', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=10000, resize=args.imgsz)
mini_test = MiniImagenet('/home/i/tmp/MAML-Pytorch/miniimagenet/', mode='test', n_way=args.n_way, k_shot=args.k_spt,
k_query=args.k_qry,
batchsz=100, resize=args.imgsz)
for epoch in range(args.epoch//10000):
# fetch meta_batchsz num of episode each time
db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)
for step, (x_spt, y_spt, x_qry, y_qry) in enumerate(db):
x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)
accs = maml(x_spt, y_spt, x_qry, y_qry)
if step % 30 == 0:
print('step:', step, '\ttraining acc:', accs)
if step % 500 == 0: # evaluation
db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
accs_all_test = []
for x_spt, y_spt, x_qry, y_qry in db_test:
x_spt, y_spt, x_qry, y_qry = x_spt.squeeze(0).to(device), y_spt.squeeze(0).to(device), \
x_qry.squeeze(0).to(device), y_qry.squeeze(0).to(device)
accs = maml.finetunning(x_spt, y_spt, x_qry, y_qry)
accs_all_test.append(accs)
# [b, update_step+1]
accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
print('Test acc:', accs)
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--epoch', type=int, help='epoch number', default=60000)
argparser.add_argument('--n_way', type=int, help='n way', default=5)
argparser.add_argument('--k_spt', type=int, help='k shot for support set', default=1)
argparser.add_argument('--k_qry', type=int, help='k shot for query set', default=15)
argparser.add_argument('--imgsz', type=int, help='imgsz', default=84)
argparser.add_argument('--imgc', type=int, help='imgc', default=3)
argparser.add_argument('--task_num', type=int, help='meta batch size, namely task num', default=4)
argparser.add_argument('--meta_lr', type=float, help='meta-level outer learning rate', default=1e-3)
argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
args = argparser.parse_args()
main()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/changqixin/MAML-Pytorch.git
git@gitee.com:changqixin/MAML-Pytorch.git
changqixin
MAML-Pytorch
MAML-Pytorch
master

搜索帮助