79 lines
2.0 KiB
Python
Executable File
79 lines
2.0 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# coding: utf-8
|
|
|
|
from math import log
|
|
|
|
|
|
# 计算香农熵
|
|
def calcShannonEnt(dataSet):
|
|
numEntries = len(dataSet)
|
|
labelCounts = {}
|
|
for featVec in dataSet:
|
|
currentLabel = featVec[-1]
|
|
if currentLabel not in labelCounts:
|
|
labelCounts[currentLabel] = 0
|
|
labelCounts[currentLabel] += 1
|
|
shannonEnt = 0.0
|
|
for key in labelCounts:
|
|
prob = float(labelCounts[key]) / numEntries
|
|
shannonEnt -= prob * log(prob, 2)
|
|
return shannonEnt
|
|
|
|
|
|
# 初始值
|
|
def createDataSet():
|
|
dataSet = [[1, 1, 'yes'],
|
|
[1, 0, 'yes'],
|
|
[1, 0, 'no'],
|
|
[0, 1, 'no'],
|
|
[0, 1, 'no']]
|
|
label = ['no surfaceing', 'flippers']
|
|
return dataSet, label
|
|
|
|
|
|
# 按照给定特征划分数据集
|
|
def splitDataSet(dataSet, axis, value):
|
|
retDataSet = []
|
|
for featVec in dataSet:
|
|
if featVec[axis] == value:
|
|
reducedFeatVec = featVec[:axis]
|
|
reducedFeatVec.extend(featVec[axis + 1:])
|
|
retDataSet.append(reducedFeatVec)
|
|
return retDataSet
|
|
|
|
|
|
# 选择最好数据集的划分方式
|
|
def choooseBestFeatureToSplit(dataSet):
|
|
numFeatures = len(dataSet)
|
|
baseEntropy = calcShannonEnt(dataSet)
|
|
bestInfoGain = 0.0
|
|
bestFeature = -1
|
|
for i in range(numFeatures):
|
|
featList = [example[i] for example in dataSet]
|
|
uniqyeVals = set(featList)
|
|
newEntropy = 0.0
|
|
for value in uniqyeVals:
|
|
subDataSet = splitDataSet(dataSet, i, value)
|
|
prob = len(subDataSet)/float(len(dataSet))
|
|
newEntropy += prob * calcShannonEnt(subDataSet)
|
|
infoGain = baseEntropy - newEntropy
|
|
if (infoGain > bestInfoGain):
|
|
bestInfoGain = infoGain
|
|
bestFeature = i
|
|
return bestFeature
|
|
|
|
"""
|
|
myData, labels = createDataSet()
|
|
print(myData)
|
|
calcShannonEnt(myData)
|
|
myData[0][-1] = 'maybe'
|
|
print(myData)
|
|
shannonEnt = calcShannonEnt(myData)
|
|
print(shannonEnt)
|
|
myData[0][-1] = 'yes'
|
|
data = splitDataSet(myData, 0, 1)
|
|
print(data)
|
|
data = splitDataSet(myData, 0, 0)
|
|
print(data)
|
|
"""
|