代码拉取完成,页面将自动刷新
from keras.layers import Input, Dense, GlobalMaxPool2D, GlobalAvgPool2D, Concatenate, Multiply, Dropout, Subtract, Add
from keras.models import Model
from keras_vggface.utils import preprocess_input
from keras_vggface.vggface import VGGFace
from keras.optimizers import Adam
from keras.preprocessing import image
import numpy as np
from random import choice, sample
import pandas as pd
import os
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
import gc
from keras import backend as K
from keras import models
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
import tensorflow as tf
import sys
threshold = 0.49847739934921265
def log_file(path: str, data: list) -> None:
if type(data) is not str:
_data = str(data)
else:
_data = data
with open(path, "a+") as log:
log.write(_data + '\n')
res = {
'FMS': [0, 0],
'FMD': [0, 0],
'total': [0, 0]
}
def focal_loss(gamma=2., alpha=.25):
def focal_loss_fixed(y_true, y_pred):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
return -K.sum(alpha * K.pow(1. - pt_1, gamma) * K.log(K.epsilon() + pt_1)) - K.sum(
(1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
return focal_loss_fixed
def baseline_model(model_name):
model = models.load_model(model_name)
return model
def get_test():
global res
test_file_path = "./test.txt"
test = []
f = open(test_file_path, "r+", encoding='utf-8')
while True:
line = f.readline().replace('\n', '')
if not line:
break
else:
data = line.split('\t')
prefix = 'test/test_faces/'
data[0] = prefix + data[0]
data[1] = prefix + data[1]
data[2] = prefix + data[2]
test.append(data)
f.close()
res['total'][0] = len(test)
for now in test:
res[now[3]][0] = res[now[3]][0] + 1
return test, len(test)
def read_img(path):
img = image.load_img(path, target_size=(224, 224))
img = np.array(img).astype(np.float)
return preprocess_input(img, version=2)
def gen(list_tuples, batch_size, total):
start = 0
while True:
if start + batch_size < total:
end = start + batch_size
else:
end = total
batch_list = list_tuples[start:end]
datas = []
labels = []
classes = []
for now in batch_list:
datas.append([now[0], now[1], now[2]])
labels.append(int(now[4]))
classes.append(now[3])
X1 = np.array([read_img(x[0]) for x in datas])
X2 = np.array([read_img(x[1]) for x in datas])
X3 = np.array([read_img(x[2]) for x in datas])
yield [X1, X2, X3], np.array(labels), classes
start = end
if start == total:
yield None, None, None
gc.collect()
K.clear_session()
def cal(pred, labels, classes):
global res
global threshold
for i in range(len(pred)):
if pred[i] >= threshold:
p = 1
else:
p = 0
if p == labels[i]:
res['total'][1] = res['total'][1] + 1
res[classes[i]][1] = res[classes[i]][1] + 1
if __name__ == "__main__":
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# model = baseline_model('final_resnet50_bce_old_image_2.h5')
model = models.load_model("final_resnet50_focal_old_image_2.h5", custom_objects={'focal_loss_fixed': focal_loss()})
test, total = get_test()
for datas, labels, classes in gen(test, 30, total):
if datas != None:
pred = model(datas).numpy()
cal(pred, labels, classes)
else:
break
for key in res:
print(key, ':', res[key][1] / res[key][0])
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。