前面是做了一轮决策,按照信息论的方式,对各特征做了分析,确定了能够带来最大信息增益(注意是熵减)的特征。但仅这一步是不够的,我们需要继续对叶子节点进行同样的操作,直到完成如下的目标:
[if !supportLists]1)[endif]程序遍历完所有划分数据集的属性;
[if !supportLists]2)[endif]每个分支下的所有实例都具有相同的分类;
如果程序已经遍历完所有划分数据集的属性,叶子节点下的实例仍然不具备相同的分类,那就采用多数表决的方法(有点像KNN)来决定该叶子节点的分类。
好,上代码。
def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys(): classCount[vote] = 0
classCount[vote] += 1
sortedClassCount=sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
return sortedClassCount[0][0]
如上代码就不逐行展开了,其实就是把一个数组中的标签项数一下数,然后找到哪一个标签出现的次数最多,和KNN中相关的排序方式类似。
我们再来看看,整棵树的遍历:
def createTree(dataSet, labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValues = [example[bestFeat] for example in dataSet]
uniqueVals = set(featValues)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
return myTree
代码一共16行。这段代码比较关键,而且不算太容易看懂,我们来逐行看一下:
def createTree(dataSet, labels):
#定义函数,dataSet实际上是带了标签值的数据集,labels其实是标签值的意义定义,详见前面的数据集定义
classList = [example[-1] for example in dataSet]
#classList实际上就是标签值的数组,标签值位于数据集的最后一列
if classList.count(classList[0]) == len(classList):
#这里做了一个判断,count方法的作用就是数出数组中某个元素值的个数,在这里就是对classList[0]做了计数,当它的数量等同于数组的长度时,说明这个数组里没有别的标签了,即已经分到了标签唯一的状态;按决策树叶子节点是否达到不可分的条件2,已经完成
return classList[0]
#返回classList[0],即当前叶子节点唯一的标签值
if len(dataSet[0]) == 1:
return majorityCnt(classList)
#如果dataSet的长度为1,那就等于是叶子节点中特征值只有1个,这个时候就满足了决策树叶子节点是否达到不可分的条件1,程序遍历完所有划分数据集的属性,这个时候我们要对叶子节点中的标签进行统计,通过多数表决的方法确定并返回这一分支的标签值。
bestFeat = chooseBestFeatureToSplit(dataSet)
bestFeatLabel = labels[bestFeat]
#如果不是上述的两种情况,即节点还可分,那么就用chooseBestFeatureToSplit方法找到最佳的特征项,并对节点进行分解。
myTree = {bestFeatLabel:{}}
#建立一个字典myTree
del(labels[bestFeat])
#删除已选出的特征
featValues = [example[bestFeat] for example in dataSet]
#根据最佳特征的下标,从dataSet里取出相关特征的值的数组
uniqueVals = set(featValues)
#去重,得到该数组中可能的值
for value in uniqueVals:
#遍历所有可能的值
subLabels = labels[:]
#复制出一个label数组,这个数组已经删掉了之前的最佳特征项
myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
#根据最优的特征进行分叉,每一个分叉再去进行生成子树的递归操作(这是关键,通过递归遍历生成所有的子节点,并根据那两条要求确定是否终止并直接return)
return myTree
#返回树(注意这里的返回,可能是一棵子树,因为是通过递归生成了整棵树,只有最开始的调用才是根节点)
好了,至此,这段代码结束。这段代码不算太容易看懂,原理好懂,但是算法理论要想跟代码联系在一起,还是挺复杂的。最终生成的树结构如下:
它是个什么呢?其实就是一开始的最优特征项是“no surfacing”,然后进行分叉,左边由于标签一致所以结束,右边进行再分叉,然后因为特征用完,结束。
一棵有意思的树。 |