import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import re    #excel读取底部标签需要
import os
from scipy.optimize import curve_fit
script_dir = os.path.dirname(__file__) # 获取脚本文件所在的目录
file_path = os.path.join(script_dir, 'excel', '气动系数.xlsx')
xls = pd.ExcelFile(file_path)
sheet_names = [sheet for sheet in xls.sheet_names if re.match(r'^\d+-\d+Ma$', sheet)]

all_data = pd.DataFrame()
for sheet_name in sheet_names:
    # 从第三行开始读取(第一行数据无效,为表头),读取第1~21列
    # D:\Python\PyCode\CurveFitting\excel
    df = pd.read_excel(file_path, header=1, usecols=range(0, 21))
    
    alpha = df.iloc[:, 0]
    cL = df.iloc[:, 12]
    cD=df.iloc[:, 13]
    cmz1 = df.iloc[:, -4] 
    delta= df.iloc[:, -1] 
   
    alpha = alpha[pd.to_numeric(alpha, errors='coerce').notnull()]
    cL = cL[alpha.index]
    cD = cD[alpha.index]
    cmz1 = cmz1[alpha.index]
    delta= delta[alpha.index] 
    match = re.match(r'^(\d+)-(\d+)Ma$', sheet_name)
    if match:
        span = int(match.group(1))
        Ma = int(match.group(2))
    span = np.linspace(span, span, len(alpha), dtype=int)
    Ma = np.linspace(Ma, Ma, len(alpha), dtype=int)
    data = pd.DataFrame({
        'SpanLength': span,
        'Ma': Ma,
        'Alpha': alpha.values,
        'Delta':delta.values,
        'CL':cL.values,
        'CD':cD.values,
        'Cmz1':cmz1.values
    })
    all_data = pd.concat([all_data, data], ignore_index=True)
# print(all_data)
def funcL(x, a00, a01, a02, a03, a10, a11, a12, a20):
    CL = a00 * x[0] + a01 * x[1] + a02 * x[2] + a03 * x[3] + \
         a10 * x[0] * x[1] + a11 * x[0] * x[2] + a12 * x[1] * x[1] + \
         a20 * x[0] * x[1] * x[2]
    return CL
# 2. 协方差矩阵:衡量两个随机变量之间的线性关系的指标
# 3. 怎么理解首尾相接?
# 使用反斜杠(\)进行多行字符串连接时,需要确保行末没有多余的空格
def funcD(x, b00, b01, b02, b03, b10, b11, b12, b13,b20,b21):
    CD = b00 * x[0] + b01 * x[1] + b02 * x[2] + b03 * x[3] + \
         b10 * x[0] * x[0] + b11 * x[1] * x[1] + b12 * x[2] * x[2] + b13 * x[0] * x[1]+ \
         b20 * x[0] * x[0] * x[0] +  b21 * x[0] * x[1] * x[2]
    return CD

def funcZ(x, b00, b01, b02, b03, b10, b11, b12, b13,b20,b21):
    Cmz1 = b00 * x[0] + b01 * x[1] + b02 * x[2] + b03 * x[3] + \
         b10 * x[0] * x[0] + b11 * x[1] * x[1] + b12 * x[2] * x[2] + b13 * x[0] * x[1]+ \
         b20 * x[0] * x[0] * x[0] +  b21 * x[0] * x[1] * x[2]
    return Cmz1
alpha = all_data['Alpha'].values
Ma = all_data.iloc[:, all_data.columns.get_loc('Ma')].values
delta =all_data['Delta'].values
span = all_data.loc[:, 'SpanLength'].values
cL_values = all_data['CL'].to_numpy()
cD_values = all_data['CD'].to_numpy()
cmz1_values = all_data['Cmz1'].to_numpy()

