# K线图样式 csv+mplfinance https://blog.csdn.net/qq_41437512/article/details/105319421
# tushare+mplfinance——最小案例 https://blog.csdn.net/ooobenooo/article/details/107754092
# datatime——时间与字符串处理 https://www.cnblogs.com/chenqionghe/p/12235277.html
# mplfinance官网文档优秀翻译 https://blog.csdn.net/wuwei_201/article/details/105781844
# dataframe基础操作 https://www.jianshu.com/p/4ff1d2b23ab3
from datetime import datetime, timedelta, date

import matplotlib as mpl
import mplfinance as mpf
import pandas as pd
import tushare as ts

api = None  # tushare 获取数据的接口对象
data = None  # 用于保存股票数据
data_insert = None  # 用于保存预测的那一条股票数据
sCode = None  # 用于保存股票代码
g_start_date = None  # 用于保存查询到的数据的开始时间
g_end_date = None  # 用于数据的结束时间
times = None  # 用于保存数据的时间段的字符串
tk = {}  # 用于保存预测值的 'open' 'close' 'high' 'low' 'volume'的字典对象


# 从命令行读取股票代码,并保存至全局变量 sCode
def get_sCode():
    global sCode  # 显式引用全局变量 sCode
    # 从命令行读取股票代码,并保存至全局变量 sCode
    sCode = input("请按格式输入股票代码:%s\n" % ('300059.SZ'))

    # 如果是正常的股票代码,应该是9个字符
    # 当然,我们无法保证9个字符的输入就一定是正确的股票代码,不过没关系,可以不管它
    if len(sCode) != 9:    # 如果输入不是9个字符
        # 就把 sCode 设置为写死的300059.SZ
        print("您输入了不正确的股票代码 >>> %s\n已修改股票代码为默认值 >>> 300059.SZ" % sCode)
        sCode = '300059.SZ'


# 设置时间。之后要按着这个时间来去查询股票数据
# 设置参数 next,next 意为是否取下一天,next 默认值为 False
# 若 next=True 则 g_end_date 取下一个工作日
# 若 next=False 则 g_end_date 取今天(而不是下一个工作日)
def set_time(next=False):
    # 显式引用全局变量 g_start_date g_end_date times
    global g_start_date
    global g_end_date
    global times

    # 字符串::开始时间 = (当前时间 - 270天).格式化为"年月日"
    g_start_date = (datetime.now() - timedelta(270)).strftime('%Y%m%d')

    # 若 next=True 则 字符串::结束时间 取下一个工作日
    # 若 next=False 则 字符串::结束时间 取今天(而不是下一个工作日)
    g_end_date = next_workday().strftime('%Y%m%d') \
        if next else datetime.now().strftime('%Y%m%d')

    # 字符串::times =
    # (按格式解析时间字符串(字符串::开始时间,格式字符串::"年月日"))转换成字符串.取[0,-9)
    # + ' ~ ' +
    # (按格式解析时间字符串(字符串::结束时间,格式字符串::"年月日"))转换成字符串.取[0,-9)
    times = str(datetime.strptime(g_start_date, '%Y%m%d'))[:-9] + \
        " ~ " + str(datetime.strptime(g_end_date, '%Y%m%d'))[:-9]


# tushare 初始化设置
def init_tushare():
    # 显式引用全局变量 api
    global api
    # tushare 用token设置初始化
    ts.set_token('dd8ab61a5210d4fa6ed707fd36c2db22986e7d3909e28b71e58f2f99')
    # 初始化完成以后,就可以用 .pro_api() 来拿到一个对象
    # 将这个对象赋值给 api ,留着以后用
    api = ts.pro_api()


