代码拉取完成,页面将自动刷新
同步操作将从 VIT/基于深度学习的课堂行为图像分类 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
import os
import pandas as pd
import tensorflow as tf
from util import get_data, F1_score
from tensorflow.python.keras.callbacks import EarlyStopping
# import keras_metrics as km
import keras_metrics as km
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.metrics import Precision, Recall, AUC
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.applications.resnet import ResNet152
from tensorflow.python.keras.applications.densenet import DenseNet121
from tensorflow.python.keras.applications.mobilenet import MobileNet
from tensorflow.python.keras.callbacks import EarlyStopping
import keras_metrics as km
import numpy as np
from keras.callbacks import Callback
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.layers import Dense, Flatten, BatchNormalization, MaxPooling2D, GlobalMaxPool2D
class Metrics(Callback):
def on_train_begin(self, logs={}):
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []
def on_epoch_end(self, epoch, logs={}):
val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round() ##.model
val_targ = self.validation_data[1] ###.model
_val_f1 = f1_score(val_targ, val_predict, average='micro')
_val_recall = recall_score(val_targ, val_predict, average=None) ###
_val_precision = precision_score(val_targ, val_predict, average=None) ###
self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)
print("— val_f1: %f " % _val_f1)
AUTOTUNE = tf.data.experimental.AUTOTUNE
if __name__ == '__main__':
train_ds, val_ds = get_data()
mobile_net = MobileNet(input_shape=(224, 224, 3), include_top=False)
# 固定参数
mobile_net.trainable = False
model = Sequential([
mobile_net,
GlobalMaxPool2D(),
Flatten(),
Dense(1000, activation='relu'),
BatchNormalization(),
Dense(200, activation='relu'),
BatchNormalization(),
Dense(4, activation='softmax')])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy', F1_score(), Recall(), Precision(), AUC()])
early_stopping = EarlyStopping(
monitor='val_accuracy',
verbose=1,
patience=80,
restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(min_lr=0.00001,
factor=0.2)
history = model.fit(train_ds, epochs=2000, callbacks=[early_stopping, reduce_lr], validation_data=val_ds)
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('mobileNet_history.csv')
model.save('mobileNet.h5')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。