代码拉取完成,页面将自动刷新
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
from matplotlib.font_manager import FontProperties
font = FontProperties(fname='C:/Windows/Fonts/STXINGKA.ttf')
# 生成数据
def gen_datas():
data = pd.read_csv("housing.csv")
data = np.array(data)
data1 = pd.read_csv("001.csv")
data1 = np.array(data1)
print(data1)
inputs = data[:, 0:data.shape[1] - 1]
outputs = data[:, data.shape[1] - 1]
# inputs = np.linspace(-1, 1, 250, dtype=np.float32)[:, np.newaxis]
# noise = np.random.normal(0, 0.05, inputs.shape).astype(np.float32)
# outputs = np.square(inputs) - 0.5*inputs + noise
return data1, inputs, outputs
def complie_model(model):
# 神经网络参数配置
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01), loss='mse', metrics=["mae", "mse"])
def create_model():
# 创建新神经网络
model = tf.keras.Sequential([tf.keras.layers.Dense(10, activation=tf.nn.relu, input_shape=(3,), name='layer1'), tf.keras.layers.Dense(1, name="outputs")])
complie_model(model)
return model
def display_nn_structure(model, nn_structure_path):
# 展示神经网络结构
model.summary()
keras.utils.plot_model(model, nn_structure_path, show_shapes=True)
def callback_only_params(model_path):
# 保存模型回调参数
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(filepath=model_path, verbose=1, save_best_only=True, save_freq='epoch')
return ckpt_callback
def tb_callback(model_path):
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=model_path, histogram_freq=1)
return tensorboard_callback
def plot_history(history):
# 绘制损失数据
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
# 打开绘图区
plt.figure()
plt.xlabel("训练次数", fontproperties=font)
plt.ylabel("损失值", fontproperties=font)
plt.plot(hist["epoch"], hist["mse"], label='loss')
plt.legend(prop=font)
# plt.savefig
plt.show()
def train_model(model, inputs, outputs, model_path, log_path):
# 训练神经网络
# 回调函数
ckpt_callback = callback_only_params(model_path)
tensorboard_callback = tb_callback(log_path)
# 保存参数
model.save_weights(model_path.format(epoch=0))
history = model.fit(inputs, outputs, epochs=10, callbacks=[ckpt_callback, tensorboard_callback], verbose=0)
plot_history(history)
def load_model(model, model_path):
latest = tf.train.latest_checkpoint(model_path)
print("latest:{}".format(latest))
# model.load_weights(latest)
def prediction(model, model_path, inputs):
load_model(model, model_path)
pres = model.predict(inputs)
# print("prediction:{}".format(pres))
return pres
def plot_prediction(model, model_path, inputs, outputs):
pres = prediction(model, model_path, inputs)
# plt.plot(outputs, s=10, c="r", marker="*", label="实际值")
plt.plot(outputs, label="实际值")
plt.plot(pres, label="预测结果")
plt.xlabel("输入数据", fontproperties=font)
plt.ylabel("预测值", fontproperties=font)
plt.legend(prop=font)
# plt.savefig()
plt.show()
print(pres)
if __name__ == "__main__":
stamp = datetime.now().strftime("%Y%m%d-%H:%M:%S")
model_path = "D:/Machine-Learning-Model/line-fit-high" # + stamp
log_path = "D:/Machine-Learning-Model/line-fit-high" # + stamp
data1, inputs, outputs = gen_datas()
model = create_model()
display_nn_structure(model, "D:/Machine-Learning-Model/001.png")
train_model(model, inputs, outputs, model_path, log_path)
model_path = "D:/Machine-Learning-Model/high"
plot_prediction(model, model_path, data1, outputs)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。