# 获取数据
def get_data():
    # 需要先用 api.query(...) 查询数据,然后把查询到的数据转换成可以绘制的格式
    # 显式引用全局变量 data sCode times g_start_date g_end_date
    global data
    global sCode
    global times
    global g_start_date
    global g_end_date

    # 设置时间为180个交易日至今天
    set_time(next=False)
    # Columns: [ts_code, trade_date, open, high, low, close, pre_close, change, pct_chg, vol, amount]
    # 查询数据并保存到 DataFrame::df
    # api.query()用于查询股票数据,有多个配置参数
    # api_name='daily' 以天为精度来查询股票数据(我的账号的积分只能支持这个精度的)
    # ts_code=sCode 设置查询目标——股票代码为全局变量股票代码
    # start_date=g_start_date 设置查询目标——开始时间为全局变量开始时间
    # end_date=g_end_date 设置查询目标——结束时间为全局变量结束时间
    df = api.query(api_name='daily',
                   ts_code=sCode,
                   start_date=g_start_date,
                   end_date=g_end_date)[::-1]
    # [::-1]实际上是简单的转换,倒序取得数据的开头从开始到最后一行
    # 数据内容并没有变化(还是全部),但是数据的顺序变成了原先的倒序。

    # 如果查到数据以后发现数据为空,那么证明两件事情:
    # 第一 不是网络错误或是其他错误,否则在执行 api.query(...) 就会直接报错程序中止
    # 第二 确实查到了数据,不过数据为空,那就只能是股票代码有问题
    if df.empty:
        # 打印提示信息 就说股票代码在某时段查询失败
        print("\n\nERR!!! 股票:%s 在 %s 的数据查询失败,请仔细检查股票代码\n\n" % (sCode, times))
        # 同时抛出异常,就说股票代码在某时段查询失败
        raise Exception("股票:%s 在 %s 的数据查询失败,请仔细检查股票代码" % (sCode, times))
    else:
        # 如果一切正常,那么可以打印出本次查询到数据的基础信息
        print("股票:%s 在 %s 的数据 已查得\n数据共有 %d 行\n" %
              (sCode, times, df.size / 11))
        # df.size 是 df 里面共有多少数据
        # 11 是 df 共有 11 列,所以每行有 11 个数据
        # 因此 df.size / 11 就可以得到df共有几行数据

    # 为df创建一个新的列 date
    # 这一列的数据就是 df 的 trade_date 这一列的每一项作为参数传递给lambda函数后计算的结果
    # 这行 lambda 函数将时间字符串按格式解析成 datetime 对象
    df['date'] = df['trade_date'].apply(
        lambda x: datetime.strptime(x, '%Y%m%d'))

    # 然后创建一个新的 DataFrame 对象 data
    # 对这个data对象进行格式整理,以便能作为绘图数据被调用

    # 从 df 查数据 df.loc[哪行,哪列]
    # df.loc[:,[...]] 相当于 df.loc[所有行,[这些列]]
    # 得到的数据传递给 data(一个新的 DataFrame 变量)
    data = df.loc[:, ['date', 'open', 'close', 'high', 'low', 'vol']]
    # 格式要求:vol 这一列的列名改为 volume
    data = data.rename(columns={"vol": "volume"})
    # 格式要求,将 date 这一列设置为索引列
    data.set_index('date', inplace=True)
    # 格式要求,将索引列的名字更改为 Date
    data.index.name = "Date"
    # 格式要求,将索引列的数据的数据类型更改为 DatetimeIndex
    data.index = pd.DatetimeIndex(data.index)
    # 最最重要的是,索引列index::Date 会被作为绘图时的X轴数据

    # data.axes # 可打印 data 的索引列信息
    # data.shape # 可打印 data 的 shape
    # print(df)
    # print(data)


# 把 data 数据绘制成股票趋势图
# 参数castmode 设置是否是预测图 默认值 False
# 若 castmode=True,则绘制的是预测图
# 若 castmode=False,则绘制的是真实数据的图
def draw(castmode=False):
    # show_nontrading:是否显示非交易日,默认False
    global data
    global sCode
    global times

    # 设置全局变量 time
    # 根据参数castmode 来设置时间的结束日期是今天还是下一个工作日
    # 详情请见 set_time 函数的注释部分
    set_time(next=castmode)

    # print(mpl.rcParams) # 打印参数列表,开发程序的时候用的。
    # 尝试设置标题能正常显示中文,失败了
    # mpl.rcParams["font.family"] = 'Arial Unicode MS' # 据说这是mac上设置中文 然并卵

    # 下面是对mplfinance的调用,具体的代码注释请见上文的 第三方依赖库的介绍 => mplfinance
    s = mpf.make_mpf_style(base_mpf_style='nightclouds', marketcolors=mpf.make_marketcolors(up='r', down='g', edge='i', wick='i'), mavcolors=[
        '#ffffff', '#f7d652', '#eb60f9', '#adadad', '#0647ef'])
    mpl.rcParams['lines.linewidth'] = 0.1
    mpl.rcParams['figure.dpi'] = 2000
    mpl.rcParams['savefig.dpi'] = 2000

    tableName = "%s\n%s" % (sCode, times)
    # figpath = ('/Users/wangweijie/Desktop/股票预测走势绘制/Forcast:%s:%s' % (sCode, times) + '.png') \
    #     if castmode else'/Users/wangweijie/Desktop/股票预测走势绘制/%s:%s' % (sCode, times) + '.png'
    config = dict(
        type='candle', style=s, mav=(
            5, 10, 20, 60, 120), figratio=(5, 2), title=tableName, figscale=1.4, datetime_format='%m/%d', xrotation=0
    )

    print("股票:%s 在 %s 的数据 已读取 \n现有数据共有 %d 行\n" % (sCode, times, data.size / 5))
    print(data.head(3))
    print(data.tail(3))
    # mpf.plot(data, **config, savefig=figpath)

    if castmode:
        # 如果当前模式是预测模式,那就渲染并弹出图片窗口
        mpf.plot(data, **config)
    else:
        # 如果当前模式是原图模式,那就询问是否要看原始数据图
        print("\n\n是否希望看到 原始数据 的走势图")
        print("Yes | yes | Y | y >>> 是")
        print("其他内容(例如回车) >>> 不,我选择不看原始数据图")
        key = input("请根据提示输入:  ")
        if key == "Yes" or key == "yes" or key == "Y" or key == "y":
            print("将弹出原始数据的走势图")
            print("按q关闭图片窗口")
            print("在图片窗口被关闭前,将无法开始预测")
            mpf.plot(data, **config)
        else:
            print("跳过展示 原始数据 的走势图")


