1 Star 1 Fork 0

林冰漫/sklearn

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
Treesklearn.py 2.03 KB
一键复制 编辑 原始数据 按行查看 历史
林冰漫 提交于 2023-03-28 00:25 . Initial commit
import pandas as pd
# 导入数据集
data = pd.read_csv('student-mat.csv', sep=';')
# 描述性统计
print(data.describe())
# 查看每个属性与目标变量之间的相关性
print(data.corr()['G3'].sort_values())
# 检查是否有缺失值
print(data.isnull().sum())
# 检查异常值
print(data.boxplot())
# 对分类变量进行编码
data_encoded = pd.get_dummies(data)
# 对数据进行特征缩放
from sklearn.preprocessing import MinMaxScaler
scaler = MinMaxScaler()
data_scaled = pd.DataFrame(scaler.fit_transform(data_encoded), columns=data_encoded.columns)
# 划分训练、验证和测试集
from sklearn.model_selection import train_test_split
X = data_scaled.drop('G3', axis=1)
y = data_scaled['G3']
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.15, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.1765, random_state=42)
# 建立决策树模型
from sklearn.tree import DecisionTreeRegressor
from sklearn.metrics import mean_squared_error
tree_reg = DecisionTreeRegressor(random_state=42)
tree_reg.fit(X_train, y_train)
# 在验证集上进行评估
import numpy as np
y_val_pred = tree_reg.predict(X_val)
mse = mean_squared_error(y_val, y_val_pred)
rmse = np.sqrt(mse)
print('Validation set RMSE:', rmse)
# 网格搜索寻找最佳超参数组合
from sklearn.model_selection import GridSearchCV
params = {'max_depth': range(1, 21), 'min_samples_leaf': range(1, 11)}
tree_reg_grid = DecisionTreeRegressor(random_state=42)
grid_search = GridSearchCV(tree_reg_grid, params, cv=5, scoring='neg_mean_squared_error', return_train_score=False)
grid_search.fit(X_train_val, y_train_val)
# 输出最佳超参数组合
print(grid_search.best_params_)
# 训练最终模型并在测试集上进行评估
best_reg = DecisionTreeRegressor(max_depth=10, min_samples_leaf=2, random_state=42)
best_reg.fit(X_train_val, y_train_val)
y_test_pred = best_reg.predict(X_test)
mse = mean_squared_error(y_test, y_test_pred)
rmse = np.sqrt(mse)
print('Test set RMSE:', rmse)
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/lin-bingman/sklearn.git
git@gitee.com:lin-bingman/sklearn.git
lin-bingman
sklearn
sklearn
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385