【Machine Learning in Action --3】决策树ID3算法
1、簡單概念描述
?????? 決策樹的類型有很多,有CART、ID3和C4.5等,其中CART是基于基尼不純度(Gini)的,這里不做詳解,而ID3和C4.5都是基于信息熵的,它們兩個得到的結果都是一樣的,本次定義主要針對ID3算法。下面我們介紹信息熵的定義。
? ? ? p(ai):事件ai發生的概率
I(ai)=-log2(p(ai)):表示為事件ai的不確定程度,稱為ai的自信息量
H=sum(p(ai)*I(ai)):稱為信源S的平均信息量—信息熵
Gain = BaseEntropy – newEntropy:信息增益
決策樹學習采用的是自頂向下的遞歸方法,其基本思想是以信息熵為度量構造一棵熵值下降最快的樹,到葉子節點處的熵值為零,此時每個葉節點中的實例都屬于同一類。ID3的原理是基于信息熵增益Gain達到最大,設原始問題的標簽有正例和負例,p和n表示其相應的個數。則原始問題的信息熵為? ??其中N為該特征所取值的個數,比如{rain,sunny},則N即為2
ID3易出現的問題:如果是取值更多的屬性,更容易使得數據更“純”(尤其是連續型數值),其信息增益更大,決策樹會首先挑選這個屬性作為樹的頂點。結果訓練出來的形狀是一棵龐大且深度很淺的樹,這樣的劃分是極為不合理的。 此時可以采用C4.5來解決,C4.5的思想是最大化Gain除以下面這個公式即得到信息增益率:
其中底為2
2、決策樹的優缺點
優點:計算復雜度不高,輸出結果易于理解,對中間值缺失不敏感,可以處理不相關特征數據
缺點:可能產生過度匹配問題
適用數據類型:數值型和標稱型
3、python代碼的實現
以下的代碼根據這些數據理解
數據1中包含5個海洋動物,特征包括:不浮出水面是否可以生存,以及是否有腳蹼。我們可以將這些動物分成兩類:魚類和非魚類。
| ? | 不浮出水面是否可以生存 | 是否有腳蹼 | 屬于魚類 |
| 1 | 是 | 是 | 是 |
| 2 | 是 | 是 | 是 |
| 3 | 是 | 否 | 否 |
| 4 | 否 | 是 | 否 |
| 5 | 否 | 是 | 否 |
?
| ? | 特征[0](no surfacing) | 特征[1](flippers) | 特征[-1]fish |
| dataSet[0] | 1 | 1 | yes |
| dataSet[1] | 1 | 1 | yes |
| dataSet[2] | 0 | 1 | no |
| dataSet[3] | 0 | 1 | no |
| dataSet[4] | 0 | 1 | no |
創建名為trees.py的文件,下面代碼內容都在此文件中。
(1)計算信息熵
# -*- coding: utf-8 -*-#計算給定數據集的香農熵 def calcShannonEnt(dataSet): numEntries=len(dataSet) #數據實例總數labelCounts={} #對類別數量創建了一個數據字典,鍵值是最后一列的數值for featVec in dataSet: #featVec表示特征集currentLabel=featVec[-1] # currentLabel表示當前鍵值,featVec[-1]表示數據集中的最后一列#如果當前鍵值不存在,擴展字典將當前鍵值加入字典,設置當前鍵值表示的類別數量為0if currentLabel not in labelCounts.keys(): labelCounts[currentLabel]=0#如果當前鍵值存在,則類別數量累加labelCounts[currentLabel]+=1shannonEnt=0.0for key in labelCounts:prob=float(labelCounts[key])/numEntries #每個鍵值都記錄了當前類別出現的次數shannonEnt -=prob*log(prob,2)return shannonEnt
(2)創建數據集
#創建數據集 def createDataSet():dataSet=[[1,1,'yes'],[1,1,'yes'],[0,1,'no'],[0,1,'no'],[0,1,'no']]labels=['no surfacing','flippers']return dataSet,labels在python命令提示符下輸入下列命令:
1 >>> import trees 2 >>> reload(trees) 3 <module 'trees' from 'E:\python excise\trees.pyc'> 4 >>> myDat,labels=trees.createDataSet() 5 >>> myDat 6 [[1, 1, 'yes'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']] 7 >>> trees.calcShannonEnt(myDat) 8 0.9709505944546686 9 >>>熵越高,則混合的數據越多,在數據集中添加更多的分類,觀察熵是如何變化的,這里增加第三個名為maybe的分類,測試熵的變化:
>>> myDat[0][-1]='maybe' >>> myDat [[1, 1, 'maybe'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']] >>> trees.calcShannonEnt(myDat) 1.3709505944546687得到熵后,我們可以按照獲取最大信息增益的方法劃分數據集
(3)劃分數據集
?我們將對每個特征劃分數據集的結果計算一次信息熵,然后判斷按照哪個特征劃分數據集是最好的劃分方式
#按照給定特征劃分數據集 #dataSet:待劃分的數據集,axis:劃分數據集的特征,value:需要返回的特征的值 def splitDataSet(dataSet,axis,value):retDataSet=[] #為了不修改原始數據dataSet,創建一個新的列表對象for featVec in dataSet:if featVec[axis]==value: reducedFeatVec=featVec[:axis] #獲取從第0列到特征列的數據reducedFeatVec.extend(featVec[axis+1:]) #獲取從特征列之后的數據retDataSet.append(reducedFeatVec) #目前reducedFeatVec表示除了特征列的數據return retDataSet 1 >>> reload(trees) 2 <module 'trees' from 'E:\python excise\trees.pyc'> 3 >>> myDat,labels=trees.createDataSet() 4 >>> myDat 5 [[1, 1, 'yes'], [1, 1, 'yes'], [0, 1, 'no'], [0, 1, 'no'], [0, 1, 'no']] 6 >>> trees.splitDataSet(myDat,0,1) 7 [[1, 'yes'], [1, 'yes']] 8 >>> trees.splitDataSet(myDat,0,0) 9 [[1, 'no'], [1, 'no'], [1, 'no']](4)選擇最好的特征進行劃分
#選擇最好的數據集劃分方式 def chooseBestFeatureToSplit(dataSet):numFeatures=len(dataSet[0])-1 #減去類別那一列baseEntropy=calcShannonEnt(dataSet) #計算整個數據集的原始香農熵bestInfoGain=0.0;bestFeature=-1 #現在最好的特征是數據集中的最后一列#i=0,新熵,增益
#i=1,新熵,增益for i in range(numFeatures): #循環遍歷數據集中的所有特征featList=[example[i] for example in dataSet] #獲取第i個特征所有可能的取值,特征0一個列表,特征1一個列表...uniqueVals=set(featList) #集合數據類型(set)與列表類型相似,不同之處僅在于集合類型中每個值互不相同newEntropy=0.0for value in uniqueVals:subDataSet=splitDataSet(dataSet,i,value) #劃分后的數據集prob=len(subDataSet)/float(len(dataSet))newEntropy+=prob*calcShannonEnt(subDataSet) #求劃分完的數據集的熵infoGain=baseEntropy-newEntropyif(infoGain>bestInfoGain):bestInfoGain=infoGainbestFeature=i return bestFeature
注意:這里數據集需要滿足以下兩個辦法:
<1>所有的列元素都必須具有相同的數據長度
<2>數據的最后一列或者每個實例的最后一個元素是當前實例的類別標簽。
1 >>> reload(trees) 2 <module 'trees' from 'E:\python excise\trees.pyc'> 3 >>> myDat,labels=trees.createDataSet() 4 >>> trees.chooseBestFeatureToSplit(myDat) 5 0(5)創建樹的代碼
Python用字典類型來存儲樹的結構,返回的結果是myTree-字典
#創建樹的函數代碼 def createTree(dataSet,labels):classList=[example[-1] for example in dataSet]if classList.count(classList[0])==len(classList): #類別完全相同規則停止繼續劃分return classList[0]if len(dataSet[0])==1: #確認至少有數據集return majorityCnt(classList)bestFeat=chooseBestFeatureToSplit(dataSet)bestFeatLabel=labels[bestFeat]myTree={bestFeatLabel:{}}del(labels[bestFeat]) #得到列表包含的所有屬性featValues=[example[bestFeat] for example in dataSet]uniqueVals=set(featValues)for value in uniqueVals:subLabels=labels[:]myTree[bestFeatLabel][value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)return myTree
其中遞歸結束當且僅當該類別中標簽完全相同或者遍歷所有的特征此時返回次數最多的
1 >>> reload(trees) 2 <module 'trees' from 'E:\python excise\trees.pyc'> 3 >>> myDat,labels=trees.createDataSet() 4 >>> myTree=trees.createTree(myDat,labels) 5 >>> myTree 6 {'no surfacing': {0: 'no', 1: 'yes'}}其中當所有的特征都用完時,采用多數表決的方法來決定該葉子節點的分類,即該葉節點中屬于某一類最多的樣本數,那么我們就說該葉節點屬于那一類。即為如果數據集已經處理了所有的屬性,但是類標簽依然不是唯一的,此時我們要決定如何定義該葉子節點,在這種情況下,我們通常采用多數表決的方法來決定該葉子節點的分類。代碼如下:
def majorityCnt(classList):classCount={}for vote in classList:if vote not in classCount.keys():classCount[vote]=0classCount[vote]+=1sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)return sortedClassCount[0][0](6)使用決策樹執行分類
#測試算法:使用決策樹執行分類 def classify(inputTree,featLabels,testVec):firstStr=inputTree.keys()[0]secondDict=inputTree[firstStr]featIndex=featLabels.index(firstStr)for key in secondDict.keys():if testVec[featIndex]==key:if type(secondDict[key]).__name__=='dict':classLabel=classify(secondDict[key],featLabels,testVec)else:classLabel=secondDict[key]return classLabel 1 >>> import trees 2 >>> myDat,labels=trees.createDataSet() 3 >>> labels 4 ['no surfacing', 'flippers'] 5 >>> trees.classify(myTree,labels,[1,0]) 6 'no' 7 >>> trees.classify(myTree,labels,[1,1]) 8 'yes'注意遞歸的思想很重要。
(7)決策樹的存儲
構造決策樹是一個很耗時的任務。為了節省計算時間,最好能夠在每次執行分類時調用已經構造好的決策樹。為了解決這個問題,需要使用python模塊pickle序列化對象,序列化對象可以在磁盤上保存對象,并在需要的時候讀取出來。
#使用算法:決策樹的存儲 def storeTree(inputTree,filename):import picklefw=open(filename,'w')pickle.dump(inputTree,fw)fw.close() def grabTree(filename):import picklefr=open(filename)return pickle.load(fr) 1 >>> reload(trees) 2 >>><module 'tree' from 'trees.py'> 3 >>> trees.storeTree(myTree,'classifierStorage.txt') 4 >>> trees.grabTree('classifierStorage.txt') 5 {'no surfacing': {0: 'no', 1: {'flippers': {0: 'no', 1: 'yes'}}}}classifierStorage.txt如下:
補充:
用matplotlib注解上述形成的決策樹
Matplotlib提供了一個注解工具annotations,非常有用,它可以在數據圖形上添加文本注釋。注解通常用于解釋數據的內容。
創建名為treePlotter.py文件,下面代碼都在此文件中
#!/usr/bin/python # -*- coding: utf-8 -*- import matplotlib.pyplot as plt from numpy import * import operator #定義文本框和箭頭格式 decisionNode=dict(boxstyle="sawtooth",fc="0.8") leafNode=dict(boxstyle="round4",fc="0.8") arrow_args=dict(arrowstyle="<-") #繪制箭頭的注解 def plotNode(nodeTxt,centerPt,parentPt,nodeType):createPlot.ax1.annotate(nodeTxt,xy=parentPt,xycoords='axes fraction',xytext=centerPt,textcoords='axes fraction',va="center",ha="center",bbox=nodeType,arrowprops=arrow_args) def createPlot():fig=plt.figure(1,facecolor='white')fig.clf()createPlot.ax1=plt.subplot(111,frameon=False)plotNode(U'決策節點',(0.5,0.1),(0.1,0.5),decisionNode)plotNode(U'葉節點',(0.8,0.1),(0.3,0.8),leafNode)plt.show() #獲取葉節點的數目和樹的層數 def getNumLeafs(myTree):numLeafs=0firstStr=myTree.keys()[0]secondDict=myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':numLeafs += getNumLeafs(secondDict[key])else: numLeafs +=1return numLeafs def getTreeDepth(myTree):maxDepth=0firstStr=myTree.keys()[0]secondDict=myTree[firstStr]for key in secondDict.keys():if type(secondDict[key]).__name__=='dict':thisDepth=1+getTreeDepth(secondDict[key])else:thisDepth=1if thisDepth>maxDepth:maxDepth=thisDepthreturn maxDepthdef retrieveTree(i):listOfTrees=[{'no surfacing':{0:'no',1:{'flippers':{0:'no',1:'yes'}}}},\{'no surfacing':{0:'no',1:{'flippers':{0:{'head':{0:'no',1:'yes'}},1:'no'}}}}]return listOfTrees[i] #在父節點間填充文本信息 def plotMidText(cntrPt,parentPt,txtString):xMid=(parentPt[0]-cntrPt[0])/2.0+cntrPt[0]yMid=(parentPt[1]-cntrPt[1])/2.0+cntrPt[1]createPlot.ax1.text(xMid,yMid,txtString) #計算寬和高 def plotTree(myTree,parentPt,nodeTxt):numLeafs=getNumLeafs(myTree)depth=getTreeDepth(myTree)firstStr=myTree.keys()[0]cntrPt=(plotTree.xOff+(1.0+float(numLeafs))/2.0/plotTree.totalW,plotTree.yOff)plotMidText(cntrPt,parentPt,nodeTxt) #計算父節點和子節點的中間位置plotNode(firstStr,cntrPt,parentPt,decisionNode)secondDict=myTree[firstStr]plotTree.yOff=plotTree.yOff-1.0/plotTree.totalDfor key in secondDict.keys():if type(secondDict[key]).__name__=='dict':plotTree(secondDict[key],cntrPt,str(key))else:plotTree.xOff=plotTree.xOff+1.0/plotTree.totalWplotNode(secondDict[key],(plotTree.xOff,plotTree.yOff),cntrPt,leafNode)plotMidText((plotTree.xOff,plotTree.yOff),cntrPt,str(key))plotTree.yOff=plotTree.yOff+1.0/plotTree.totalD def createPlot(inTree):fig=plt.figure(1,facecolor='white')fig.clf()axprops=dict(xticks=[],yticks=[])createPlot.ax1=plt.subplot(111,frameon=False,**axprops)plotTree.totalW=float(getNumLeafs(inTree))plotTree.totalD=float(getTreeDepth(inTree))plotTree.xOff=-0.5/plotTree.totalW;plotTree.yOff=1.0;plotTree(inTree,(0.5,1.0),'')plt.show()?
其中index方法為查找當前列表中第一個匹配firstStr的元素 返回的為索引。
?
總結
以上是生活随笔為你收集整理的【Machine Learning in Action --3】决策树ID3算法的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: JSP页面空指针异常调错办法之weblo
- 下一篇: Openssh学习笔记