1 Star 0 Fork 0

cassuto/DEEPDUPA

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Res_tar_final.py 6.77 KB
一键复制 编辑 原始数据 按行查看 历史
asrakin 提交于 2021-05-01 23:49 . Code released version
from __future__ import print_function
import numpy as np
import pandas as pd
import torch.nn as nn
import math
import torch.nn.functional as F
import torch
from torch.nn import init
from collections import OrderedDict
import time
import shutil
import xlwt
from xlwt import Workbook
import argparse
import torch.optim as optim
from torchvision import datasets, transforms
# from utils import AverageMeter, RecorderMeter, time_string, convert_secs2time
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import random
random.seed(6)
from torch.autograd import Variable
from torchvision import models
import torch.nn.functional as F
import torch as th
from module import validate,validate1,bin2int,weight_conversion,int2bin
from model import vgg11_bn,quan_Linear,quan_Conv2d,ResNetBasicblock,DownsampleA,CifarResNet
from attack import DES_new
import argparse
parser = argparse.ArgumentParser(description='Deep Dup A')
parser.add_argument('--iteration', type=int, default=1000, help='Attack Iterations')
parser.add_argument('--z', type=int, default=500, help='evolution z')
parser.add_argument('--batch-size', type=int, default=256, help='input batch size for 256 default')
parser.add_argument('--probab', type=float, default=1, help='probability of a successfull hardware AWD attack at a location')
parser.add_argument('--data', type=str, default='./cifar10', help='data path')
parser.add_argument('--target', type=int, default=8, help='Target Class')
args = parser.parse_args()
print(args)
# datapath for the workstation
dataset_path= args.data
# ---------------------- Hyper Parameter ---------------------------
iteration = args.iteration ## number of attack iteration
picks = args.z # number of weights picked initially
weight_p_clk = 2 ## number of weights at each package constant throughout the paper
shift_p_clk = 1 ## number of clock shift at each iteration constant thourghout the paper
evolution = args.z ## number of evolution = picks = number of initial candidate chosen =z
targeted = args.target ## target class
BATCH_SIZE =args.batch_size ## batch_size
probab = args.probab # AWD success probability $f_p$
# ------------------------------- model -------------------------------
model = CifarResNet(ResNetBasicblock, 20, 10)
model=model.cuda()
criterion = torch.nn.CrossEntropyLoss()
criterion=criterion.cuda()
# ---------------------------------- Data loading -------------------------------------
device=1
mean = [0.4914, 0.4822, 0.4465]
std = [0.2023, 0.1994, 0.2010]
train_transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(dataset_path, train=True, download=True,
transform=train_transform),
batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(dataset_path, train=False,
transform=test_transform),
batch_size=BATCH_SIZE, shuffle=False)
criterion = torch.nn.CrossEntropyLoss()
criterion=criterion.cuda()
#------------------------------- model loading ----------------------------------------------------
# model.load_state_dict(torch.load('./cifar_vgg_pretrain.pt', map_location='cpu'))
pretrained_dict = torch.load('Resnet20_8_0.pkl')
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
n=0
# update the step size before validation
for m in model.modules():
if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
n=n+1
print(m.weight.size(),n)
m.__reset_stepsize__()
m.__reset_weight__()
weight_conversion(model)
validate(model, device, criterion, test_loader, 0)
# see: https://discuss.pytorch.org/t/what-does-model-eval-do-for-batchnorm-layer/7146
model.eval()
import copy
model1=copy.deepcopy(model)
for batch_idx, (data, target) in enumerate(test_loader):
data, target = data, target
break
# ----------------------------- Attack Setup -------------------------------------------
attacker = DES_new(criterion, k_top=picks, w_clk=weight_p_clk, s_clk=shift_p_clk,evolution= evolution,probab=probab)
xs=[]
ys=[]
ASR=torch.zeros([iteration])
acc=torch.zeros([iteration])
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(dataset_path, train=False,
transform=test_transform),
batch_size=1, shuffle=False)
# ------------------------------------------------------------ Data division -------------------------------------------------------------
datas=torch.zeros([256,3,32,32]) # attack batch
targets=torch.zeros([256])
datas1=torch.zeros([500,3,32,32]) # evaluation batch
targets1=torch.zeros([500])
count=0
for batch_idx, (data, target) in enumerate(test_loader):
if target == targeted:
if count < 256:
datas[count,:,:,:]=data[0,:,:,:]
targets[count] = target[0]
if count >= 500:
datas1[count-500,:,:,:]=data[0,:,:,:]
targets1[count-500] = target[0]
count = count + 1
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(dataset_path, train=False,
transform=test_transform),
batch_size=256, shuffle=False)
# ------------------------------------------------------------ Attacking -------------------------------------------------------------
for i in range(iteration):
print("epoch:",i+1)
xs,ys=attacker.progressive_search(model.cuda(), datas.cuda(), targets.long().cuda(),xs,ys)
#print(xs[i],ys[i])
print("Test Accuracy of Target Class (%)")
_,ASR[i]=validate(model, device, criterion, test_loader, 0)
print("Test Accuracy of Target Class (%)")
_,acc[i] = validate1(model, device, criterion, test_loader,datas1.cuda(),targets1.long().cuda(), 0)
if float(acc[i])< 2.00:
break
## finally printing out exactly how many weights different compared to the original model
i=0
for name, m in model.named_modules():
if isinstance(m, quan_Conv2d) or isinstance(m, quan_Linear):
i=i+1
j=0
for name1, h in model1.named_modules():
if isinstance(h, quan_Conv2d) or isinstance(h, quan_Linear):
j=j+1
if i==j:
zz=m.weight.data-h.weight.data
print(zz[zz!=0].size())
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/nullptr12/DEEPDUPA.git
git@gitee.com:nullptr12/DEEPDUPA.git
nullptr12
DEEPDUPA
DEEPDUPA
main

搜索帮助

0d507c66 1850385 C8b1a773 1850385