1 Star 0 Fork 5

lss616263/基于深度学习的课堂行为图像分类

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
train_mobileNet.py 2.91 KB
一键复制 编辑 原始数据 按行查看 历史
VIT 提交于 2023-04-15 02:04 . 模型训练-v1.0
import os
import pandas as pd
import tensorflow as tf
from util import get_data, F1_score
from tensorflow.python.keras.callbacks import EarlyStopping
# import keras_metrics as km
import keras_metrics as km
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.metrics import Precision, Recall, AUC
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.applications.resnet import ResNet152
from tensorflow.python.keras.applications.densenet import DenseNet121
from tensorflow.python.keras.applications.mobilenet import MobileNet
from tensorflow.python.keras.callbacks import EarlyStopping
import keras_metrics as km
import numpy as np
from keras.callbacks import Callback
from sklearn.metrics import f1_score, precision_score, recall_score
from tensorflow.python.keras.layers import Dense, Flatten, BatchNormalization, MaxPooling2D, GlobalMaxPool2D
class Metrics(Callback):
def on_train_begin(self, logs={}):
self.val_f1s = []
self.val_recalls = []
self.val_precisions = []
def on_epoch_end(self, epoch, logs={}):
val_predict = (np.asarray(self.model.predict(self.validation_data[0]))).round() ##.model
val_targ = self.validation_data[1] ###.model
_val_f1 = f1_score(val_targ, val_predict, average='micro')
_val_recall = recall_score(val_targ, val_predict, average=None) ###
_val_precision = precision_score(val_targ, val_predict, average=None) ###
self.val_f1s.append(_val_f1)
self.val_recalls.append(_val_recall)
self.val_precisions.append(_val_precision)
print("— val_f1: %f " % _val_f1)
AUTOTUNE = tf.data.experimental.AUTOTUNE
if __name__ == '__main__':
train_ds, val_ds = get_data()
mobile_net = MobileNet(input_shape=(224, 224, 3), include_top=False)
# 固定参数
mobile_net.trainable = False
model = Sequential([
mobile_net,
GlobalMaxPool2D(),
Flatten(),
Dense(1000, activation='relu'),
BatchNormalization(),
Dense(200, activation='relu'),
BatchNormalization(),
Dense(4, activation='softmax')])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy', F1_score(), Recall(), Precision(), AUC()])
early_stopping = EarlyStopping(
monitor='val_accuracy',
verbose=1,
patience=80,
restore_best_weights=True
)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(min_lr=0.00001,
factor=0.2)
history = model.fit(train_ds, epochs=2000, callbacks=[early_stopping, reduce_lr], validation_data=val_ds)
hist_df = pd.DataFrame(history.history)
hist_df.to_csv('mobileNet_history.csv')
model.save('mobileNet.h5')
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/lss616263/Classroom-behavior-image-classification.git
git@gitee.com:lss616263/Classroom-behavior-image-classification.git
lss616263
Classroom-behavior-image-classification
基于深度学习的课堂行为图像分类
master

搜索帮助