#ValueError: not enough values to unpack (expected 9, got 6)
poptL, pcovL = curve_fit(funcL,(alpha, Ma, delta, span), cL_values, p0=[1] * 8) #P0:系数初始猜测值
poptD, pcovD = curve_fit(funcD,(alpha, Ma, delta, span), cD_values, p0=[1] * 10)
poptZ, pcovZ = curve_fit(funcZ,(alpha, Ma, delta, span), cmz1_values, p0=[1] * 10)
#TypeError: unsupported operand type(s) for -: 'NoneType' and 'float'
# pcovL 是协方差矩阵,通过 np.diag(pcovL) 提取协方差矩阵的对角线元素(这些元素是参数方差)
# 计算参数的标准误差
perrL = np.sqrt(np.diag(pcovL))
for i, (param, error) in enumerate(zip(poptL, perrL)):
    print(f'Parameter {i}: {param:.6f} ± {error:.6f}')
perrD = np.sqrt(np.diag(pcovD))
for i, (param, error) in enumerate(zip(poptD, perrD)):
    print(f'Parameter {i}: {param:.6f} ± {error:.6f}')
perrZ = np.sqrt(np.diag(pcovZ))
for i, (param, error) in enumerate(zip(poptZ, perrZ)):
    print(f'Parameter {i}: {param:.6f} ± {error:.6f}')


alphaNEW=np.linspace(-5,10,endpoint = True)
MaNEW=np.linspace(2,8,endpoint = True)
deltaNEW=np.linspace(-30,30,endpoint = True)
spanNEW=np.linspace(0,100,endpoint = True)
y_predL = funcL((alphaNEW, MaNEW, deltaNEW, spanNEW), *poptL)
y_predD = funcD((alphaNEW, MaNEW, deltaNEW, spanNEW), *poptD)
y_predZ = funcZ((alphaNEW, MaNEW, deltaNEW, spanNEW), *poptZ)
# 注意切片操作 cD_values[0:5]
variables = [
    ('alpha', alpha, alphaNEW),
    ('Ma', Ma, MaNEW),
    ('delta', delta, deltaNEW),
    ('span', span, spanNEW)  
    ]
fig, axs = plt.subplots(2, 2, figsize=(14, 10))
for i, (label, x, xNEW) in enumerate(variables):
    ax = axs[i//2, i%2]
    ax.plot(x, cL_values, 'o', label='True Value')
    ax.plot(xNEW, y_predL, '-', label='Fitted Value')
    ax.set_xlabel(label)
    ax.set_ylabel('Lift Coefficient (CL)')
    ax.set_title(f'{label} vs CL')
    ax.legend()
plt.tight_layout()
plt.suptitle('Lift Coefficient (CL)')
plt.savefig('lift_coefficient.png')
plt.show()
# bug:必须上一幅关闭后才能展示下一幅
fig, axs = plt.subplots(2, 2, figsize=(14, 10))
for i, (label, x, xNEW) in enumerate(variables):
    ax = axs[i//2, i%2]
    ax.plot(x, cD_values, 'o', label='True Value')
    ax.plot(xNEW, y_predD, '-', label='Fitted Value')
    ax.set_xlabel(label)
    ax.set_ylabel('Drag Coefficient (CD)')
    ax.set_title(f'{label} vs CD')
    ax.legend()
plt.tight_layout()
plt.suptitle('Drag Coefficient (CD)')
plt.savefig('drag_coefficient.png')
plt.show()

fig, axs = plt.subplots(2, 2, figsize=(14, 10))
for i, (label, x, xNEW) in enumerate(variables):
    ax = axs[i//2, i%2]
    ax.plot(x, cmz1_values, 'o', label='True Value')
    ax.plot(xNEW, y_predZ, '-', label='Fitted Value')
    ax.set_xlabel(label)
    ax.set_ylabel('Moment Coefficient (CMz1)')
    ax.set_title(f'{label} vs CMz1')
    ax.legend()
plt.tight_layout()
plt.suptitle('Moment Coefficient (CMz1)')
plt.savefig('moment_coefficient.png')
plt.show()