This repository has been archived on 2020-04-25. You can view files and clone it, but cannot push or open issues or pull requests.
ml/nn/handWrittenDigitsRecoginition.py
2020-02-23 22:14:06 +08:00

31 lines
903 B
Python
Executable File

#!/usr/bin/env python3
# coding: utf-8
import numpy as np
import sklearn.datasets as ds
import sklearn.metrics as mt
import sklearn.preprocessing as prep
import neuralNetwork as nn
import sklearn.cross_validation as cross
digits = ds.load_digits()
x = digits.data
y = digits.target
# 将x中的值设置到0-1之间
x -= x.min()
x /= x.max()
# n = nn.NeuralNetwork([64, 100, 10], 'tanh')
n = nn.NeuralNetwork([64, 100, 10], 'logistic')
x_train, x_test, y_train, y_test = cross.train_test_split(x, y)
labels_train = prep.LabelBinarizer().fit_transform(y_train)
labels_test = prep.LabelBinarizer().fit_transform(y_test)
print("start fitting")
n.fit(x_train, labels_train, epochs=10000)
preddctions = []
for i in range(x_test.shape[0]):
o = n.predit(x_test[i])
preddctions.append(np.argmax(o))
print(mt.confusion_matrix(y_test, preddctions))
print(mt.classification_report(y_test, preddctions))