Python 感知器(PLA)实现
Table of Contents
1 PLA回顾
PLA是一种构建于感知器的学习算法。其本质为超平面对线性可分的N维数据集合的线性分割 PLA算法的目的,即通过不断的修正超平面本身法向量,来寻找到当前Training Set下没有任何错误的超平面。 而PLA有两种类型,一种为经典的PLA形式,其表现为超平面法向量的朴素表达。而另外一种,则为PLA的对偶形式,表现为超平面法向量由Traning Set中元素的线性集合 前者在Traning Set维度很大时,效率不如后者来的高。至于原因,不再赘述。
2 PLA – 原始形式
import numpy as np
import matplotlib.pyplot as plt
def create_train_data_set():
train_data_mat = np.array([[1, 1, 4],
[1, 2, 3],
[1, -2, 3],
[1, -2, 2],
[1, 0, 1],
[1, 1, 2]])
train_data_res = np.array([1, 1, 1, -1, -1, -1])
train_data_res = train_data_res[:, np.newaxis]
return train_data_mat, train_data_res
def pla(in_train_mat, in_train_mat_res):
pla_weight = np.ones((3, 1))
while True:
pla_finish = True
for i in range(len(in_train_mat)):
x = in_train_mat[i][:]
y = np.dot(pla_weight.T, x)
if np.sign(y) == np.sign(in_train_mat_res[i]):
continue
else:
pla_finish = False
w_i = x * in_train_mat_res[i]
w_i = w_i[:, np.newaxis]
pla_weight = pla_weight + w_i
if pla_finish:
break
return pla_weight
def plot(point_mat, pla_res):
fig = plt.figure()
ax = fig.add_subplot(111)
xx = list(filter(lambda x: x[3] == -1, point_mat))
ax.scatter([x[1] for x in xx], [x[2] for x in xx], s=100, c='b', marker="x", label='-1')
oo = list(filter(lambda x: x[3] == 1, point_mat))
ax.scatter([x[1] for x in oo], [x[2] for x in oo], s=100, c='r', marker="o", label='1')
x = np.linspace(-3, 3, 50)
#b*pla_res[0] + x*pla_res[1] + y*pla_res[2] = 0
#y = ax+b --> a = - pla_res[1]/pla_res[2] + - pla_res[0]/pra_res[2] pla_res[2]!=0
#pla_res[2]=0 --> x = -pla_res[0]/pla_res[1]b
if pla_res[2]:
a, b = -pla_res[1] / pla_res[2], -pla_res[0] / pla_res[2]
y = a*x+b
ax.plot(x, y, 'b-')
else:
b = -pla_res[0]/pla_res[1]
ax.plot(b * len(x), x, 'b')
plt.legend(loc='upper left', scatterpoints=1)
plt.show()
def main():
train_data_mat, train_mat_res = create_train_data_set()
weight = pla(train_data_mat, train_mat_res)
point_mat = np.hstack((train_data_mat, train_mat_res))
plot(point_mat, weight)
if __name__ == '__main__':
main()
3 PLA – pocket algorithm
正如我们所知,PLA(原始形式)仅可在数据集线性可分的前提下才能work perfectly..然而在现实世界中,
实际上,抛开数据集是否可分不谈,对于一个具有生产价值的系统而言,要如何保证数据集必定线性可分呢?
也许我们会说,我的数据属性决定了,我的数据集合必定是可分的,但不能忽略,现实世界中的数据,往往带有噪音。
如何确保程序在该情形下也可以正常的工作?原始形式,显然是行不通的。
因此,我们提出PLA – pocket algorithm
而pocket algorithm的不同,一言以蔽之,即:有限次数的随机迭代
import random
import numpy as np
import matplotlib.pyplot as plt
# train_data_mat 6*3
train_data_mat = np.array([[1, 1, 4],
[1, 2, 3],
[1, -2, 3],
[1, -2, 2],
[1, 0, 1],
[1, 1, 2]])
# train_data_res_mat 6*1
train_data_res = np.array([[1], [1], [1], [-1], [-1], [-1]])
def cheack(nw):
cnt_nw = 0
for i in range(len(train_data_mat)):
y_nw = np.dot(nw, train_data_mat[i])
if np.sign(y_nw) != np.sign(train_data_res[i]):
cnt_nw = cnt_nw + 1
return cnt_nw
def pla_pocket():
w = np.ones((1, 3))
least_false = cheack(w)
for i in range(0,100000):
de_choice = random.randint(0,5)
de_data = train_data_mat[de_choice]
de_res = train_data_res[de_choice]
y = np.dot(w, de_data)
if np.sign(y) != np.sign(de_res):
n_w = w + de_res * de_data
n_cnt = cheack(n_w)
w = n_w
'''
只要出现错误的点,就一定要更新。不这么做的话,可能会出现w根本无法进行更改的情况。
因为,可能存在这样的w,在任何wrong point上都无法获得比当前更好的结果,则iter停顿。
'''
if n_cnt <= least_false:
least_false = n_cnt
res = n_w
return res
def pla_plot(w):
point_mat = np.hstack((train_data_mat, train_data_res))
fig = plt.figure()
ax = fig.add_subplot(111)
xx = list(filter(lambda x: x[3] == -1, point_mat))
ax.scatter([x[1] for x in xx], [x[2] for x in xx], s=100, c='b', marker="x", label='-1')
oo = list(filter(lambda x: x[3] == 1, point_mat))
ax.scatter([x[1] for x in oo], [x[2] for x in oo], s=100, c='r', marker="o", label='1')
x = np.linspace(-3, 3, 50)
# b*pla_res[0] + x*pla_res[1] + y*pla_res[2] = 0
# y = ax+b --> a = - pla_res[1]/pla_res[2] + - pla_res[0]/pra_res[2] pla_res[2]!=0
# pla_res[2]=0 --> x = -pla_res[0]/pla_res[1]b
if w[2]:
print('w2')
a, b = -w[1] / w[2], -w[0] / w[2]
y = a * x + b
ax.plot(x, y, 'b-')
else:
print('w2nil')
b = (-w[0] / w[1]) * x
ax.plot(b, x, 'b')
plt.legend(loc='upper left', scatterpoints=1)
plt.show()
def main():
w = pla_pocket()
pla_plot(w[0])
if __name__ == '__main__':
main()
附上各种各样的结果图[include the error versions]