1 Star 0 Fork 2

计量经济学/Machine-Learning-in-R

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
克隆/下载
04-decision-trees.Rmd 3.68 KB
一键复制 编辑 原始数据 按行查看 历史
EastBayEv 提交于 2020-02-06 23:27 . move url
# Decision Trees ## Load packages ```{r load_packages} library(ggplot2) library(rpart) library(rpart.plot) ``` ## Load data Load `train_x_class`, `train_y_class`, `test_x_class`, and `test_y_class` variables we defined in 02-preprocessing.Rmd for this *classification* task. ```{r setup_data} # Objects: task_reg, task_class load("data/preprocessed.RData") ``` ## Overview Decision trees are recursive partitioning methods that divide the predictor spaces into simpler regions and can be visualized in a tree-like structure. They attempt to classify data by dividing it into subsets according to a Y output variable and based on some predictors. Let's see how a decision tree classifies if a person suffers from heart disease (`target` = 1) or not (`target` = 0). ## Fit Model ```{r} set.seed(3) tree = rpart::rpart(train_y_class ~ ., data = train_x_class, # Use method = "anova" for a continuous outcome. method = "class", # Can use "gini" for gini coefficient. parms = list(split = "information")) # https://stackoverflow.com/questions/4553947/decision-tree-on-information-gain # Here is the text-based display of the decision tree. Yikes! :^( print(tree) ``` Although interpreting the text can be intimidating, a decision tree's main strength is its tree-like plot, which is much easier to interpret. ## Investigate Results ```{r plot_tree} rpart.plot::rpart.plot(tree) ``` We can also look inside of `tree` to see what we can unpack. "variable.importance" is one we should check out! ```{r} names(tree) tree$variable.importance ``` Plot variable importance ```{r} # Turn the tree$variable.importance vector into a dataframe tree_varimp = data.frame(tree$variable.importance) # Add rownames as their own column tree_varimp$x = rownames(tree_varimp) # Reorder clumns tree_varimp = tree_varimp[, c(2,1)] # Reset row names rownames(tree_varimp) = NULL # Rename columns names(tree_varimp) = c("Variable", "Importance") tree_varimp # Plot ggplot(tree_varimp, aes(x = reorder(Variable, Importance), y = Importance)) + geom_bar(stat = "identity") + theme_bw() + coord_flip() + xlab("") ``` In decision trees the main hyperparameter (configuration setting) is the **complexity parameter** (CP), but the name is a little counterintuitive; a high CP results in a simple decision tree with few splits, whereas a low CP results in a larger decision tree with many splits. `rpart` uses cross-validation internally to estimate the accuracy at various CP settings. We can review those to see what setting seems best. Print the results for various CP settings - we want the one with the lowest "xerror". We can also plot the performance estimates for different CP settings. ```{r plotcp_tree} # Show estimated error rate at different complexity parameter settings. printcp(tree) # Plot those estimated error rates. plotcp(tree) # Trees of similar sizes might appear to be tied for lowest "xerror", but a tree with fewer splits might be easier to interpret. tree_pruned2 = prune(tree, cp = 0.028986) # 2 splits tree_pruned6 = prune(tree, cp = 0.010870) # 6 splits ``` Print detailed results, variable importance, and summary of splits. ```{r} summary(tree_pruned2) rpart.plot(tree_pruned2) ``` ```{r} summary(tree_pruned6) rpart.plot(tree_pruned6) ``` You can also get more fine-grained control by checking out the "control" argument inside the rpart function. Type `?rpart` to learn more. Be sure to check out [gormanalysis](https://www.gormanalysis.com/blog/decision-trees-in-r-using-rpart/) excellent overview to help internalize what you learned in this example. ## Challenge 2 Open Challenge 2 in the "Challenges" folder.
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/econometric/Machine-Learning-in-R.git
git@gitee.com:econometric/Machine-Learning-in-R.git
econometric
Machine-Learning-in-R
Machine-Learning-in-R
master

搜索帮助