2 Star 1 Fork 0

林锦星/airfoilprocess

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
utils.py 6.33 KB
一键复制 编辑 原始数据 按行查看 历史
LJX1111 提交于 2021-01-22 22:13 . 改bug+三次样条插值
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import re
import shutil
def merge(x1,x2):
x = []
for i, j in zip(x1,x2):
x.append(i)
x.append(j)
x.append(x1[-1])
return np.asarray(x)
def move(src, dst):
shutil.move(src, dst)
def cal_mid(a):
b = []
for i in range(len(a) - 1):
b.append((a[i] + a[i + 1]) / 2)
return np.asarray(b)
# 将.dat文件转化为.csv文件,并输出点数低于阈值的文件名
def prc_svcsv(read_path, save_path):
# 去除异常文件
if "n642415" in read_path:
print("remove: ", read_path)
return
elif "naca23015" in read_path:
print("remove: ", read_path)
return
elif "naca23018" in read_path:
print("remove: ", read_path)
return
elif "naca2412" in read_path:
print("remove: ", read_path)
return
elif "naca1.dat" in read_path:
print("remove: ", read_path)
return
elif "e664ex" in read_path:
print("remove: ", read_path)
return
elif "lrn1007" in read_path: # 科学计数法特殊处理
print("remove: ", read_path)
return
elif "naca64a010" in read_path:
print("remove: ", read_path)
return
elif "s1221" in read_path:
print("remove: ", read_path)
return
elif "sc1095r8" in read_path:
print("remove: ", read_path)
return
file = open(read_path, mode='r')
strs = file.readlines()
dict = {}
x = []
y = []
for tmp_str in strs:
if tmp_str == "":
continue
tmp_str = tmp_str.strip()
while " " in tmp_str:
tmp_str = tmp_str.replace(" ", " ")
tmp_strs = tmp_str.split(" ")
if " " in tmp_str:
tmp_strs = tmp_str.split(" ")
if re.match(r"^([-,0-9]{0,}[.][0-9]*)$", tmp_strs[0]) == None:
continue
if re.match(r"^([-,0-9]{0,}[.][0-9]*)$", tmp_strs[1]) == None:
continue
# 去除异常值
if float(tmp_strs[1]) > 10:
continue
x.append(float(tmp_strs[0]))
y.append(float(tmp_strs[1]))
df = pd.DataFrame(dict)
df['x'] = x
df['y'] = y
df.to_csv(save_path, index=False)
# 输出点数少于阈值的文件名及点数
# if df.shape[0] <= 17:
# print(read_path + "\t" + str(df.shape[0]))
# rename: from "*.DAT" to "*.dat"
def file_rename(file_name):
os.rename(file_name, file_name[:-3]+"dat")
# count
def csv_count(file_path):
df = pd.read_csv(file_path)
# if len(df['x']) <= 17:
# print(file_path, len(df['x']))
return len(df['x'])
# find the number out of range and modify
def csv_out_range(file_path):
df = pd.read_csv(file_path)
# test coordinate of x
list_x = df['x']
for i, j in enumerate(list_x):
if j > 1:
list_x[i] = 1
print(file_path, i, j)
elif j < 0:
list_x[i] = 0
print(file_path, i, j)
df.to_csv(file_path, index=False)
# test coordinate of y
list_y = df['y']
for i, j in enumerate(list_y):
if j > 0.45:
# list_y[i] = 1
print(file_path, i, j)
elif j < -0.45:
# list_y[i] = -1
print(file_path, i, j)
# df.to_csv("test1.csv", index=False)
def csi(X1, Y1, X2, Y2, file_name):
'''
三次样条插值
:param X1:
:param Y1:
:param X2:
:param Y2:
:param file_name:
:return:
'''
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
# print(X1)
# print(Y1)
# print(X2)
# print(Y2)
# new_x = np.arange(0, 1.01, 0.01) # 定义差值点
new_x_u = cal_mid(X1)
new_x_d = cal_mid(X2)
# 进行样条差值
import scipy.interpolate as spi
# 进行三次样条拟合
ipo3_u = spi.splrep(X1, Y1, k=3) # 样本点导入,生成参数
iy3_u = spi.splev(new_x_u, ipo3_u) # 根据观测点和样条参数,生成插值
ipo3_d = spi.splrep(X2, Y2, k=3) # 样本点导入,生成参数
iy3_d = spi.splev(new_x_d, ipo3_d) # 根据观测点和样条参数,生成插值
# 作图
plt.plot(X1, Y1, 'o', label='样本点上')
plt.plot(new_x_u, iy3_u, '*', label='插值点上')
plt.plot(X2, Y2, 'o', label='样本点下')
plt.plot(new_x_d, iy3_d, '*', label='插值点下')
plt.ylim(-0.3, 0.3)
plt.ylabel('指数')
plt.title('机翼数据三次样条插值拟合结果')
plt.legend()
plt.show()
# 将数据合并,最后保存为csv文件
x_up = merge(X1[::-1], new_x_u[::-1])
y_up = merge(Y1[::-1], iy3_u[::-1])
x_down = merge(X2, new_x_d)
y_down = merge(Y2, iy3_d)
x = np.concatenate((x_up, x_down))
y = np.concatenate((y_up, y_down))
dict = {'x': x, 'y': y}
df = pd.DataFrame(dict)
df.to_csv(file_name, index=False)
def hermite(xi, yi):
from scipy.interpolate import KroghInterpolator
x = np.linspace(0, 1, 20)
# xi = np.array([0, 0, 1, 1])
# yi = np.array([1, 0, 2, 3])
interpolant = KroghInterpolator(xi, yi)
plt.figure()
plt.plot(x, interpolant(x), 'r--', label='Hermite Interpolation')
plt.plot(xi, yi, 'go', label='nodes', markersize=8)
plt.legend(loc=9)
plt.xlim(0, 1)
plt.title('$埃尔米特插值$')
plt.show()
def lagrange(x, y):
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
from scipy.interpolate import lagrange
# x = [1, 2, 3, 4]
# y = [4, 15, 40, 85]
lag_ret = lagrange(x,y)
x0 = np.arange(0, 1, 0.02)
plt.figure()
plt.plot(x0, lag_ret(x0), label="Lagrange Interpolation")
plt.plot(x, y, 'go', label="notes", markersize=8)
plt.legend(loc=9)
plt.xlim(0, 1)
plt.ylim(-0.3, 0.3)
plt.title('$拉格朗日插值$')
plt.show()
def Monotonic(path):
df = pd.read_csv(path)
x = list(df['x'])
state_change_count = 0
state = x[0] > x[1] # True表示递减;False表示递增
for i in range(2, len(x)):
tmp_state = x[i-1] > x[i]
if tmp_state != state:
state_change_count += 1
print("{}: {}->{}".format(i, state, tmp_state))
state = tmp_state
return state_change_count
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/strugglejx98/airfoilprocess.git
git@gitee.com:strugglejx98/airfoilprocess.git
strugglejx98
airfoilprocess
airfoilprocess
master

搜索帮助