-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbai12_binaryClassification_MNIST.py
More file actions
78 lines (60 loc) · 1.95 KB
/
bai12_binaryClassification_MNIST.py
File metadata and controls
78 lines (60 loc) · 1.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
"""
Created on Wed Apr 22 23:07:11 2020
@author: phamk
"""
import numpy as np
from mnist import MNIST
import matplotlib.pyplot as plt
from sklearn import linear_model
from sklearn.metrics import accuracy_score
from display_network import *
mntrain = MNIST('E:/AI/example/MNIST/')
mntrain.load_training()
Xtrain_all = np.asarray(mntrain.train_images)
ytrain_all = np.array(mntrain.train_labels.tolist())
mntest = MNIST('E:/AI/example/MNIST/')
mntest.load_testing()
Xtest_all = np.asarray(mntest.test_images)
ytest_all = np.array(mntest.test_labels.tolist())
cls = [[0], [1]]
def extract_data(X, y, classes):
"""
X: numpy array, matrix of size (N, d), d is data dim
y: numpy array, size (N, )
cls: two lists of labels. For example:
cls = [[1, 4, 7], [5, 6, 8]]
return:
X: extracted data
y: extracted label
(0 and 1, corresponding to two lists in cls)
"""
y_res_id = np.array([])
for i in cls[0]:
y_res_id = np.hstack((y_res_id, np.where(y == i)[0]))
n0 = len(y_res_id)
for i in cls[1]:
y_res_id = np.hstack((y_res_id, np.where(y == i)[0]))
n1 = len(y_res_id) - n0
y_res_id = y_res_id.astype(int)
X_res = X[y_res_id, :]/255.0
y_res = np.asarray([0]*n0 + [1]*n1)
return (X_res, y_res)
# extract data for training
(X_train, y_train) = extract_data(Xtrain_all, ytrain_all, cls)
# extract data for test
(X_test, y_test) = extract_data(Xtest_all, ytest_all, cls)
# train the logistic regression model
logreg = linear_model.LogisticRegression(C=1e5) # just a big number
logreg.fit(X_train, y_train)
# predict
y_pred = logreg.predict(X_test)
print ("Accuracy: %.2f %%" %(100*accuracy_score(y_test, y_pred.tolist())))
# display misclassified image(s)
mis = np.where((y_pred - y_test) != 0)[0]
Xmis = X_test[mis, :]
plt.axis('off')
A = display_network(Xmis.T, Xmis.shape[0], 1)
f2 = plt.imshow(A, interpolation='nearest' )
plt.gray()
plt.show()