西瓜决策树-纯算法
文章目錄
- ID3決策樹算法
- 一、理論
- 二、代碼實現
- 1.引入數據和需要用到的包
- 2.函數
- 計算熵
- 拆分數據集
- 選擇最好的特征
- 尋找最多的,作為標簽
- 生成樹
- 初始化
- 畫圖
- 3.結果
- 三、參考
ID3決策樹算法
一、理論
純度(purity)
對于一個分支結點,如果該結點所包含的樣本都屬于同一類,那么它的純度為1,而我們總是希望純度越高越好,也就是盡可能多的樣本屬于同一類別。那么如何衡量“純度”呢?由此引入“信息熵”的概念。
信息熵(information entropy)
假定當前樣本集合D中第k類樣本所占的比例為pk(k=1,2,…,|y|),則D的信息熵定義為:
顯然,Ent(D)值越小,D的純度越高。因為0<=pk<= 1,故log2 pk<=0,Ent(D)>=0. 極限情況下,考慮D中樣本同屬于同一類,則此時的Ent(D)值為0(取到最小值)。當D中樣本都分別屬于不同類別時,Ent(D)取到最大值log2 |y|.
信息增益(information gain)
假定離散屬性a有V個可能的取值{a1,a2,…,aV}. 若使用a對樣本集D進行分類,則會產生V個分支結點,記Dv為第v個分支結點包含的D中所有在屬性a上取值為av的樣本。不同分支結點樣本數不同,我們給予分支結點不同的權重:|Dv|/|D|, 該權重賦予樣本數較多的分支結點更大的影響、由此,用屬性a對樣本集D進行劃分所獲得的信息增益定義為:
其中,Ent(D)是數據集D劃分前的信息熵,∑v=1 |Dv|/|D|·Ent(Dv)可以表示為劃分后的信息熵。“前-后”的結果表明了本次劃分所獲得的信息熵減少量,也就是純度的提升度。顯然,Gain(D,a) 越大,獲得的純度提升越大,此次劃分的效果越好。
增益率(gain ratio)
基于信息增益的最優屬性劃分原則——信息增益準則,對可取值數據較多的屬性有所偏好。C4.5算法使用增益率替代信息增益來選擇最優劃分屬性,增益率定義為:
其中
IV(a) = -∑v=1 |Dv|/|D|·log2 |Dv|/|D|稱為屬性a的固有值。屬性a的可能取值數目越多(即V越大),則IV(a)的值通常會越大。這在一定程度上消除了對可取值數據較多的屬性的偏好。
事實上,增益率準則對可取值數目較少的屬性有所偏好,C4.5算法并不是直接使用增益率準則,而是先從候選劃分屬性中找出信息增益高于平均水平的屬性,再從中選擇增益率最高的。基尼指數(Gini index)
CART決策樹算法使用基尼指數來選擇劃分屬性,基尼指數定義為:
可以這樣理解基尼指數:從數據集D中隨機抽取兩個樣本,其類別標記不一致的概率。Gini(D)越小,純度越高。
屬性a的基尼指數定義:
Gain_index(D,a) = ∑v=1 |Dv|/|D|·Gini(Dv)使用基尼指數選擇最優劃分屬性,即選擇使得劃分后基尼指數最小的屬性作為最優劃分屬性。
二、代碼實現
1.引入數據和需要用到的包
import numpy as np import pandas as pd import sklearn.tree as st import math data = pd.read_csv('./西瓜數據集.csv') data| 青綠 | 蜷縮 | 濁響 | 清晰 | 凹陷 | 硬滑 | 是 |
| 烏黑 | 蜷縮 | 沉悶 | 清晰 | 凹陷 | 硬滑 | 是 |
| 烏黑 | 蜷縮 | 濁響 | 清晰 | 凹陷 | 硬滑 | 是 |
| 青綠 | 蜷縮 | 沉悶 | 清晰 | 凹陷 | 硬滑 | 是 |
| 淺白 | 蜷縮 | 濁響 | 清晰 | 凹陷 | 硬滑 | 是 |
| 青綠 | 稍蜷 | 濁響 | 清晰 | 稍凹 | 軟粘 | 是 |
| 烏黑 | 稍蜷 | 濁響 | 稍糊 | 稍凹 | 軟粘 | 是 |
| 烏黑 | 稍蜷 | 濁響 | 清晰 | 稍凹 | 硬滑 | 是 |
| 烏黑 | 稍蜷 | 沉悶 | 稍糊 | 稍凹 | 硬滑 | 否 |
| 青綠 | 硬挺 | 清脆 | 清晰 | 平坦 | 軟粘 | 否 |
| 淺白 | 硬挺 | 清脆 | 模糊 | 平坦 | 硬滑 | 否 |
| 淺白 | 蜷縮 | 濁響 | 模糊 | 平坦 | 軟粘 | 否 |
| 青綠 | 稍蜷 | 濁響 | 稍糊 | 凹陷 | 硬滑 | 否 |
| 淺白 | 稍蜷 | 沉悶 | 稍糊 | 凹陷 | 硬滑 | 否 |
| 烏黑 | 稍蜷 | 濁響 | 清晰 | 稍凹 | 軟粘 | 否 |
| 淺白 | 蜷縮 | 濁響 | 模糊 | 平坦 | 硬滑 | 否 |
| 青綠 | 蜷縮 | 沉悶 | 稍糊 | 稍凹 | 硬滑 | 否 |
2.函數
計算熵
def calcEntropy(dataSet):mD = len(dataSet)dataLabelList = [x[-1] for x in dataSet]dataLabelSet = set(dataLabelList)ent = 0for label in dataLabelSet:mDv = dataLabelList.count(label)prop = float(mDv) / mDent = ent - prop * np.math.log(prop, 2)return ent拆分數據集
# index - 要拆分的特征的下標 # feature - 要拆分的特征 # 返回值 - dataSet中index所在特征為feature,且去掉index一列的集合 def splitDataSet(dataSet, index, feature):splitedDataSet = []mD = len(dataSet)for data in dataSet:if(data[index] == feature):sliceTmp = data[:index]sliceTmp.extend(data[index + 1:])splitedDataSet.append(sliceTmp)return splitedDataSet選擇最好的特征
# 返回值 - 最好的特征的下標 def chooseBestFeature(dataSet):entD = calcEntropy(dataSet)mD = len(dataSet)featureNumber = len(dataSet[0]) - 1maxGain = -100maxIndex = -1for i in range(featureNumber):entDCopy = entDfeatureI = [x[i] for x in dataSet]featureSet = set(featureI)for feature in featureSet:splitedDataSet = splitDataSet(dataSet, i, feature) # 拆分數據集mDv = len(splitedDataSet)entDCopy = entDCopy - float(mDv) / mD * calcEntropy(splitedDataSet)if(maxIndex == -1):maxGain = entDCopymaxIndex = ielif(maxGain < entDCopy):maxGain = entDCopymaxIndex = ireturn maxIndex尋找最多的,作為標簽
# 返回值 - 標簽 def mainLabel(labelList):labelRec = labelList[0]maxLabelCount = -1labelSet = set(labelList)for label in labelSet:if(labelList.count(label) > maxLabelCount):maxLabelCount = labelList.count(label)labelRec = labelreturn labelRec生成樹
def createFullDecisionTree(dataSet, featureNames, featureNamesSet, labelListParent):labelList = [x[-1] for x in dataSet]if(len(dataSet) == 0):return mainLabel(labelListParent)elif(len(dataSet[0]) == 1): #沒有可劃分的屬性了return mainLabel(labelList) #選出最多的label作為該數據集的標簽elif(labelList.count(labelList[0]) == len(labelList)): # 全部都屬于同一個Labelreturn labelList[0]bestFeatureIndex = chooseBestFeature(dataSet)bestFeatureName = featureNames.pop(bestFeatureIndex)myTree = {bestFeatureName: {}}featureList = featureNamesSet.pop(bestFeatureIndex)featureSet = set(featureList)for feature in featureSet:featureNamesNext = featureNames[:]featureNamesSetNext = featureNamesSet[:][:]splitedDataSet = splitDataSet(dataSet, bestFeatureIndex, feature)myTree[bestFeatureName][feature] = createFullDecisionTree(splitedDataSet, featureNamesNext, featureNamesSetNext, labelList)return myTree初始化
# 返回值 # dataSet 數據集 # featureNames 標簽 # featureNamesSet 列標簽 def readWatermelonDataSet():dataSet = data.values.tolist()featureNames =['色澤', '根蒂', '敲擊', '紋理', '臍部', '觸感']#獲取featureNamesSetfeatureNamesSet = []for i in range(len(dataSet[0]) - 1):col = [x[i] for x in dataSet]colSet = set(col)featureNamesSet.append(list(colSet))return dataSet, featureNames, featureNamesSet畫圖
# 能夠顯示中文 matplotlib.rcParams['font.sans-serif'] = ['SimHei'] matplotlib.rcParams['font.serif'] = ['SimHei']# 分叉節點,也就是決策節點 decisionNode = dict(boxstyle="sawtooth", fc="0.8")# 葉子節點 leafNode = dict(boxstyle="round4", fc="0.8")# 箭頭樣式 arrow_args = dict(arrowstyle="<-")def plotNode(nodeTxt, centerPt, parentPt, nodeType):"""繪制一個節點:param nodeTxt: 描述該節點的文本信息:param centerPt: 文本的坐標:param parentPt: 點的坐標,這里也是指父節點的坐標:param nodeType: 節點類型,分為葉子節點和決策節點:return:"""createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',xytext=centerPt, textcoords='axes fraction',va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)def getNumLeafs(myTree):"""獲取葉節點的數目:param myTree::return:"""# 統計葉子節點的總數numLeafs = 0# 得到當前第一個key,也就是根節點firstStr = list(myTree.keys())[0]# 得到第一個key對應的內容secondDict = myTree[firstStr]# 遞歸遍歷葉子節點for key in secondDict.keys():# 如果key對應的是一個字典,就遞歸調用if type(secondDict[key]).__name__ == 'dict':numLeafs += getNumLeafs(secondDict[key])# 不是的話,說明此時是一個葉子節點else:numLeafs += 1return numLeafsdef getTreeDepth(myTree):"""得到數的深度層數:param myTree::return:"""# 用來保存最大層數maxDepth = 0# 得到根節點firstStr = list(myTree.keys())[0]# 得到key對應的內容secondDic = myTree[firstStr]# 遍歷所有子節點for key in secondDic.keys():# 如果該節點是字典,就遞歸調用if type(secondDic[key]).__name__ == 'dict':# 子節點的深度加1thisDepth = 1 + getTreeDepth(secondDic[key])# 說明此時是葉子節點else:thisDepth = 1# 替換最大層數if thisDepth > maxDepth:maxDepth = thisDepthreturn maxDepthdef plotMidText(cntrPt, parentPt, txtString):"""計算出父節點和子節點的中間位置,填充信息:param cntrPt: 子節點坐標:param parentPt: 父節點坐標:param txtString: 填充的文本信息:return:"""# 計算x軸的中間位置xMid = (parentPt[0]-cntrPt[0])/2.0 + cntrPt[0]# 計算y軸的中間位置yMid = (parentPt[1]-cntrPt[1])/2.0 + cntrPt[1]# 進行繪制createPlot.ax1.text(xMid, yMid, txtString)def plotTree(myTree, parentPt, nodeTxt):"""繪制出樹的所有節點,遞歸繪制:param myTree: 樹:param parentPt: 父節點的坐標:param nodeTxt: 節點的文本信息:return:"""# 計算葉子節點數numLeafs = getNumLeafs(myTree=myTree)# 計算樹的深度depth = getTreeDepth(myTree=myTree)# 得到根節點的信息內容firstStr = list(myTree.keys())[0]# 計算出當前根節點在所有子節點的中間坐標,也就是當前x軸的偏移量加上計算出來的根節點的中心位置作為x軸(比如說第一次:初始的x偏移量為:-1/2W,計算出來的根節點中心位置為:(1+W)/2W,相加得到:1/2),當前y軸偏移量作為y軸cntrPt = (plotTree.xOff + (1.0 + float(numLeafs))/2.0/plotTree.totalW, plotTree.yOff)# 繪制該節點與父節點的聯系plotMidText(cntrPt, parentPt, nodeTxt)# 繪制該節點plotNode(firstStr, cntrPt, parentPt, decisionNode)# 得到當前根節點對應的子樹secondDict = myTree[firstStr]# 計算出新的y軸偏移量,向下移動1/D,也就是下一層的繪制y軸plotTree.yOff = plotTree.yOff - 1.0/plotTree.totalD# 循環遍歷所有的keyfor key in secondDict.keys():# 如果當前的key是字典的話,代表還有子樹,則遞歸遍歷if isinstance(secondDict[key], dict):plotTree(secondDict[key], cntrPt, str(key))else:# 計算新的x軸偏移量,也就是下個葉子繪制的x軸坐標向右移動了1/WplotTree.xOff = plotTree.xOff + 1.0/plotTree.totalW# 打開注釋可以觀察葉子節點的坐標變化# print((plotTree.xOff, plotTree.yOff), secondDict[key])# 繪制葉子節點plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)# 繪制葉子節點和父節點的中間連線內容plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))# 返回遞歸之前,需要將y軸的偏移量增加,向上移動1/D,也就是返回去繪制上一層的y軸plotTree.yOff = plotTree.yOff + 1.0/plotTree.totalDdef createPlot(inTree):"""需要繪制的決策樹:param inTree: 決策樹字典:return:"""# 創建一個圖像fig = plt.figure(1, facecolor='white')fig.clf()axprops = dict(xticks=[], yticks=[])createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)# 計算出決策樹的總寬度plotTree.totalW = float(getNumLeafs(inTree))# 計算出決策樹的總深度plotTree.totalD = float(getTreeDepth(inTree))# 初始的x軸偏移量,也就是-1/2W,每次向右移動1/W,也就是第一個葉子節點繪制的x坐標為:1/2W,第二個:3/2W,第三個:5/2W,最后一個:(W-1)/2WplotTree.xOff = -0.5/plotTree.totalW# 初始的y軸偏移量,每次向下或者向上移動1/DplotTree.yOff = 1.0# 調用函數進行繪制節點圖像plotTree(inTree, (0.5, 1.0), '')# 繪制plt.show()3.結果
dataSet, featureNames, featureNamesSet=readWatermelonDataSet() testTree= createFullDecisionTree(dataSet, featureNames, featureNamesSet,featureNames) createPlot(testTree)三、參考
【機器學習】 - 決策樹(西瓜數據集)
總結
- 上一篇: 我的世界java雪村种子_我的世界:矿洞
- 下一篇: 土壤湿度传感器