1 Star 2 Fork 1

Kirito/线性回归模型预测N天后股票收盘价

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
预测股票价格.py 2.30 KB
一键复制 编辑 原始数据 按行查看 历史
Kirito 提交于 2023-11-25 12:11 . 预测股票价格
import pandas as pd
import matplotlib.pyplot as plt
import plotly.graph_objs as go
from plotly.offline import plot
from sklearn.linear_model import LinearRegression
from sklearn import preprocessing #用于数据标准化
df = pd.read_csv('data3.csv')
df.head()
df['trade_date'] = pd.to_datetime(df['trade_date'])
df = df.set_index('trade_date')
# 将数据按照日期升序排列
df.sort_values(by=['trade_date'], ascending=True, inplace=True)
df.tail()
df.info()
df.drop_duplicates(inplace=True)
print(df.shape)
Min_date = df.index.min()
Max_date = df.index.max()
print("First date is:", Min_date)
print("Last date is:", Max_date)
print("时间跨度:", Max_date - Min_date, "天")
#k线图
# trace = go.Ohlc(x=df.index, open=df['open'], high=df['high'], low=df['low'], close=df['close'])
# data = [trace]
#
# plot(data, filename='simple_ohlc.html')
N = 3 #预测3天之后的股票收盘价
df['label'] = df['close'].shift(-N)
Data = df.drop([ 'change', 'pct_chg'], axis=1)
X = Data.values #转换成矩阵格式
scaler = preprocessing.StandardScaler()
X_scaled = scaler.fit_transform(X) # 标准化数据,均值为0,标准差为1
# 最后N三行数据没有label值
df.dropna(inplace=True)
Target = df.label
y = Target.values #转换成矩阵格式
# 为与y长度保持一致,X的最后N行也要去除
X = X[:-N]
# 将数据分为训练数据和测试数据
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)
# 使用线性回归模型进行训练
lr = LinearRegression()
lr.fit(X_train, y_train)
lr.score(X_test, y_test)
print(lr.score(X_test, y_test))
# 理解模型
# 输出各个特征参数
for idx, col_name in enumerate(['open', 'high', 'close', 'low', 'vol', 'ma5', 'ma10', 'ma20', 'ma_v_5', 'ma_v_10', 'ma_v_20']):
print("The coefficient for {} is {}".format(col_name, lr.coef_[idx]))
print(lr.intercept_)
data = df.dropna()
data.index = df.index
data['forecast'] = lr.predict(Data[:-N].values)
# 画预测值和实际值
data['close'].plot(color='green', linewidth=1)
data['forecast'].plot(color='orange', linewidth=2)
plt.xlabel('Time')
plt.ylabel('Price')
plt.show()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/kirito-yukiasuna/gupiao.git
git@gitee.com:kirito-yukiasuna/gupiao.git
kirito-yukiasuna
gupiao
线性回归模型预测N天后股票收盘价
master

搜索帮助