31 lines
903 B
Python
31 lines
903 B
Python
|
#!/usr/bin/env python3
|
||
|
# coding: utf-8
|
||
|
import numpy as np
|
||
|
import sklearn.datasets as ds
|
||
|
import sklearn.metrics as mt
|
||
|
import sklearn.preprocessing as prep
|
||
|
import neuralNetwork as nn
|
||
|
import sklearn.cross_validation as cross
|
||
|
|
||
|
|
||
|
digits = ds.load_digits()
|
||
|
x = digits.data
|
||
|
y = digits.target
|
||
|
# 将x中的值设置到0-1之间
|
||
|
x -= x.min()
|
||
|
x /= x.max()
|
||
|
|
||
|
# n = nn.NeuralNetwork([64, 100, 10], 'tanh')
|
||
|
n = nn.NeuralNetwork([64, 100, 10], 'logistic')
|
||
|
x_train, x_test, y_train, y_test = cross.train_test_split(x, y)
|
||
|
labels_train = prep.LabelBinarizer().fit_transform(y_train)
|
||
|
labels_test = prep.LabelBinarizer().fit_transform(y_test)
|
||
|
print("start fitting")
|
||
|
n.fit(x_train, labels_train, epochs=10000)
|
||
|
preddctions = []
|
||
|
for i in range(x_test.shape[0]):
|
||
|
o = n.predit(x_test[i])
|
||
|
preddctions.append(np.argmax(o))
|
||
|
print(mt.confusion_matrix(y_test, preddctions))
|
||
|
print(mt.classification_report(y_test, preddctions))
|