代码拉取完成,页面将自动刷新
from sklearn.metrics import mean_squared_error
import numpy as np
import matplotlib.pyplot as plt
def pred(input_data, target, weights):
return ((input_data * weights).sum())
def get_slope(input_data, target, weights):
preds = pred(input_data, target, weights)
error = target - preds
slope = 2 * input_data * error
return slope
def get_mse(input_data, target, weights):
preds = pred(input_data, target, weights)
return mean_squared_error([preds], [target])
weights = np.array([0, 2, 1])
input_data = np.array([1, 2, 3])
target = 0
n_updates = 20
mse_hist = []
for i in range(n_updates):
slope = get_slope(input_data, target, weights)
weights = weights + (learning_rate * slope)
mse = get_mse(input_data, target, weights)
mse_hist.append(mse)
plt.plot(mse_hist)
plt.xlabel('Iterations')
plt.ylabel('Mean Squared Error')
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。