-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbai9_PLA_Perceptron_Learning_Algorithrm.py
More file actions
107 lines (86 loc) · 2.74 KB
/
bai9_PLA_Perceptron_Learning_Algorithrm.py
File metadata and controls
107 lines (86 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# -*- coding: utf-8 -*-
"""
Created on Sun Apr 19 18:45:33 2020
@author: phamk
"""
# generate data
# list of points
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial.distance import cdist
np.random.seed(2)
means = [[2, 2], [4, 2]]
cov = [[.3, .2], [.2, .3]]
N = 10
X0 = np.random.multivariate_normal(means[0], cov, N).T
X1 = np.random.multivariate_normal(means[1], cov, N).T
X = np.concatenate((X0, X1), axis = 1)
y = np.concatenate((np.ones((1, N)), -1*np.ones((1, N))), axis = 1)
# Xbar
X = np.concatenate((np.ones((1, 2*N)), X), axis = 0)
def h(w, x):
return np.sign(np.dot(w.T, x))
def has_converged(X, y, w):
return np.array_equal(h(w, X), y) #True if h(w, X) == y else False
def perceptron(X, y, w_init):
w = [w_init]
N = X.shape[1]
mis_points = []
while True:
# mix data
mix_id = np.random.permutation(N)
for i in range(N):
xi = X[:, mix_id[i]].reshape(3, 1)
yi = y[0, mix_id[i]]
if h(w[-1], xi)[0] != yi:
mis_points.append(mix_id[i])
w_new = w[-1] + yi*xi
w.append(w_new)
if has_converged(X, y, w[-1]):
break
return (w, mis_points)
d = X.shape[0]
w_init = np.random.randn(d, 1)
(w, m) = perceptron(X, y, w_init)
print(m)
# print(w)
# print(len(w))
def draw_line(w):
w0, w1, w2 = w[0], w[1], w[2]
if w2 != 0:
x11, x12 = -100, 100
return plt.plot([x11, x12], [-(w1*x11 + w0)/w2, -(w1*x12 + w0)/w2], 'k')
else:
x10 = -w0/w1
return plt.plot([x10, x10], [-100, 100], 'k')
## Visualization
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
def viz_alg_1d_2(w):
it = len(w)
fig, ax = plt.subplots(figsize=(5, 5))
def update(i):
ani = plt.cla()
#points
ani = plt.plot(X0[0, :], X0[1, :], 'b^', markersize = 8, alpha = .8)
ani = plt.plot(X1[0, :], X1[1, :], 'ro', markersize = 8, alpha = .8)
ani = plt.axis([0 , 6, -2, 4])
i2 = i if i < it else it-1
ani = draw_line(w[i2])
if i < it-1:
# draw one misclassified point
circle = plt.Circle((X[1, m[i]], X[2, m[i]]), 0.15, color='k', fill = False)
ax.add_artist(circle)
# hide axis
cur_axes = plt.gca()
cur_axes.axes.get_xaxis().set_ticks([])
cur_axes.axes.get_yaxis().set_ticks([])
label = 'PLA: iter %d/%d' %(i2, it-1)
ax.set_xlabel(label)
return ani, ax
print(i)
anim = FuncAnimation(fig, update, frames=np.arange(0, it + 2), interval=1000)
# save
anim.save('pla_vis.gif', dpi = 100, writer = 'imagemagick')
plt.show()
viz_alg_1d_2(w)