This repository has been archived on 2020-04-25. You can view files and clone it, but cannot push or open issues or pull requests.
ml/decisionTree/trees.py

79 lines
2.0 KiB
Python
Raw Permalink Normal View History

2020-02-23 14:14:06 +00:00
#!/usr/bin/env python
# coding: utf-8
from math import log
# 计算香农熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet)
labelCounts = {}
for featVec in dataSet:
currentLabel = featVec[-1]
if currentLabel not in labelCounts:
labelCounts[currentLabel] = 0
labelCounts[currentLabel] += 1
shannonEnt = 0.0
for key in labelCounts:
prob = float(labelCounts[key]) / numEntries
shannonEnt -= prob * log(prob, 2)
return shannonEnt
# 初始值
def createDataSet():
dataSet = [[1, 1, 'yes'],
[1, 0, 'yes'],
[1, 0, 'no'],
[0, 1, 'no'],
[0, 1, 'no']]
label = ['no surfaceing', 'flippers']
return dataSet, label
# 按照给定特征划分数据集
def splitDataSet(dataSet, axis, value):
retDataSet = []
for featVec in dataSet:
if featVec[axis] == value:
reducedFeatVec = featVec[:axis]
reducedFeatVec.extend(featVec[axis + 1:])
retDataSet.append(reducedFeatVec)
return retDataSet
# 选择最好数据集的划分方式
def choooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet)
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featList = [example[i] for example in dataSet]
uniqyeVals = set(featList)
newEntropy = 0.0
for value in uniqyeVals:
subDataSet = splitDataSet(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
"""
myData, labels = createDataSet()
print(myData)
calcShannonEnt(myData)
myData[0][-1] = 'maybe'
print(myData)
shannonEnt = calcShannonEnt(myData)
print(shannonEnt)
myData[0][-1] = 'yes'
data = splitDataSet(myData, 0, 1)
print(data)
data = splitDataSet(myData, 0, 0)
print(data)
"""