代码拉取完成,页面将自动刷新
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
'''
@file :linear01.py
@author :cat
@version :1.0
'''
import numpy as np
import matplotlib.pyplot as plt
"""
为了能够很好地模拟真实样本的观测误差,我们给模型添加误差自变量𝜖,它采样自均值为 0,标准差为 1 的高斯分布
𝑦 = 1.477𝑥 + 0.089 + 𝜖, 𝜖 ∼ 𝒩(0 , 1)
:return:
"""
def make_data():
data = []
for _ in range(200):
x = np.random.uniform(-10,10)
y = 1.4777 * x + 0.089 + np.random.normal(0,0.1)
data.append([x,y])
# 可视化一下数据
# plt.scatter([i[0] for i in data],[j[1] for j in data])
# plt.show()
data = np.array(data)
# print(data.shape) # (200,2)
return data
# 计算误差
def mse_error(w,b,data):
total_error = 0.0
for i in range(len(data)):
x = data[i,0]
y = data[i,1]
error = (y - (x * w + b)) ** 2
total_error += error
return total_error / float(len(data))
# 计算梯度
def step_gradient(w_current,b_current,data,lr):
b_gradient = 0
w_gradient = 0
for i in range(0,len(data)):
x = data[i,0]
y = data[i,1]
b_gradient += 2 * ((w_current * x + b_current) - y) * 0.01
w_gradient += 2 * x * ((w_current * x + b_current) - y) * 0.01
new_b = b_current - (lr * b_gradient)
new_w = w_current - (lr * w_gradient)
return new_b,new_w
# 更新梯度
def gradient_descent(data, starting_b, starting_w, learning_rate, num_iterations):
"""
在计算出误差函数在𝑤和𝑏处的梯度后,我们可以根据式(2.1)来更新𝑤和𝑏的值。
我们把 对数据集的所有样本训练一次称为一个 Epoch,共循环迭代 num_iterations 个 Epoch
:param data:
:param starting_b:
:param starting_w:
:param learning_rate:
:param num_iterations:
:return:
"""
b =starting_b
w = starting_w
for step in range(num_iterations):
b,w = step_gradient(w,b,np.array(data),learning_rate)
loss = mse_error(b,w,data)
print(f'iteration:{step + 1},loss:{loss}')
return b,w
def main():
lr = 1e-2
initial_b = 0
initial_w = 0
num_iterations = 1000
data = make_data()
# 训练优化 1000 次,返回最优 w*,b*和训练 Loss 的下降过程
[b, w] = gradient_descent(data, initial_b, initial_w, lr, num_iterations)
loss = mse_error(b, w, data) # 计算最优数值解 w,b 上的均方差
print(f'Final loss:{loss}, w:{w}, b:{b}')
# 绘制一下最终的图像
plt.scatter([i[0] for i in data],[j[1] for j in data],c='r',alpha=.5)
x_data = data[:,0]
y_data = w * x_data + b
plt.plot(x_data,y_data)
plt.show()
if __name__ == "__main__":
main()
# Final loss:64.66513807777812, w:1.4766844022375065, b:0.09135514886298757
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。