1 Star 4 Fork 0

章一砚/股票预测走势绘制

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
有完整注释的python脚本.py 14.83 KB
一键复制 编辑 原始数据 按行查看 历史
# K线图样式 csv+mplfinance https://blog.csdn.net/qq_41437512/article/details/105319421
# tushare+mplfinance——最小案例 https://blog.csdn.net/ooobenooo/article/details/107754092
# datatime——时间与字符串处理 https://www.cnblogs.com/chenqionghe/p/12235277.html
# mplfinance官网文档优秀翻译 https://blog.csdn.net/wuwei_201/article/details/105781844
# dataframe基础操作 https://www.jianshu.com/p/4ff1d2b23ab3
from datetime import datetime, timedelta, date
import matplotlib as mpl
import mplfinance as mpf
import pandas as pd
import tushare as ts
api = None # tushare 获取数据的接口对象
data = None # 用于保存股票数据
data_insert = None # 用于保存预测的那一条股票数据
sCode = None # 用于保存股票代码
g_start_date = None # 用于保存查询到的数据的开始时间
g_end_date = None # 用于数据的结束时间
times = None # 用于保存数据的时间段的字符串
tk = {} # 用于保存预测值的 'open' 'close' 'high' 'low' 'volume'的字典对象
# 从命令行读取股票代码,并保存至全局变量 sCode
def get_sCode():
global sCode # 显式引用全局变量 sCode
# 从命令行读取股票代码,并保存至全局变量 sCode
sCode = input("请按格式输入股票代码:%s\n" % ('300059.SZ'))
# 如果是正常的股票代码,应该是9个字符
# 当然,我们无法保证9个字符的输入就一定是正确的股票代码,不过没关系,可以不管它
if len(sCode) != 9: # 如果输入不是9个字符
# 就把 sCode 设置为写死的300059.SZ
print("您输入了不正确的股票代码 >>> %s\n已修改股票代码为默认值 >>> 300059.SZ" % sCode)
sCode = '300059.SZ'
# 设置时间。之后要按着这个时间来去查询股票数据
# 设置参数 next,next 意为是否取下一天,next 默认值为 False
# 若 next=True 则 g_end_date 取下一个工作日
# 若 next=False 则 g_end_date 取今天(而不是下一个工作日)
def set_time(next=False):
# 显式引用全局变量 g_start_date g_end_date times
global g_start_date
global g_end_date
global times
# 字符串::开始时间 = (当前时间 - 270天).格式化为"年月日"
g_start_date = (datetime.now() - timedelta(270)).strftime('%Y%m%d')
# 若 next=True 则 字符串::结束时间 取下一个工作日
# 若 next=False 则 字符串::结束时间 取今天(而不是下一个工作日)
g_end_date = next_workday().strftime('%Y%m%d') \
if next else datetime.now().strftime('%Y%m%d')
# 字符串::times =
# (按格式解析时间字符串(字符串::开始时间,格式字符串::"年月日"))转换成字符串.取[0,-9)
# + ' ~ ' +
# (按格式解析时间字符串(字符串::结束时间,格式字符串::"年月日"))转换成字符串.取[0,-9)
times = str(datetime.strptime(g_start_date, '%Y%m%d'))[:-9] + \
" ~ " + str(datetime.strptime(g_end_date, '%Y%m%d'))[:-9]
# tushare 初始化设置
def init_tushare():
# 显式引用全局变量 api
global api
# tushare 用token设置初始化
ts.set_token('dd8ab61a5210d4fa6ed707fd36c2db22986e7d3909e28b71e58f2f99')
# 初始化完成以后,就可以用 .pro_api() 来拿到一个对象
# 将这个对象赋值给 api ,留着以后用
api = ts.pro_api()
# 获取数据
def get_data():
# 需要先用 api.query(...) 查询数据,然后把查询到的数据转换成可以绘制的格式
# 显式引用全局变量 data sCode times g_start_date g_end_date
global data
global sCode
global times
global g_start_date
global g_end_date
# 设置时间为180个交易日至今天
set_time(next=False)
# Columns: [ts_code, trade_date, open, high, low, close, pre_close, change, pct_chg, vol, amount]
# 查询数据并保存到 DataFrame::df
# api.query()用于查询股票数据,有多个配置参数
# api_name='daily' 以天为精度来查询股票数据(我的账号的积分只能支持这个精度的)
# ts_code=sCode 设置查询目标——股票代码为全局变量股票代码
# start_date=g_start_date 设置查询目标——开始时间为全局变量开始时间
# end_date=g_end_date 设置查询目标——结束时间为全局变量结束时间
df = api.query(api_name='daily',
ts_code=sCode,
start_date=g_start_date,
end_date=g_end_date)[::-1]
# [::-1]实际上是简单的转换,倒序取得数据的开头从开始到最后一行
# 数据内容并没有变化(还是全部),但是数据的顺序变成了原先的倒序。
# 如果查到数据以后发现数据为空,那么证明两件事情:
# 第一 不是网络错误或是其他错误,否则在执行 api.query(...) 就会直接报错程序中止
# 第二 确实查到了数据,不过数据为空,那就只能是股票代码有问题
if df.empty:
# 打印提示信息 就说股票代码在某时段查询失败
print("\n\nERR!!! 股票:%s 在 %s 的数据查询失败,请仔细检查股票代码\n\n" % (sCode, times))
# 同时抛出异常,就说股票代码在某时段查询失败
raise Exception("股票:%s 在 %s 的数据查询失败,请仔细检查股票代码" % (sCode, times))
else:
# 如果一切正常,那么可以打印出本次查询到数据的基础信息
print("股票:%s 在 %s 的数据 已查得\n数据共有 %d 行\n" %
(sCode, times, df.size / 11))
# df.size 是 df 里面共有多少数据
# 11 是 df 共有 11 列,所以每行有 11 个数据
# 因此 df.size / 11 就可以得到df共有几行数据
# 为df创建一个新的列 date
# 这一列的数据就是 df 的 trade_date 这一列的每一项作为参数传递给lambda函数后计算的结果
# 这行 lambda 函数将时间字符串按格式解析成 datetime 对象
df['date'] = df['trade_date'].apply(
lambda x: datetime.strptime(x, '%Y%m%d'))
# 然后创建一个新的 DataFrame 对象 data
# 对这个data对象进行格式整理,以便能作为绘图数据被调用
# 从 df 查数据 df.loc[哪行,哪列]
# df.loc[:,[...]] 相当于 df.loc[所有行,[这些列]]
# 得到的数据传递给 data(一个新的 DataFrame 变量)
data = df.loc[:, ['date', 'open', 'close', 'high', 'low', 'vol']]
# 格式要求:vol 这一列的列名改为 volume
data = data.rename(columns={"vol": "volume"})
# 格式要求,将 date 这一列设置为索引列
data.set_index('date', inplace=True)
# 格式要求,将索引列的名字更改为 Date
data.index.name = "Date"
# 格式要求,将索引列的数据的数据类型更改为 DatetimeIndex
data.index = pd.DatetimeIndex(data.index)
# 最最重要的是,索引列index::Date 会被作为绘图时的X轴数据
# data.axes # 可打印 data 的索引列信息
# data.shape # 可打印 data 的 shape
# print(df)
# print(data)
# 把 data 数据绘制成股票趋势图
# 参数castmode 设置是否是预测图 默认值 False
# 若 castmode=True,则绘制的是预测图
# 若 castmode=False,则绘制的是真实数据的图
def draw(castmode=False):
# show_nontrading:是否显示非交易日,默认False
global data
global sCode
global times
# 设置全局变量 time
# 根据参数castmode 来设置时间的结束日期是今天还是下一个工作日
# 详情请见 set_time 函数的注释部分
set_time(next=castmode)
# print(mpl.rcParams) # 打印参数列表,开发程序的时候用的。
# 尝试设置标题能正常显示中文,失败了
# mpl.rcParams["font.family"] = 'Arial Unicode MS' # 据说这是mac上设置中文 然并卵
# 下面是对mplfinance的调用,具体的代码注释请见上文的 第三方依赖库的介绍 => mplfinance
s = mpf.make_mpf_style(base_mpf_style='nightclouds', marketcolors=mpf.make_marketcolors(up='r', down='g', edge='i', wick='i'), mavcolors=[
'#ffffff', '#f7d652', '#eb60f9', '#adadad', '#0647ef'])
mpl.rcParams['lines.linewidth'] = 0.1
mpl.rcParams['figure.dpi'] = 2000
mpl.rcParams['savefig.dpi'] = 2000
tableName = "%s\n%s" % (sCode, times)
# figpath = ('/Users/wangweijie/Desktop/股票预测走势绘制/Forcast:%s:%s' % (sCode, times) + '.png') \
# if castmode else'/Users/wangweijie/Desktop/股票预测走势绘制/%s:%s' % (sCode, times) + '.png'
config = dict(
type='candle', style=s, mav=(
5, 10, 20, 60, 120), figratio=(5, 2), title=tableName, figscale=1.4, datetime_format='%m/%d', xrotation=0
)
print("股票:%s 在 %s 的数据 已读取 \n现有数据共有 %d 行\n" % (sCode, times, data.size / 5))
print(data.head(3))
print(data.tail(3))
# mpf.plot(data, **config, savefig=figpath)
if castmode:
# 如果当前模式是预测模式,那就渲染并弹出图片窗口
mpf.plot(data, **config)
else:
# 如果当前模式是原图模式,那就询问是否要看原始数据图
print("\n\n是否希望看到 原始数据 的走势图")
print("Yes | yes | Y | y >>> 是")
print("其他内容(例如回车) >>> 不,我选择不看原始数据图")
key = input("请根据提示输入: ")
if key == "Yes" or key == "yes" or key == "Y" or key == "y":
print("将弹出原始数据的走势图")
print("按q关闭图片窗口")
print("在图片窗口被关闭前,将无法开始预测")
mpf.plot(data, **config)
else:
print("跳过展示 原始数据 的走势图")
# 获取下一个工作日
# 函数返回值是一个 时间对象 就是下一个工作日
def next_workday():
# 首先得到 今天 这个时间对象
today = datetime.now()
if (today.weekday() == 4):
# 如果今天是周五,那下一个工作日是往后数三天
nextwork_day = date.today() + timedelta(days=3)
elif (today.weekday() == 5):
# 如果今天是周六,那下一个工作日是往后数两天
nextwork_day = date.today() + timedelta(days=2)
else:
# 否则,那下一个工作日是往后数一天
nextwork_day = date.today() + timedelta(days=1)
# 返回计算出的下一个工作日
return nextwork_day
# 获取预测值数据并保存到全局变量 tk 这个字典里
def get_tk():
global tk # 显式引用全局变量tk
# 设置死循环,不输入正确不让退出
while True:
try:
print("请按提示仔细输入")
tk['open'] = float(input("请输入:开盘价\n"))
tk['close'] = float(input("请输入:收盘价\n"))
tk['high'] = float(input("请输入:当日最高价\n"))
tk['low'] = float(input("请输入:当日最低价\n"))
except ValueError:
# input(...) 函数的返回值就是用户在命令行中的输入,是字符串
# 需要用 float(...) 将这个字符串转换为浮点数再赋值到字典 tk
# 如果输入的数字不能解析成浮点数,就会触发 ValueError 这个异常
# 那么打印"输入值未满足格式要求:自然数或小数"这句话提醒用户检查输入的内容
# 然后开始新一轮循环
print("输入值未满足格式要求:自然数或小数")
pass # 如果输入不能转换成浮点数,就重新来触发输入
else:
# 如果没有触发数据转换为数字的异常,也要保证数据满足基本的大小值关系
if not (tk['high'] >= tk['open'] and tk['high'] >= tk['close'] and tk['low'] <= tk['open'] and tk['low'] <= tk['close'] and tk['low'] <= tk['high']):
print("输入值未满足大小关系:最低价 <= (开盘价,收盘价) <= 最高价")
pass # 大小关系不满足,则重新输入
else:
print("成功设置:开盘价、收盘价、当日最高价、当日最低价 ")
print(tk)
break # 一切顺利,完成输入,退出死循环
# 启动读入预测数据 然后修改数据 并重新渲染走势图窗口
def forecast_data():
# 显式引用全局变量 tk 和 data 和 data_insert
global tk
global data
global data_insert
# 从命令行读取预测的四个值
get_tk()
# 创建 data_insert 为一个新的DataFrame对象,里面只有一行数据,就是预测的那一行
data_insert = pd.DataFrame(
# 设置它的索引是 下一个工作日
index=[next_workday()],
data={
# 设置 data 的各个栏目为相应的值
'open': [tk['open']],
'close': [tk['close']],
'high': [tk['high']],
'low': [tk['low']],
# 成交量数据不重要 干脆置为零
'volume': [0]
})
# append并不修改原data,只是返回一个新的DataFrame
# 所以,需要进行赋值,将返回值赋给data
data = data.append(data_insert, ignore_index=False)
# 因为新添加的这一项的index是date而不是datetimeIndex,所以整个data的index就变成object了
# 用下面这行来重新转换为datatimeindex
data.index = pd.DatetimeIndex(data.index)
# data.axes 可以查询现在的index状况
# 调用我们的绘图函数,castmode=True 即选择 预测图模式
draw(castmode=True)
# 之前导入的包,定义的全局变量,定义好的函数,都可以在此处调用
# 之前定义的函数内的语句,是不会自动执行的,需要在这里调用函数,他们才会执行
if __name__ == "__main__":
get_sCode() # 从命令行读取股票代码到全局变量 sCode
set_time() # 默认值 next 为 false,即,并不会设置为下一个交易日
init_tushare() # tushare 为了能读取数据,需要初始化
get_data() # 根据读取今天往前数270天的数据,约为180个交易日
draw() # 默认值 castmode 为 false,即,并不会绘制和或保存预测数据的图片
while True:
print("\n\n是否要输入新的预测值")
print("Yes | yes | Y | y >>> 是")
print("No | no | N | n >>> 不,我选择退出程序")
key = input("请根据提示输入: ")
if key == "Yes" or key == "yes" or key == "Y" or key == "y":
print("开始新的一次预测")
print("在预测图渲染图片窗口弹出后,可以按q关闭图片窗口")
print("在图片窗口被关闭前,将无法开始下一轮预测")
forecast_data() # 触发预测数据
elif key == "No" or key == "no" or key == "N" or key == "n":
print("退出预测 程序中止")
break
else:
print("输入值不合法 <<< %s 请按提示重新输入" % key)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/SevDaisy/stock_forecast_trend.git
git@gitee.com:SevDaisy/stock_forecast_trend.git
SevDaisy
stock_forecast_trend
股票预测走势绘制
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385