#!/usr/bin/env python3 # coding: utf-8 import numpy as np import sklearn.datasets as ds import sklearn.metrics as mt import sklearn.preprocessing as prep import neuralNetwork as nn import sklearn.cross_validation as cross digits = ds.load_digits() x = digits.data y = digits.target # 将x中的值设置到0-1之间 x -= x.min() x /= x.max() # n = nn.NeuralNetwork([64, 100, 10], 'tanh') n = nn.NeuralNetwork([64, 100, 10], 'logistic') x_train, x_test, y_train, y_test = cross.train_test_split(x, y) labels_train = prep.LabelBinarizer().fit_transform(y_train) labels_test = prep.LabelBinarizer().fit_transform(y_test) print("start fitting") n.fit(x_train, labels_train, epochs=10000) preddctions = [] for i in range(x_test.shape[0]): o = n.predit(x_test[i]) preddctions.append(np.argmax(o)) print(mt.confusion_matrix(y_test, preddctions)) print(mt.classification_report(y_test, preddctions))