1 Star 0 Fork 0

mona2544/adv_train_wbc

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
exp2_plot_spatial_STM2+BO.py 19.15 KB
一键复制 编辑 原始数据 按行查看 历史
mona2544 提交于 2023-03-05 19:32 . asdf
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572
# %%
from cv2 import rotate
from utils.plotFunction import test_report
from statistics import mean
from foolbox import accuracy
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import (
accuracy_score,
auc,
average_precision_score,
precision_recall_curve,
roc_auc_score,
roc_curve,
)
from utils.tools import caculate_all_accuracy, caculate_all_roi, caculate_sAdd_accuracy
CUDA_VISIBLE_DEVICES = 0
plt.switch_backend('agg')
res_dict = {}
res_dict_change = {}
model_dict = {
# =================== Standard
"Prob_clean_EfficientNet-B3": "clean_EfficientNet-B3",
"Prob_spatial_STM+BO_30.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_60.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_90.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_120.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_150.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_180.0_1000_1.0_-1_EfficientNet-B3": "tmp",
# trans
"Prob_spatial_STM+BO_1.0_-1_2.0_1000_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_4.0_1000_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_6.0_1000_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_8.0_1000_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_10.0_1000_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_12.0_1000_EfficientNet-B3": "tmp",
# =================== AT_STM
"Prob_clean_['', 'STM']_120_withoutTrades_EfficientNet-B3": "AT_STM+BO_EfficientNet-B3",
"Prob_spatial_STM+BO_30.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_60.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_90.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_120.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_150.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_180.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
# trans
"Prob_spatial_STM+BO_1.0_-1_2.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_4.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_6.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_8.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_10.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Prob_spatial_STM+BO_1.0_-1_12.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
}
model_dict_change = {
# =================== Standard
"Change_spatial_STM+BO_30.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_60.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_90.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_120.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_150.0_1000_1.0_-1_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_180.0_1000_1.0_-1_EfficientNet-B3": "tmp",
# trans
"Change_spatial_STM+BO_1.0_-1_2.0_1000_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_4.0_1000_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_6.0_1000_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_8.0_1000_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_10.0_1000_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_12.0_1000_EfficientNet-B3": "tmp",
# =================== AT_STM
"Change_spatial_STM+BO_30.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_60.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_90.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_120.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_150.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_180.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
# trans
"Change_spatial_STM+BO_1.0_-1_2.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_4.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_6.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_8.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_10.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
"Change_spatial_STM+BO_1.0_-1_12.0_1000_['', 'STM']_120_withoutTrades_EfficientNet-B3": "tmp",
}
color_dict = {
# 'Prob_clean_ResNet18': '#e30039', # 亮红
# "Prob_clean_ResNet50": "#00994e",
# "Prob_clean_SE-ResNet50Xt": "#00a8e1",
# "Prob_clean_VGG16": "#fcd300",
"Prob_clean_EfficientNet-B3": "#99cc00",
"Prob_clean_['', 'STM']_120_withoutTrades_EfficientNet-B3": "#fcd300",
}
for m in model_dict:
res_dict[m + '_label'] = pd.read_csv(
f'./result/midLog/spatial/exp2_spatial_useSTMtrain_useSTM+BOattack/{m}.csv', header=None, index_col=None, sep=' '
).to_numpy()[:, 0]
res_dict[m + '_proba'] = pd.read_csv(
f'./result/midLog/spatial/exp2_spatial_useSTMtrain_useSTM+BOattack/{m}.csv', header=None, index_col=None, sep=' '
).to_numpy()[:, 2]
res_df = pd.DataFrame(res_dict)
for m in model_dict_change:
res_dict_change[m + '_label'] = pd.read_csv(
f'./result/midLog/spatial/exp2_spatial_useSTMtrain_useSTM+BOattack/{m}.csv', header=None, index_col=None, sep=' '
).to_numpy()[:, 0]
res_dict_change[m + '_oriPred'] = pd.read_csv(
f'./result/midLog/spatial/exp2_spatial_useSTMtrain_useSTM+BOattack/{m}.csv', header=None, index_col=None, sep=' '
).to_numpy()[:, 1]
res_dict_change[m + '_advPred'] = pd.read_csv(
f'./result/midLog/spatial/exp2_spatial_useSTMtrain_useSTM+BOattack/{m}.csv', header=None, index_col=None, sep=' '
).to_numpy()[:, 2]
res_df_change = pd.DataFrame(res_dict_change)
# %% ROC
baseline_models = ["EfficientNet-B3",
"['', 'STM']_120_withoutTrades_EfficientNet-B3"]
def plot_roc(name, labels, predictions, **kwargs):
fp, tp, _ = roc_curve(labels, predictions)
# alpha_value = 1 if name == 'Proposed method' else 0.7
lw = 1.5 if name == 'Proposed method' else 1
ls = 'dotted' if (name == 'PredPHI' or name == 'PHIAF') else '-'
plt.plot(
fp,
tp,
label=name + f' (AUC = {roc_auc_score(labels, predictions):.3f})',
linewidth=lw,
linestyle=ls,
**kwargs,
)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=13)
for m in baseline_models:
perfix = "Prob_clean_"
plot_roc(
model_dict[perfix+m],
res_df[perfix+m + '_label'][1:].astype(float).astype(int),
res_df[perfix+m + '_proba'][1:].astype(float),
color=color_dict[perfix+m],
)
plt.title('ROC Curve', fontsize=18)
# plt.savefig('../result/ROC_Curve.eps', dpi=600, format='eps', bbox_inches = 'tight')
plt.show()
# %% PRC
baseline_models = ["EfficientNet-B3",
"['', 'STM']_120_withoutTrades_EfficientNet-B3"]
def plot_prc(name, labels, predictions, **kwargs):
precision, recall, _ = precision_recall_curve(labels, predictions)
lw = 1.5 if name == 'Proposed method' else 0.8
ls = 'dotted' if (name == 'PredPHI' or name == 'PHIAF') else '-'
plt.plot(
precision,
recall,
label=name +
f' (AUPR = {average_precision_score(labels, predictions):.3f})',
linewidth=lw,
linestyle=ls,
**kwargs,
)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.xlim([0.4, 1])
plt.ylim([0, 1])
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=13)
# ax = plt.gca()
# ax.set_aspect('equal')
for m in baseline_models:
perfix = "Prob_clean_"
plot_prc(
model_dict[perfix+m],
res_df[perfix+m + '_label'][1:].astype(float).astype(int),
res_df[perfix+m + '_proba'][1:].astype(float),
color=color_dict[perfix+m],
)
plt.title('PR Curve', fontsize=18)
# plt.savefig('../result/PR_Curve.eps', dpi=600, format='eps', bbox_inches = 'tight')
plt.show()
# %% test result csv 2 df ------------------------------
baseline_models = ["EfficientNet-B3",
"['', 'STM']_120_withoutTrades_EfficientNet-B3"]
results = pd.DataFrame()
for m in baseline_models:
perfix = "Prob_clean_"
r = test_report(
# y_test=res_dict[m + '_label'],
# y_pred_probe=res_dict[m + '_proba'],
res_df[perfix+m + '_label'][1:].astype(float).astype(int),
res_df[perfix+m + '_proba'][1:].astype(float),
print_flag=False,
)
results = results.append(pd.DataFrame(r).T, ignore_index=True)
results.columns = [
'TN',
'FP',
'FN',
'TP',
'Precision_0',
'Recall_0',
'F1_0',
'Precision_1',
'Recall_1',
'F1_1',
'Accuracy',
'AUC',
'AUPR',
'KS-Value',
'MCC',
'G-Mean',
]
results = results.round(3)
results.index = baseline_models
results
# For Latex input
results[
['Accuracy', 'AUC', 'Precision_0', 'Precision_1', 'Recall_0', 'Recall_1',
'F1_0', 'F1_1',
'TN',
'FP',
'FN',
'TP', ]
]
# %% spatial STM(rotate)
mdb5 = [
"clean_EfficientNet-B3",
"spatial_STM+BO_30.0_1000_1.0_-1_EfficientNet-B3",
"spatial_STM+BO_60.0_1000_1.0_-1_EfficientNet-B3",
"spatial_STM+BO_90.0_1000_1.0_-1_EfficientNet-B3",
"spatial_STM+BO_120.0_1000_1.0_-1_EfficientNet-B3",
"spatial_STM+BO_150.0_1000_1.0_-1_EfficientNet-B3",
"spatial_STM+BO_180.0_1000_1.0_-1_EfficientNet-B3",
]
mdb55 = [
# ========== AT_STM
"clean_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_30.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_60.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_90.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_120.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_150.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_180.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
]
md1 = [mdb5, mdb55]
model_names = ["EfficientNet-B3",
"['', 'STM']_120_withoutTrades_EfficientNet-B3"]
def plot_diff_rotate_acc(
name,
y_label,
x_label,
all_accuracy,
rotate_range,
**kwargs):
# print(all_accuracy)
# lw = 1.5 if name == 'Proposed method' else 0.8
# ls = 'dotted' if (name == 'PredPHI' or name == 'PHIAF') else '-'
lw = 1.5
ls = '-'
def getName(name):
if name == "EfficientNet-B3":
return name
elif name == "['', 'STM']_120_withoutTrades_EfficientNet-B3":
return "AT_STM_EfficientNet-B3"
plt.plot(
rotate_range,
all_accuracy,
label=getName(name),
marker='v' if name == model_names[0] else 'o',
markersize='10',
linewidth=lw,
linestyle=ls,
**kwargs,
)
plt.xlabel(x_label, fontsize=18)
plt.ylabel(y_label, fontsize=18)
# plt.xlim([eps_range[0]-2/255, eps_range[-1]+2/255])
plt.xticks(rotate_range, ("0", "30", "60", "90",
"120", "150", "180"), fontsize=18)
# plt.ylim([0.4, 1])
plt.yticks([0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1], fontsize=18)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=13)
def plot_diff_rotate_roi(
name,
all_roi,
rotate_range,
**kwargs):
# lw = 1.5 if name == 'Proposed method' else 0.8
# ls = 'dotted' if (name == 'PredPHI' or name == 'PHIAF') else '-'
lw = 1.5
ls = '-'
def getName(name):
if name == "EfficientNet-B3":
return name
elif name == "['', 'STM']_120_withoutTrades_EfficientNet-B3":
return "AT_STM_EfficientNet-B3"
plt.plot(
rotate_range,
all_roi,
label=getName(name),
marker='v' if name == model_names[0] else 'o',
markersize='10',
linewidth=lw,
linestyle=ls,
**kwargs,
)
plt.xlabel('Rotation range', fontsize=18)
plt.ylabel('Rate of invariance', fontsize=18)
plt.xticks(rotate_range, ("0", "30", "60", "90",
"120", "150", "180"), fontsize=18)
# plt.ylim([0.7, 1])
plt.yticks([0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], fontsize=18)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=13)
# ===============================================Accuracy
for i, mb in enumerate(md1):
plot_diff_rotate_acc(
model_names[i],
y_label="Accuracy",
x_label="Rotate range",
all_accuracy=caculate_all_accuracy(mb, res_df),
rotate_range=[0, 30, 60, 90, 120, 150, 180],
color=color_dict["Prob_clean_"+model_names[i]],
)
plt.title('STM+BO Attack', fontsize=18)
plt.savefig("./result/final_result/STMtrain_Eval_STM+BO_Rotate_acc.eps",
dpi=600, format='eps', bbox_inches='tight')
plt.show()
# =============================================== S+ Accuracy (S+表示正确分类的样本)
for i, mb in enumerate(md1):
plot_diff_rotate_acc(
model_names[i],
y_label="S+ Accuracy",
x_label="Rotate range",
all_accuracy=caculate_sAdd_accuracy(mb[1:], res_df_change),
rotate_range=[0, 30, 60, 90, 120, 150, 180],
color=color_dict["Prob_clean_"+model_names[i]],
)
plt.title('STM+BO Attack', fontsize=18)
plt.savefig("./result/final_result/STMtrain_Eval_STM+BO_S+_Rotate_acc.eps",
dpi=600, format='eps', bbox_inches='tight')
plt.show()
# ===============================================Rate of invariance
# 绘制不变性指标
for (i, mb) in enumerate(md1):
plot_diff_rotate_roi(
model_names[i],
caculate_all_roi(mb[1:], res_df_change),
rotate_range=[0, 30, 60, 90, 120, 150, 180],
color=color_dict["Prob_clean_"+model_names[i]],
)
plt.title('STM+BO Attack', fontsize=18)
plt.savefig("./result/final_result/STMtrain_Eval_STM+BO_Rotate_invariance.eps",
dpi=600, format='eps', bbox_inches='tight')
plt.show()
# %% spatial STM(transition)
mdb10 = [
"clean_EfficientNet-B3",
# trans
"spatial_STM+BO_1.0_-1_2.0_1000_EfficientNet-B3",
"spatial_STM+BO_1.0_-1_4.0_1000_EfficientNet-B3",
"spatial_STM+BO_1.0_-1_6.0_1000_EfficientNet-B3",
"spatial_STM+BO_1.0_-1_8.0_1000_EfficientNet-B3",
"spatial_STM+BO_1.0_-1_10.0_1000_EfficientNet-B3",
"spatial_STM+BO_1.0_-1_12.0_1000_EfficientNet-B3",
]
mdb100 = [
# ========== AT_STM
"clean_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_30.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_60.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_90.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_120.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_150.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
"spatial_STM+BO_180.0_1000_1.0_-1_['', 'STM']_120_withoutTrades_EfficientNet-B3",
]
md2 = [mdb10, mdb100]
model_names = ["EfficientNet-B3",
"['', 'STM']_120_withoutTrades_EfficientNet-B3"]
def plot_diff_trans_roi(
name,
all_roi,
trans_range,
**kwargs):
# lw = 1.5 if name == 'Proposed method' else 0.8
# ls = 'dotted' if (name == 'PredPHI' or name == 'PHIAF') else '-'
lw = 1.5
ls = '-'
def getName(name):
if name == "EfficientNet-B3":
return name
elif name == "['', 'STM']_120_withoutTrades_EfficientNet-B3":
return "AT_STM_EfficientNet-B3"
plt.plot(
trans_range,
all_roi,
label=getName(name),
marker='v' if name == model_names[0] else 'o',
markersize='10',
linewidth=lw,
linestyle=ls,
**kwargs,
)
plt.xlabel('Transition range', fontsize=18)
plt.ylabel('Rate of invariance', fontsize=18)
plt.xticks(trans_range, ("0", "2%", "4%", "6%",
"8%", "10%", "12%"), fontsize=18)
# plt.ylim([0.7, 1])
plt.yticks([0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], fontsize=18)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=13)
def plot_diff_trans_acc(
name,
y_label,
x_label,
all_accuracy,
trans_range,
**kwargs):
# lw = 1.5 if name == 'Proposed method' else 0.8
# ls = 'dotted' if (name == 'PredPHI' or name == 'PHIAF') else '-'
lw = 1.5
ls = '-'
def getName(name):
if name == "EfficientNet-B3":
return name
elif name == "['', 'STM']_120_withoutTrades_EfficientNet-B3":
return "AT_STM_EfficientNet-B3"
plt.plot(
trans_range,
all_accuracy,
label=getName(name),
marker='v' if name == model_names[0] else 'o',
markersize='10',
linewidth=lw,
linestyle=ls,
**kwargs,
)
plt.xlabel(x_label, fontsize=18)
plt.ylabel(y_label, fontsize=18)
# plt.xlim([eps_range[0]-2/255, eps_range[-1]+2/255])
plt.xticks(trans_range, ("0", "2%", "4%", "6%",
"8%", "10%", "12%"), fontsize=18)
# plt.ylim([0.7, 1])
plt.yticks([0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1], fontsize=18)
plt.grid(True, linestyle='--', alpha=0.5)
plt.legend(fontsize=13)
# =============================================== Accuracy
for i, mb in enumerate(md2):
plot_diff_trans_acc(
model_names[i],
y_label="Accuracy",
x_label="Transition range",
all_accuracy=caculate_all_accuracy(mb, res_df),
trans_range=[0, 2, 4, 6, 8, 10, 12],
color=color_dict["Prob_clean_"+model_names[i]],
)
plt.title('STM+BO Attack', fontsize=18)
plt.savefig("./result/final_result/STMtrain_Eval_STM+BO_Trans_acc.eps",
dpi=600, format='eps', bbox_inches='tight')
plt.show()
# =============================================== S+ Accuracy (S+表示正确分类的样本)
for i, mb in enumerate(md2):
plot_diff_trans_acc(
model_names[i],
y_label="S+ Accuracy",
x_label="Transition range",
all_accuracy=caculate_sAdd_accuracy(mb[1:], res_df_change),
trans_range=[0, 2, 4, 6, 8, 10, 12],
color=color_dict["Prob_clean_"+model_names[i]],
)
plt.title('STM+BO Attack', fontsize=18)
plt.savefig("./result/final_result/STMtrain_Eval_STM+BO_S+_Trans_acc.eps",
dpi=600, format='eps', bbox_inches='tight')
plt.show()
# =============================================== Rate of invariance
for (i, mb) in enumerate(md2):
plot_diff_trans_roi(
model_names[i],
caculate_all_roi(mb[1:], res_df_change),
trans_range=[0, 2, 4, 6, 8, 10, 12],
color=color_dict["Prob_clean_"+model_names[i]],
)
plt.title('STM+BO Attack', fontsize=18)
plt.savefig("./result/final_result/STMtrain_Eval_STM+BO_Trans_invariance.eps",
dpi=600, format='eps', bbox_inches='tight')
plt.show()
# %%
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/mona2544/adv_train_wbc.git
git@gitee.com:mona2544/adv_train_wbc.git
mona2544
adv_train_wbc
adv_train_wbc
master

搜索帮助