数据挖掘之决策树归纳算法
決策樹歸納算法
作者:這次國際周老師講的課非常的硬核,趕緊整理一下筆記壓壓驚。
1.Motivation
- Basic idea: recursively partitioning the input space in training step and traverse the tree with test data point to predict
- Classification problem setup:
- training dataset
- testing dataset
- validation dataset
- Model
- Transparent method: a tree-like structure that emulate human’s decision making flow
- Can be converted into decision rules
- Similarity to association rules
中:
-
基本思想:在訓練步驟中遞歸分區輸入空間,并用測試數據點遍歷樹進行預測
-
分類問題設置:
①訓練數據集
②測試數據集
③驗證數據集
④模型
-
透明方法:一種模仿人類決策流程的樹狀結構
①可以轉換成決策規則
②與關聯規則相似
2.Decision Tree Structure(決策樹結構)
Decision Tree Structure
- Node: attribute splitting
- root, leaf, internal nodes
- Branch: attribute condition testing
- Binary or more
節點:屬性拆分
- 根,葉,內部節點
分支:屬性條件測試
- 二進制或更多
3.Framework of Supervised Learning(監督學習的框架)
- Induction: model building from training data
- Specific -> General
- Deduction: model prediction on testing data
- General -> Specific
- Eager vs. lazy learning: presence of induction step
中文:
- 歸納:從訓練數據建立模型
- 具體 - >一般
- 扣除:對測試數據的模型預測
- 一般 - >具體
- 渴望與懶惰的學習:歸納(induction)步驟的存在
4.Application
Major Application of Decision Tree Induction Algorithm
- Improve business decision making and support in a lot of industries: finance, banking, insurance, healthcare, etc.
- Enhance customer service levels
- Knowledge management platform to facilitate easier knowledge findability
5.Algorithm Summary (Hunt’s Algorithm)
- Goal: improve dataset purity by recursively splitting with attributes
- Check if a dataset dT is pure: if yes, then label it as a leaf node; if not, continue
- Choose the attribute and (in the case of numerical attributes) split points that maximize information gain to split the dataset
- Keep splitting until one of stop conditions is met
- when all the data points belong to the same class
- when all the records have the same attribute values
- Early termination: set by model parameters (e.g. minsplit, minbucket, maxdepth) that control pruning
- Other algorithm: ID3, C4.5, C5.0, CART
中:
- 目標:通過遞歸分割屬性來提高數據集純度
- 檢查是否有數據集dT是純的:如果是,則將其標記為葉節點; 如果沒有,繼續
- 選擇屬性和(在數字屬性的情況下)分割點,以最大化信息增益以分割數據集
- 保持分裂直到滿足一個停止條件
- 當所有數據點屬于同一個類時
- 當所有記錄具有相同的屬性值時
- 提前終止:由控制修剪的模型參數(例如minsplit,minbucket,maxdepth)設置
- 其他算法:ID3,C4.5,C5.0,CART
6.Attributes for Decision Tree(決策樹的屬性)
- Categorical attributes
- Binary attributes: Classification And Regression Tree (CART) constructs binary trees
- Multinomial attributes: grouping to reduce number of child nodes
- Numerical attributes
- Often discretized into binary attribute
- Pick a splitting point (cutoff) on the attribute
中:
- 分類屬性
- 二進制屬性:分類和回歸樹(CART)構造二叉樹
- 多項屬性:分組以減少子節點數
- 數字屬性
- 經常被離散化為二進制屬性
- 在屬性上選擇一個分裂點(cut off)
7.Data Impurity Measure: Entropy(數據雜質度量:熵)
-
Entropy: property of a dataset D and the classification C
-
Entropy curve for binary classification
8. Other Impurity Measure: Gini Index
8.1 Common characteristics of data impurity metric
- Correlate with data purity with regards to targt class label
- If data is more pure/homogeneous, metric has a lower value; if data is less pure/heterogeneous, metric has a higher value
8.2 Gini index
- Special cases
- Used in CART (Classification And Regression Trees)
9.Information Gain
- Information gain: property of entropy (D, C) and attribute (A)
Adopted in ID3 algorithm
- Gain ratio: Adjust information gain to control for number of groups after splitting
Adopted in C4.5 algorithm
10.Occam’s Razor
- Smaller models are preferred given similar training accuracy
- The complexity of a decision tree is defined as the number of splits in the tree
- Pruning: reduce the size of the decision tree
- Prepruning: halt tree construction early; requires setting threshold to stop attributes splitting
- Postpruning: remove branches from a “fully grown” tree
11.Overfitting
- Training accuracy vs.testing accuracy
12.Model Parameters
-
Set by rpart.control() function in rpart package.
? rpart.control(
? minsplit = 20,
? minbucket = round(minsplit/3),
? cp = 0.01,
? maxdepth = 30,
? …,
? )
-
Minbucket: the minimum number of observations in any terminal node.
-
Minsplit: the minimum number of observations that must exist in a node in order for a split to be attempted.
-
Maxdepth: maximum depth of any node of the final tree, with the root node counted as depth 0.
-
Complexity parameter (cp = ): the improvement of model fit in order to create a new branch
- When cp is set to a lower value, more complex the model can be; therefore increase cp to prune
- Question: how to set cp for a fully grown tree (set to a negative value)
-
In order to avoid overfitting, we should increase minbucket, minsplit, or cp; or decrease maxdepth
13.Properties of the Algorithm
- Greedy algorithm: top-down, recursive partitioning strategy
- Rectlinear decision boundary (rectangles or hyper-rectangles)
- Data fragmentation
- Slow training process to build model, fast to predict
- Robust to outliers
- Non-parametric model: no underlying assumptions for the model
- Output models either as a tree or as a set of rules (similar to association rules)
算法的屬性
- 貪心算法:自上而下,遞歸分區策略
- 直線決策邊界(矩形或超矩形)
- 數據碎片
- 緩慢的培訓過程建立模型,快速預測
- 對異常值的魯棒性
- 非參數模型:模型沒有基本假設
- 輸出模型作為樹或一組規則(類似于關聯規則)
14.Demo
14.1 churn dataset from C50 package
# install.packages("C50") > library(C50) > data(churn) > churn <- rbind(churnTrain, churnTest) > str(churnTrain) 'data.frame': 3333 obs. of 20 variables:$ state : Factor w/ 51 levels "AK","AL","AR",..: 17 36 32 36 37 2 20 25 19 50 ...$ account_length : int 128 107 137 84 75 118 121 147 117 141 ...$ area_code : Factor w/ 3 levels "area_code_408",..: 2 2 2 1 2 3 3 2 1 2 ...$ international_plan : Factor w/ 2 levels "no","yes": 1 1 1 2 2 2 1 2 1 2 ...$ voice_mail_plan : Factor w/ 2 levels "no","yes": 2 2 1 1 1 1 2 1 1 2 ...$ number_vmail_messages : int 25 26 0 0 0 0 24 0 0 37 ...$ total_day_minutes : num 265 162 243 299 167 ...$ total_day_calls : int 110 123 114 71 113 98 88 79 97 84 ...$ total_day_charge : num 45.1 27.5 41.4 50.9 28.3 ...$ total_eve_minutes : num 197.4 195.5 121.2 61.9 148.3 ...$ total_eve_calls : int 99 103 110 88 122 101 108 94 80 111 ...$ total_eve_charge : num 16.78 16.62 10.3 5.26 12.61 ...$ total_night_minutes : num 245 254 163 197 187 ...$ total_night_calls : int 91 103 104 89 121 118 118 96 90 97 ...$ total_night_charge : num 11.01 11.45 7.32 8.86 8.41 ...$ total_intl_minutes : num 10 13.7 12.2 6.6 10.1 6.3 7.5 7.1 8.7 11.2 ...$ total_intl_calls : int 3 3 5 7 3 6 7 6 4 5 ...$ total_intl_charge : num 2.7 3.7 3.29 1.78 2.73 1.7 2.03 1.92 2.35 3.02 ...$ number_customer_service_calls: int 1 1 0 2 3 0 3 0 1 0 ...$ churn : Factor w/ 2 levels "yes","no": 2 2 2 2 2 2 2 2 2 2 ...14.2 Model Training
library(caret) > library(rpart) > library(e1071) > dt_model <- train(churn ~ ., data = churnTrain, metric = "Accuracy", method = "rpart") > typeof(dt_model) [1] "list"> names(dt_model)[1] "method" "modelInfo" "modelType" "results" "pred" "bestTune" [7] "call" "dots" "metric" "control" "finalModel" "preProcess" [13] "trainingData" "resample" "resampledCM" "perfNames" "maximize" "yLimits" [19] "times" "levels" "terms" "coefnames" "contrasts" "xleves"14.3 Check Decision Tree Classifiers
> print(dt_model) CART 3333 samples19 predictor2 classes: 'yes', 'no' No pre-processing Resampling: Bootstrapped (25 reps) Summary of sample sizes: 3333, 3333, 3333, 3333, 3333, 3333, ... Resampling results across tuning parameters:cp Accuracy Kappa 0.07867495 0.8741209 0.30720490.08488613 0.8683224 0.24754400.08902692 0.8653671 0.2178997Accuracy was used to select the optimal model using the largest value. The final value used for the model was cp = 0.07867495.14.4 Check Decision Tree Classifier Details
> print(dt_model$finalModel) n= 3333 node), split, n, loss, yval, (yprob)* denotes terminal node1) root 3333 483 no (0.1449145 0.8550855) 2) total_day_minutes>=264.45 211 84 yes (0.6018957 0.3981043) 4) voice_mail_planyes< 0.5 158 37 yes (0.7658228 0.2341772) *5) voice_mail_planyes>=0.5 53 6 no (0.1132075 0.8867925) *3) total_day_minutes< 264.45 3122 356 no (0.1140295 0.8859705) *14.5 Model Prediction (1)
> dt_predict <- predict(dt_model, newdata = churnTest, na.action = na.omit, type = "prob") > head(dt_predict, 5)yes no 1 0.1140295 0.8859705 2 0.1140295 0.8859705 3 0.1132075 0.8867925 4 0.1140295 0.8859705 5 0.1140295 0.885970514.6 Model Prediction (2)
> dt_predict2 <- predict(dt_model, newdata = churnTest, type = "raw" ) > head(dt_predict2) [1] no no no no no no Levels: yes no14.7 Model Tuning (1)
> dt_model_tune <- train(churn ~ ., data = churnTrain, method = "rpart", metric = "Accuracy",tuneLength = 8 ) > print(dt_model_tune$finalModel) n= 3333 node), split, n, loss, yval, (yprob)* denotes terminal node1) root 3333 483 no (0.14491449 0.85508551) 2) total_day_minutes>=264.45 211 84 yes (0.60189573 0.39810427) 4) voice_mail_planyes< 0.5 158 37 yes (0.76582278 0.23417722) 8) total_eve_minutes>=187.75 101 5 yes (0.95049505 0.04950495) *9) total_eve_minutes< 187.75 57 25 no (0.43859649 0.56140351) 18) total_day_minutes>=277.7 32 11 yes (0.65625000 0.34375000) 36) total_eve_minutes>=144.35 24 4 yes (0.83333333 0.16666667) *37) total_eve_minutes< 144.35 8 1 no (0.12500000 0.87500000) *19) total_day_minutes< 277.7 25 4 no (0.16000000 0.84000000) *5) voice_mail_planyes>=0.5 53 6 no (0.11320755 0.88679245) *3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053) 6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390) 12) total_day_minutes< 160.2 102 13 yes (0.87254902 0.12745098) *13) total_day_minutes>=160.2 149 38 no (0.25503356 0.74496644) 26) total_eve_minutes< 141.75 19 5 yes (0.73684211 0.26315789) *27) total_eve_minutes>=141.75 130 24 no (0.18461538 0.81538462) 54) total_day_minutes< 175.75 34 14 no (0.41176471 0.58823529) 108) total_eve_minutes< 212.15 16 2 yes (0.87500000 0.12500000) *109) total_eve_minutes>=212.15 18 0 no (0.00000000 1.00000000) *55) total_day_minutes>=175.75 96 10 no (0.10416667 0.89583333) *7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685) 14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285) 28) total_intl_calls< 2.5 51 0 yes (1.00000000 0.00000000) *29) total_intl_calls>=2.5 216 50 no (0.23148148 0.76851852) 58) total_intl_minutes>=13.1 43 0 yes (1.00000000 0.00000000) *59) total_intl_minutes< 13.1 173 7 no (0.04046243 0.95953757) *15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485) 30) total_day_minutes>=223.25 383 68 no (0.17754569 0.82245431) 60) total_eve_minutes>=259.8 51 17 yes (0.66666667 0.33333333) 120) voice_mail_planyes< 0.5 40 6 yes (0.85000000 0.15000000) *121) voice_mail_planyes>=0.5 11 0 no (0.00000000 1.00000000) *61) total_eve_minutes< 259.8 332 34 no (0.10240964 0.89759036) *31) total_day_minutes< 223.25 2221 60 no (0.02701486 0.97298514) *14.8 Model Tuning (2)
> dt_model_tune2 <- train(churn ~ ., data = churnTrain, method = "rpart",tuneGrid = expand.grid(cp = seq(0, 0.1, 0.01))) > print(dt_model_tune2$finalModel) n= 3333 node), split, n, loss, yval, (yprob)* denotes terminal node1) root 3333 483 no (0.14491449 0.85508551) 2) total_day_minutes>=264.45 211 84 yes (0.60189573 0.39810427) 4) voice_mail_planyes< 0.5 158 37 yes (0.76582278 0.23417722) 8) total_eve_minutes>=187.75 101 5 yes (0.95049505 0.04950495) *9) total_eve_minutes< 187.75 57 25 no (0.43859649 0.56140351) 18) total_day_minutes>=277.7 32 11 yes (0.65625000 0.34375000) 36) total_eve_minutes>=144.35 24 4 yes (0.83333333 0.16666667) *37) total_eve_minutes< 144.35 8 1 no (0.12500000 0.87500000) *19) total_day_minutes< 277.7 25 4 no (0.16000000 0.84000000) *5) voice_mail_planyes>=0.5 53 6 no (0.11320755 0.88679245) *3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053) 6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390) 12) total_day_minutes< 160.2 102 13 yes (0.87254902 0.12745098) *13) total_day_minutes>=160.2 149 38 no (0.25503356 0.74496644) 26) total_eve_minutes< 141.75 19 5 yes (0.73684211 0.26315789) *27) total_eve_minutes>=141.75 130 24 no (0.18461538 0.81538462) 54) total_day_minutes< 175.75 34 14 no (0.41176471 0.58823529) 108) total_eve_minutes< 212.15 16 2 yes (0.87500000 0.12500000) *109) total_eve_minutes>=212.15 18 0 no (0.00000000 1.00000000) *55) total_day_minutes>=175.75 96 10 no (0.10416667 0.89583333) *7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685) 14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285) 28) total_intl_calls< 2.5 51 0 yes (1.00000000 0.00000000) *29) total_intl_calls>=2.5 216 50 no (0.23148148 0.76851852) 58) total_intl_minutes>=13.1 43 0 yes (1.00000000 0.00000000) *59) total_intl_minutes< 13.1 173 7 no (0.04046243 0.95953757) *15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485) 30) total_day_minutes>=223.25 383 68 no (0.17754569 0.82245431) 60) total_eve_minutes>=259.8 51 17 yes (0.66666667 0.33333333) 120) voice_mail_planyes< 0.5 40 6 yes (0.85000000 0.15000000) *121) voice_mail_planyes>=0.5 11 0 no (0.00000000 1.00000000) *61) total_eve_minutes< 259.8 332 34 no (0.10240964 0.89759036) *31) total_day_minutes< 223.25 2221 60 no (0.02701486 0.97298514) *14.9 Model Pre-Pruning
> dt_model_preprune <- train(churn ~ ., data = churnTrain, method = "rpart",metric = "Accuracy",tuneLength = 8,control = rpart.control(minsplit = 50, minbucket = 20, maxdepth = 5)) > print(dt_model_preprune$finalModel) n= 3333 node), split, n, loss, yval, (yprob)* denotes terminal node1) root 3333 483 no (0.14491449 0.85508551) 2) total_day_minutes>=264.45 211 84 yes (0.60189573 0.39810427) 4) voice_mail_planyes< 0.5 158 37 yes (0.76582278 0.23417722) 8) total_eve_minutes>=187.75 101 5 yes (0.95049505 0.04950495) *9) total_eve_minutes< 187.75 57 25 no (0.43859649 0.56140351) 18) total_day_minutes>=277.7 32 11 yes (0.65625000 0.34375000) *19) total_day_minutes< 277.7 25 4 no (0.16000000 0.84000000) *5) voice_mail_planyes>=0.5 53 6 no (0.11320755 0.88679245) *3) total_day_minutes< 264.45 3122 356 no (0.11402947 0.88597053) 6) number_customer_service_calls>=3.5 251 124 yes (0.50597610 0.49402390) 12) total_day_minutes< 160.2 102 13 yes (0.87254902 0.12745098) *13) total_day_minutes>=160.2 149 38 no (0.25503356 0.74496644) 26) total_eve_minutes< 155.5 29 11 yes (0.62068966 0.37931034) *27) total_eve_minutes>=155.5 120 20 no (0.16666667 0.83333333) *7) number_customer_service_calls< 3.5 2871 229 no (0.07976315 0.92023685) 14) international_planyes>=0.5 267 101 no (0.37827715 0.62172285) 28) total_intl_calls< 2.5 51 0 yes (1.00000000 0.00000000) *29) total_intl_calls>=2.5 216 50 no (0.23148148 0.76851852) 58) total_intl_minutes>=13.1 43 0 yes (1.00000000 0.00000000) *59) total_intl_minutes< 13.1 173 7 no (0.04046243 0.95953757) *15) international_planyes< 0.5 2604 128 no (0.04915515 0.95084485) 30) total_day_minutes>=223.25 383 68 no (0.17754569 0.82245431) 60) total_eve_minutes>=259.8 51 17 yes (0.66666667 0.33333333) *61) total_eve_minutes< 259.8 332 34 no (0.10240964 0.89759036) *31) total_day_minutes< 223.25 2221 60 no (0.02701486 0.97298514) *14.10 Model Post-pruning
> dt_model_postprune <- prune(dt_model$finalModel, cp = 0.2) > print(dt_model_postprune) n= 3333 node), split, n, loss, yval, (yprob)* denotes terminal node1) root 3333 483 no (0.1449145 0.8550855) *14.11 Check Decision Tree Classifier (1)
> plot(dt_model$finalModel) > text(dt_model$finalModel)14.12 Check Decision Tree Classifier (2)
> library(rattle) > fancyRpartPlot(dt_model$finalModel)15.other
一些參考資料:
https://blog.csdn.net/baimafujinji/article/details/50467970
https://blog.csdn.net/baimafujinji/article/details/51724371
https://www.cnblogs.com/csguo/p/7814855.html
https://blog.csdn.net/yangzhongblog/article/details/47151837
https://wenku.baidu.com/view/e42ee971c950ad02de80d4d8d15abe23482f039a.html
總結
以上是生活随笔為你收集整理的数据挖掘之决策树归纳算法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 显示网络没有服务器地址,tcp/ip c
- 下一篇: [贪心算法]Leetcode738. 单