31 lines
884 B
Python
31 lines
884 B
Python
|
#!/usr/bin/env python3
|
||
|
# -*- coding:utf-8 -*-
|
||
|
import numpy as np
|
||
|
import operator
|
||
|
|
||
|
|
||
|
def createDataSet():
|
||
|
group = np.array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
|
||
|
labels = ['A', 'A', 'B', 'B']
|
||
|
return group, labels
|
||
|
|
||
|
|
||
|
# K-近邻算法
|
||
|
def classify0(inX, dataSet, labels, k):
|
||
|
dataSetSize = np.shape(dataSet)[0]
|
||
|
diffMat = np.tile(inX, (dataSetSize, 1)) - dataSet
|
||
|
sqDiffMat = diffMat ** 2
|
||
|
sqDistances = np.sum(sqDiffMat, axis=1)
|
||
|
distances = sqDistances ** 0.5
|
||
|
sortedDistIndicies = np.argsort(distances)
|
||
|
classCount = {}
|
||
|
for i in range(k):
|
||
|
voteLabel = labels[sortedDistIndicies[i]]
|
||
|
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1
|
||
|
|
||
|
sortedClassCount = np.sort(classCount.iteritems(),
|
||
|
key=operator.itemgetter(0),
|
||
|
reversed=True)
|
||
|
|
||
|
return sortedClassCount
|