# 获取下一个工作日
# 函数返回值是一个 时间对象 就是下一个工作日
def next_workday():
    # 首先得到 今天 这个时间对象
    today = datetime.now()
    if (today.weekday() == 4):
        # 如果今天是周五,那下一个工作日是往后数三天
        nextwork_day = date.today() + timedelta(days=3)
    elif (today.weekday() == 5):
        # 如果今天是周六,那下一个工作日是往后数两天
        nextwork_day = date.today() + timedelta(days=2)
    else:
        # 否则,那下一个工作日是往后数一天
        nextwork_day = date.today() + timedelta(days=1)
    # 返回计算出的下一个工作日
    return nextwork_day


# 获取预测值数据并保存到全局变量 tk 这个字典里
def get_tk():
    global tk  # 显式引用全局变量tk
    # 设置死循环,不输入正确不让退出
    while True:
        try:
            print("请按提示仔细输入")
            tk['open'] = float(input("请输入:开盘价\n"))
            tk['close'] = float(input("请输入:收盘价\n"))
            tk['high'] = float(input("请输入:当日最高价\n"))
            tk['low'] = float(input("请输入:当日最低价\n"))
        except ValueError:
            # input(...) 函数的返回值就是用户在命令行中的输入,是字符串
            # 需要用 float(...) 将这个字符串转换为浮点数再赋值到字典 tk
            # 如果输入的数字不能解析成浮点数,就会触发 ValueError 这个异常
            # 那么打印"输入值未满足格式要求:自然数或小数"这句话提醒用户检查输入的内容
            # 然后开始新一轮循环
            print("输入值未满足格式要求:自然数或小数")
            pass  # 如果输入不能转换成浮点数,就重新来触发输入
        else:
            # 如果没有触发数据转换为数字的异常,也要保证数据满足基本的大小值关系
            if not (tk['high'] >= tk['open'] and tk['high'] >= tk['close'] and tk['low'] <= tk['open'] and tk['low'] <= tk['close'] and tk['low'] <= tk['high']):
                print("输入值未满足大小关系:最低价 <= (开盘价,收盘价) <= 最高价")
                pass  # 大小关系不满足,则重新输入
            else:
                print("成功设置:开盘价、收盘价、当日最高价、当日最低价 ")
                print(tk)
                break  # 一切顺利,完成输入,退出死循环


# 启动读入预测数据 然后修改数据 并重新渲染走势图窗口
def forecast_data():
    # 显式引用全局变量 tk 和 data 和 data_insert
    global tk
    global data
    global data_insert

    # 从命令行读取预测的四个值
    get_tk()
    # 创建 data_insert 为一个新的DataFrame对象,里面只有一行数据,就是预测的那一行
    data_insert = pd.DataFrame(
        # 设置它的索引是 下一个工作日
        index=[next_workday()],
        data={
            # 设置 data 的各个栏目为相应的值
            'open': [tk['open']],
            'close': [tk['close']],
            'high': [tk['high']],
            'low': [tk['low']],
            # 成交量数据不重要 干脆置为零
            'volume': [0]
        })
    # append并不修改原data,只是返回一个新的DataFrame
    # 所以,需要进行赋值,将返回值赋给data
    data = data.append(data_insert, ignore_index=False)
    # 因为新添加的这一项的index是date而不是datetimeIndex,所以整个data的index就变成object了
    # 用下面这行来重新转换为datatimeindex
    data.index = pd.DatetimeIndex(data.index)
    # data.axes 可以查询现在的index状况

    # 调用我们的绘图函数,castmode=True 即选择 预测图模式
    draw(castmode=True)


# 之前导入的包,定义的全局变量,定义好的函数,都可以在此处调用
# 之前定义的函数内的语句,是不会自动执行的,需要在这里调用函数,他们才会执行
if __name__ == "__main__":
    get_sCode()  # 从命令行读取股票代码到全局变量 sCode
    set_time()  # 默认值 next 为 false,即,并不会设置为下一个交易日
    init_tushare()  # tushare 为了能读取数据,需要初始化
    get_data()  # 根据读取今天往前数270天的数据,约为180个交易日
    draw()  # 默认值 castmode 为 false,即,并不会绘制和或保存预测数据的图片
    while True:
        print("\n\n是否要输入新的预测值")
        print("Yes | yes | Y | y >>> 是")
        print("No | no | N | n >>> 不,我选择退出程序")
        key = input("请根据提示输入:  ")
        if key == "Yes" or key == "yes" or key == "Y" or key == "y":
            print("开始新的一次预测")
            print("在预测图渲染图片窗口弹出后,可以按q关闭图片窗口")
            print("在图片窗口被关闭前,将无法开始下一轮预测")
            forecast_data()  # 触发预测数据
        elif key == "No" or key == "no" or key == "N" or key == "n":
            print("退出预测 程序中止")
            break
        else:
            print("输入值不合法 <<< %s 请按提示重新输入" % key)