【机器学习系列】聊聊决策树
決策樹是簡單易學(xué)且具有良好解釋性的模型,但實話說,我在工作中用的不多,通常會選擇更加復(fù)雜一些的模型,如隨機森林、XGBoots之類的模型,但要理解這些模型,對決策樹的學(xué)習(xí)是必不可少的,所以本文就基于sklearn(Scikit Learn)來討論一下決策樹相關(guān)的內(nèi)容。
決策樹基本使用與可視化
為了方便,我們直接使用sklearn提供的鳶尾花數(shù)據(jù)集來展示決策樹的使用。
首先,導(dǎo)入鳶尾花數(shù)據(jù)集:
from?sklearn.datasets?import?load_iris iris?=?load_iris()如果第一次使用,sklearn會自動幫我們下載,你需要等待一下則可。
獲得鳶尾花數(shù)據(jù)集后,我們使用數(shù)據(jù)集中花瓣的長度和寬度作為特征,將花瓣的種類作為target,然后使用sklearn的DecisionTreeClassifier構(gòu)建分類決策樹,該決策樹會基于特征,對花瓣的種類進行分類,代碼如下:
from?sklearn.tree?import?DecisionTreeClassifierX?=?iris.data[:,?2:]??#?花瓣的長與寬 y?=?iris.target??#?花的種類#?分類決策樹,max_depth=2表示決策樹最大高度為2 tree_clf?=?DecisionTreeClassifier(max_depth=2) tree_clf.fit(X,?y)構(gòu)建完后,可以通過sklearn提供的plot_tree方法可視化決策樹:
from?sklearn?import?tree tree.plot_tree(tree_clf)???圖1
如果覺得sklearn的plot_tree方法繪制出的決策樹不太美觀,可以甩graphviz進行繪制,graphviz是一個繪圖軟件,需要先自行安裝,在MacOS下,安裝簡單:
brew?install?graphviz pip?install?graphviz安裝完后,變可以進行決策樹的繪制了:
from?graphviz?import?Source from?sklearn.tree?import?export_graphvizexport_graphviz(tree_clf,out_file=os.path.join(IMAGES_PATH,?"iris_tree.dot"),feature_names=iris.feature_names[2:],??#?特征class_names=iris.target_names,??#?分類rounded=True,??#?圓角filled=True,??#?顏色填充)Source.from_file(os.path.join(IMAGES_PATH,?"iris_tree.dot"))效果如圖:
圖2
怎么使用這棵決策樹?
假設(shè)你手里有一朵剛摘的鳶尾花,要對其進行分類,你會從決策樹的根節(jié)點開始(深度為0),判斷該花的花瓣寬度是否小于0.8cm,如果小于,那么就來到根的左子節(jié)點(圖中橙色節(jié)點,深度為1),該節(jié)點是葉子節(jié)點(沒有其他子節(jié)點),葉子節(jié)點的class值便是當(dāng)前這棵決策樹對當(dāng)前這朵花的預(yù)測,它認為這朵花的種類是setosa。
假設(shè)你還有另外一朵花,發(fā)現(xiàn)它的花瓣寬度大于0.8cm,那么你還需要繼續(xù)判斷花瓣寬度是否小于等于1.75cm,如果小于,那么就來的圖中綠色的葉子節(jié)點,此時,決策樹預(yù)測這朵花的種類是versicolor。
純度
先不糾結(jié)決策樹的細節(jié),來思考一個關(guān)鍵問題:決策樹是基于什么來做決策的?,比如圖2中的根節(jié)點,分裂成了2個子節(jié)點,左邊的子節(jié)點,不再分裂,而右邊的子節(jié)點繼續(xù)分裂,這里的分裂的依據(jù)是什么?
這就涉及到純度的概念,對于分類決策樹而言,通常會使用基尼系數(shù)或熵來判斷某個節(jié)點的純度,而對于回歸決策樹而言,通常會使用均方誤差(MSE)來判斷某節(jié)點純度。
圖3
基尼系數(shù)
我們先從基尼系數(shù)聊起。
百度搜索基尼系數(shù),會得到如下定義:
基尼系數(shù)(英文:Gini index、Gini Coefficient)是指國際上通用的、用以衡量一個國家或地區(qū)居民收入差距的常用指標(biāo)。基尼系數(shù)最大為“1”,最小等于“0”。
在決策樹中,基本性質(zhì)不變,只是它被用來表示決策樹中節(jié)點的純度,這里的純度正如其字面意思,如果決策樹的葉子節(jié)點里,只有一種類別,那么它就是純的,該葉子節(jié)點的基尼系數(shù)為0。
我們看回圖2,圖2中的gini表示當(dāng)前節(jié)點的基尼系數(shù),看到橙色節(jié)點(深度為1的左節(jié)點),可以發(fā)現(xiàn)gini為0,這是因為該節(jié)點的訓(xùn)練實例數(shù)量(節(jié)點中的samples屬性描述)為50個,而這50個都是種類setosa,如果訓(xùn)練實例的個數(shù)不全部屬于同一類,那么節(jié)點就是不純的,使用基尼系數(shù)表示不純度的公式為:
公式中表示決策樹第i個節(jié)點中實例數(shù)量與k類實例的比率。
還是以圖2為例,就公式代入,我們來計算一下圖2中綠色節(jié)點的gini。
In?[1]:?import?numpyIn?[2]:?1?-?numpy.square(49/54)?-?numpy.square(5/54) Out[2]:?0.1680384087791495熵
熵的概念源于熱力學(xué),很多硬科幻小說中也會出現(xiàn)它以及它涉及的熱力學(xué)第二原理。
在熱力學(xué)領(lǐng)域,熵主要用于描述分子混亂程度,而這里我們提及的熵主要是在信息論中的定義。在信息論里,熵被用于衡量一條信息的平均信息內(nèi)容,一條有價值的信息會讓決策的可能性(混亂程度)降低,這便是熵減的過程,在信息論里,將熵減過程稱為信息增益。
決策樹也會使用熵作為不純度的一種指標(biāo),如果節(jié)點的數(shù)據(jù)集中只包含一種類型的實例,那么熵為0,純度為0,用熵表示不純度的公式為:
sklearn中使用的決策樹默認會使用基尼系數(shù)來作為節(jié)點不純度的指標(biāo),但我們可以將criterion參數(shù)設(shè)置為「entropy」,讓sklearn使用熵來作為不純度計算方式,代碼如下:
from?sklearn.tree?import?DecisionTreeClassifierX?=?iris.data[:,?2:]??#?花瓣的長與寬 y?=?iris.target??#?花的種類#?分類決策樹,max_depth=2表示決策樹最大高度為2 tree_clf2?=?DecisionTreeClassifier(max_depth=2,?criterion="entropy") tree_clf2.fit(X,?y)相同的方式繪制一下:
圖4
獲得圖4后,我們來使用一下熵的公式來手動計算一下紫色節(jié)點的熵(entropy)。
In?[4]:?import?mathIn?[5]:?-(1/46)*math.log((1/46),?2)-(45/46)*math.log((45/46),2) Out[5]:?0.15109697051711368從圖2和圖4看,使用基尼系數(shù)或熵訓(xùn)練出的分類決策樹是一樣的,那我們該使用基尼系數(shù)還是使用熵呢?一旦腦海里出現(xiàn)這種問題,一律使用默認的,sklearn默認使用基尼系數(shù),那我們就使用基尼系數(shù),這是經(jīng)驗之談,sklearn的開發(fā)者有我們目前知識量暫時無法理解的考慮。
其實呢,大多數(shù)情況下,使用基尼系數(shù)還是熵,差異不大,都會產(chǎn)生相似的決策樹,只是基尼系數(shù)計算會快一些,所以作為sklearn的默認值,當(dāng)然,差異還是有的,使用基尼系數(shù)會讓你的決策樹從樹枝中分裂出最常見的類別,而使用熵則會傾向于生成更平衡的決策樹。
均方誤差(MSE)
當(dāng)我們通過決策樹來解決回歸類型的問題時,通常會使用MSE來作為節(jié)點是否向下分裂的依據(jù),這里已經(jīng)跟純度這個概念沒啥關(guān)系了,所謂純度其樸素的理解就建立在類別上的,某個數(shù)據(jù)集中,都是同一類數(shù)據(jù),那么就是高純度的,但回歸問題上,沒有類別的概念,所以也與純度的概念無關(guān)。
首先,我們利用sklearn構(gòu)建一顆回歸決策樹:
import?numpy?as?np from?sklearn.tree?import?DecisionTreeRegressornp.random.seed(42) m?=?200 #?隨機生成特征數(shù)據(jù) X?=?np.random.rand(m,?1) y?=?4?*?(X?-?0.5)?**?2 #?隨機生成對應(yīng)的目標(biāo) y?=?y?+?np.random.randn(m,?1)?/?10#?構(gòu)建回歸決策樹 tree_reg?=?DecisionTreeRegressor(max_depth=2,?random_state=42) tree_reg.fit(X,?y)然后用老方法,將其可視化顯示出來:
from?graphviz?import?Source from?sklearn.tree?import?export_graphvizexport_graphviz(tree_reg,out_file=os.path.join(IMAGES_PATH,?"reg_tree.dot"),feature_names='X',??#?特征class_names='y',??#?目標(biāo)rounded=True,??#?圓角filled=True,??#?顏色填充)Source.from_file(os.path.join(IMAGES_PATH,?"reg_tree.dot"))圖5
回歸決策樹長的很像分類決策樹,主要的差別是,回歸決策樹中的節(jié)點不再預(yù)測某個分類,而是預(yù)測一個具體的值,比如你想對一個的實例進行預(yù)測,那么從根節(jié)點看,最后會來到value=0.111的葉子節(jié)點上,在該葉子節(jié)點上,均分誤差為0.015。
CART訓(xùn)練算法
sklearn使用分類和回歸樹(Classification and Regression Tree,CART)算法來訓(xùn)練決策樹。
對于分類決策樹,CART工作原理為:
先使用單個特征和閾值(如:花瓣長度)將訓(xùn)練集劃分為兩個子集,如何選擇與呢?
算法會對與進行搜索比較,然后找到可以產(chǎn)生最純子集的一對與,其最小化的成本函數(shù)為:
其中:
、測量左右子集的不純度
、策略左右子集的實例數(shù)
CART算法通過上述邏輯將訓(xùn)練集分成兩部分后,會使用相同的邏輯對子集進行分割,然后一直分割下去,直到達到最大深度(由max_depath定義)
對于回歸決策樹,CART算法通過最小化MSE的方式來拆分訓(xùn)練集,公式如下:
其中:
從原理上理解,很容易發(fā)現(xiàn)CART算法是一種貪心算法,它會從最頂層開始搜索最優(yōu)的分裂方式,然后對子集也進行同樣的處理,多次分裂后,CART算法不會審視自己目前這樣的分裂產(chǎn)出的不純度是否全局最優(yōu)的。
通常,貪心算法可以獲得一個不錯的解,但不能保證該解是最優(yōu)解,為了便于理解,我再舉一個例子,假設(shè)你要從廣州去上海,你在每個節(jié)點上都選擇最短的路,但這樣選擇下來的總路徑可能不是最短,這便是貪心算法面臨的情況。
決策邊界可視化
除了前文中提及的將決策樹本身可視化外,還有另一種常見的可視化方式,那便是將決策樹的決策邊界可視化出來。當(dāng)然,如果數(shù)據(jù)量很大,還需要對數(shù)據(jù)進行采樣后在進行可視化處理。
編寫一個用于可視化決策邊界的函數(shù):
#?To?plot?pretty?figures %matplotlib?inline import?matplotlib?as?mpl import?matplotlib.pyplot?as?plt mpl.rc('axes',?labelsize=14) mpl.rc('xtick',?labelsize=12) mpl.rc('ytick',?labelsize=12)from?matplotlib.colors?import?ListedColormapdef?plot_decision_boundary(clf,?X,?y,?axes=[0,?7.5,?0,?3],?iris=True,?legend=False,?plot_training=True):#?指定間隔內(nèi),返回均勻的數(shù)字x1s?=?np.linspace(axes[0],?axes[1],?100)x2s?=?np.linspace(axes[2],?axes[3],?100)#?meshgrid函數(shù):用兩個坐標(biāo)軸上的點在平面上畫網(wǎng)格,其實返回的是矩陣x1,?x2?=?np.meshgrid(x1s,?x2s)#?按行連接兩個矩陣,就是把兩矩陣左右相加,要求行數(shù)相等X_new?=?np.c_[x1.ravel(),?x2.ravel()]#?預(yù)測y_pred?=?clf.predict(X_new).reshape(x1.shape)custom_cmap?=?ListedColormap(['#fafab0','#9898ff','#a0faa0'])plt.contourf(x1,?x2,?y_pred,?alpha=0.3,?cmap=custom_cmap)if?not?iris:custom_cmap2?=?ListedColormap(['#7d7d58','#4c4c7f','#507d50'])plt.contour(x1,?x2,?y_pred,?cmap=custom_cmap2,?alpha=0.8)if?plot_training:plt.plot(X[:,?0][y==0],?X[:,?1][y==0],?"yo",?label="Iris?setosa")plt.plot(X[:,?0][y==1],?X[:,?1][y==1],?"bs",?label="Iris?versicolor")plt.plot(X[:,?0][y==2],?X[:,?1][y==2],?"g^",?label="Iris?virginica")plt.axis(axes)if?iris:plt.xlabel("Petal?length",?fontsize=14)plt.ylabel("Petal?width",?fontsize=14)else:plt.xlabel(r"$x_1$",?fontsize=18)plt.ylabel(r"$x_2$",?fontsize=18,?rotation=0)if?legend:plt.legend(loc="lower?right",?fontsize=14)基于鳶尾花數(shù)據(jù)集進行可視化,代碼如下:
plt.figure(figsize=(8,?4)) X?=?iris.data[:,?2:]??#?花瓣的長與寬 y?=?iris.target??#?花的種類 plot_decision_boundary(tree_clf,?X,?y) plt.plot([2.45,?2.45],?[0,?3],?"k-",?linewidth=2) plt.plot([2.45,?7.5],?[1.75,?1.75],?"k--",?linewidth=2) plt.plot([4.95,?4.95],?[0,?1.75],?"k:",?linewidth=2) plt.plot([4.85,?4.85],?[1.75,?3],?"k:",?linewidth=2) plt.text(1.40,?1.0,?"Depth=0",?fontsize=15) plt.text(3.2,?1.80,?"Depth=1",?fontsize=13) plt.text(4.05,?0.5,?"(Depth=2)",?fontsize=11) plt.show()效果:
決策樹的問題
容易過擬合
決策樹的特點是,它極少對訓(xùn)練數(shù)據(jù)本身做出假設(shè),對比看線性模型,如果你選擇使用線性模型,其實你就假設(shè)了訓(xùn)練數(shù)據(jù)是線性變化的,否則線性模型不可能得出好的結(jié)果,而決策樹不會有這樣的假設(shè),這個特點容易讓決策樹出現(xiàn)過擬合的問題。
以一個具體的例子來展示決策樹過擬合的情況:
首先,我通過sklearn的make_moons方法生成半環(huán)形分布式的數(shù)據(jù)集,直觀理解如下:
from?sklearn.datasets?import?make_moons?plt.subplot(122)?? x1,y1=make_moons(n_samples=1000,noise=0.1)?? plt.title('make_moons?function?example')?? plt.scatter(x1[:,0],x1[:,1],marker='o',c=y1)?? plt.show()上述代碼效果如下圖:
有了半環(huán)形分布的數(shù)據(jù)后,訓(xùn)練分類決策樹,代碼如下:
決策樹對半環(huán)形分布的數(shù)據(jù),其決策邊界如下:
從圖可以看出,決策樹對該數(shù)據(jù)有明顯的過擬合情況。
解決方法也很簡單,就是使用各種超參數(shù)對模型進行正則化調(diào)整,即限制模型的擬合能力,從而希望獲得具有更好泛化的模型。sklearn對決策樹提供了max_depth(最大深度)、min_samples_split(分裂前節(jié)點必須有的最小樣本數(shù))、min_samples_leaf(葉節(jié)點必須要有的最小樣本數(shù)),等等超參數(shù)用于正則化。
這里我使用max_depth和min_samples_leaf對分類決策樹做了相應(yīng)的正則化。
from?sklearn.datasets?import?make_moons Xm,?ym?=?make_moons(n_samples=100,?noise=0.25,?random_state=53)#?分類決策樹 tree_clf?=?DecisionTreeClassifier(random_state=42,?max_depth=5,?min_samples_leaf=4) tree_clf.fit(Xm,?ym)#?可視化 plt.figure(figsize=(8,?4)) plot_decision_boundary(tree_clf,?Xm,?ym,?axes=[-1.5,?2.4,?-1,?1.5],?iris=False) plt.title("regularization",?fontsize=16) plt.show()可視化效果如圖:
過擬合的問題不只是在分類決策樹上,在回歸決策樹上也會有,這里可視化的展示一下,讓你有更直觀的了解,如下圖:
上圖中,左半部分,無疑是過擬合的,而右半部分,使用了min_smaples_leaf做正則化的限制,效果還可以。
此外,從右半部分的圖也可以看出,因為回歸決策樹使用MSE來做節(jié)點分裂標(biāo)準(zhǔn),所以決策樹的預(yù)測值都是對應(yīng)區(qū)域內(nèi)實例的目標(biāo)平均值。
不穩(wěn)定性
前文中,我們已經(jīng)展示,決策樹對多種數(shù)據(jù)類型的處理情況并可視化的展示出其決策邊界,仔細觀察會發(fā)現(xiàn),無論是分類決策樹還是回歸決策樹,其決策邊界都喜歡垂直于X軸或Y軸,這使得他們對訓(xùn)練集數(shù)據(jù)的旋轉(zhuǎn)產(chǎn)生的變化特別敏感,一個具體的例子:
np.random.seed(6) Xs?=?np.random.rand(100,?2)?-?0.5 ys?=?(Xs[:,?0]?>?0).astype(np.float32)?*?2angle?=?np.pi?/?4 rotation_matrix?=?np.array([[np.cos(angle),?-np.sin(angle)],?[np.sin(angle),?np.cos(angle)]]) Xsr?=?Xs.dot(rotation_matrix)tree_clf_s?=?DecisionTreeClassifier(random_state=42) tree_clf_s.fit(Xs,?ys) tree_clf_sr?=?DecisionTreeClassifier(random_state=42) tree_clf_sr.fit(Xsr,?ys)fig,?axes?=?plt.subplots(ncols=2,?figsize=(10,?4),?sharey=True) plt.sca(axes[0]) plot_decision_boundary(tree_clf_s,?Xs,?ys,?axes=[-0.7,?0.7,?-0.7,?0.7],?iris=False) plt.sca(axes[1]) plot_decision_boundary(tree_clf_sr,?Xsr,?ys,?axes=[-0.7,?0.7,?-0.7,?0.7],?iris=False) plt.ylabel("") plt.show()從上圖可以看出,左邊的圖,決策樹使用一條線就將數(shù)據(jù)做好了分類,但我們將數(shù)據(jù)旋轉(zhuǎn)一下,獲得右邊的圖,再使用決策樹去處理,會發(fā)現(xiàn),決策樹需要繪制多條線才能將數(shù)據(jù)做好分類,即右邊數(shù)據(jù)訓(xùn)練出的決策樹模型,可能無法很好的泛化。
更概括的說,決策樹對訓(xùn)練集中微小的數(shù)據(jù)變化都非常敏感,一個具體例子,如果我們直接使用鳶尾花數(shù)據(jù)集進行可視化,效果如圖(前面展示過):
但我們使用相同的數(shù)據(jù)、類似的代碼,會得到與上圖大為不同的效果:
上述代碼可視化的效果如下:
從上圖可知,即便是相同的訓(xùn)練數(shù)據(jù)上,如果random_state不同(sklearn選擇特征集的算法是隨機的,通過random_state參數(shù)控制),獲得的決策樹模型也完全不同了。
隨機森林可以通過對多個樹進行平均預(yù)測來限制這種不穩(wěn)定性,關(guān)于隨機森林的內(nèi)容,我們后面的文章會討論。
結(jié)尾
本文涉及代碼已提交到:https://github.com/ayuLiao/machine_learning_interstellar_journey 項目中。
我是二兩,下篇文章見。
總結(jié)
以上是生活随笔為你收集整理的【机器学习系列】聊聊决策树的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 听说学习是件苦差事——Linux第一天
- 下一篇: 程序员的奋斗史(十二)——谈信念