写KD树的时候没把类别考虑进去。。。所以先用KD算出最近的k个点,然后找到对应分类最后输出占比最大的
KD树是一种二叉树,用来分割空间上得点
一个树节点的结构如下:
class TreeNode:
index = -1 # 对应维度序号
point = None # 对应的点
left = None # 左子树
right = None # 右子树
data = None
def __init__(self, index=-1, point=None, left=None, right=None):
self.index = index
self.point = point
self.left = left
self.right = right
def set_data(self, data):
self.data = data
def get_data(self):
return self.data
建树过程是:
- 先选出方差最大的维度
- 将现有数据按该维度排序
- 取数据中位点
- 中位点即该树结点的数据
- 点坐标左边的的传入左子树构造方法,右边的同理
- 下一层树结点使用下一个维度
代码:
def build_tree(self, dataset, split):
# 如果为空返回None
if dataset is None or len(dataset) is 0:
return None
# 顺序维度超出维度范围取余
if split >= len(dataset[0]) - 1:
split %= len(dataset[0]) - 1
# 如果仅只有一个点那么必定是叶子
if len(dataset) is 1:
return TreeNode(split, dataset[0], None, None)
data_sum = len(dataset)
dataset.sort(key=lambda x: x[split])
node = TreeNode()
node.index = split
point_index = int(data_sum / 2)
node.point = dataset[point_index]
node.left = self.build_tree(dataset[0:point_index], split + 1)
node.right = self.build_tree(dataset[point_index + 1:], split + 1)
return node
def create(self, dataset):
starlin = self.get_var(dataset)
root = self.build_tree(dataset, starlin)
self.root = root
return root
有时候check后的数据需要插入到树中:
插入的过程较简单,从root开始 按该层维度,大于该层维度的值的继续搜索左子树,反之右子树
直到搜索的节点为None 则在这里插入新的结点
def insert(self, point):
if self.root is None:
print('Build a tree first !')
return
if len(point) is not len(self.root.point):
print('This point have {l} splits but tree have {m}'.format(l=len(point), m=len(self.root.point)))
return
flag = False
root = self.root
while not flag:
if point[root.index] < root.point[root.index]:
if root.left is not None:
root = root.left
else:
split = (root.index + 1) % len(point)
root.left = TreeNode(split, point, None, None)
flag = True
else:
if root.right is not None:
root = root.right
else:
split = (root.index + 1) % len(point)
root.right = TreeNode(split, point, None, None)
flag = True
寻找过程,首先先按照插入的方法找到最接近的最底层子节点
然后依次向上回溯查找,如果该结点的另半个子树也可能成为最近点则将其Push进栈
查找至栈为空,则找到最近点。点间距离同理可应用不同的距离(相似度)算法
def sim_distance(self, p1, p2):
sum_of_squares = sum([pow(p1[i] - p2[i], 2) for i in range(len(p1))])
return sqrt(sum_of_squares)
def find_nearest(self, point):
root = self.root
s = Stack(99999)
while root is not None:
index = root.index
s.push(root)
if point[index] <= root.point[index]:
root = root.left
else:
root = root.right
nearest = s.pop()
min_dist = self.sim_distance(nearest.point, point)
while not s.isempty():
back_point = s.pop()
if back_point is None:
continue
index = back_point.index
if self.sim_distance([point[index]], [back_point.point[index]]) < min_dist:
if point[index] <= back_point.point[index]:
root = back_point.right
else:
root = back_point.left
s.push(root)
if min_dist > self.sim_distance(back_point.point, point):
nearest = back_point
min_dist = self.sim_distance(back_point.point, point)
return nearest.point, min_dist
KNN 算法的核心在于找到最近的k的点,然后根据这些点的类别缺点待查点的类别
我维护了一个长度始终为k的list来保存前k小得距离
每次跟 list尾部的进行比较,如果比其小则加入list,并排序 取前k项
def find_near_kth(self, point, k):
root = self.root
result = []
s = Stack(99999)
while root is not None:
index = root.index
s.push(root)
if point[index] <= root.point[index]:
root = root.left
else:
root = root.right
t_point = s.pop()
result.append((t_point, self.sim_distance(t_point.point, point)))
while not s.isempty():
back_point = s.pop()
if back_point is None:
continue
index = back_point.index
if self.sim_distance([point[index]], [back_point.point[index]]) <= result[len(result) - 1][1] or len(
result) < k:
if point[index] <= back_point.point[index]:
root = back_point.right
else:
root = back_point.left
s.push(root)
if result[len(result) - 1][1] > self.sim_distance(back_point.point, point) or len(result) < k:
result.append((back_point, self.sim_distance(back_point.point, point)))
result.sort(key=lambda x: x[1])
result = result[0:k]
return result
最后用了很蠢得方法来找对应点的分类:
def decide_type(kd_result, t_point, t_type):
ans = {i: 0 for i in t_type}
for node in kd_result:
for i in range(len(t_point)):
if node[0].point == t_point[i]:
ans[t_type[i]] += 1
break
max_v = 0
max_type = None
for i in ans:
if ans[i] > max_v:
max_v = ans[i]
max_type = i;
return max_type
测试如下:
kd = KdTree()
kd.create(train_point)
print(kd.find_near_kth((1, 1), 2))
# print(decide_type(kd.find_near_kth((6.5, 6), 3),train_point,train_type))