写KD树的时候没把类别考虑进去。。。所以先用KD算出最近的k个点,然后找到对应分类最后输出占比最大的

KD树是一种二叉树,用来分割空间上得点

一个树节点的结构如下:

class TreeNode:
    index = -1  # 对应维度序号
    point = None  # 对应的点
    left = None  # 左子树
    right = None  # 右子树
    data = None

    def __init__(self, index=-1, point=None, left=None, right=None):
        self.index = index
        self.point = point
        self.left = left
        self.right = right

    def set_data(self, data):
        self.data = data

    def get_data(self):
        return self.data

建树过程是:

  1. 先选出方差最大的维度
  2. 将现有数据按该维度排序
  3. 取数据中位点
  4. 中位点即该树结点的数据
  5. 点坐标左边的的传入左子树构造方法,右边的同理
  6. 下一层树结点使用下一个维度

代码:

 def build_tree(self, dataset, split):
        # 如果为空返回None
        if dataset is None or len(dataset) is 0:
            return None
        # 顺序维度超出维度范围取余
        if split >= len(dataset[0]) - 1:
            split %= len(dataset[0]) - 1
        # 如果仅只有一个点那么必定是叶子
        if len(dataset) is 1:
            return TreeNode(split, dataset[0], None, None)

        data_sum = len(dataset)
        dataset.sort(key=lambda x: x[split])
        node = TreeNode()
        node.index = split
        point_index = int(data_sum / 2)
        node.point = dataset[point_index]
        node.left = self.build_tree(dataset[0:point_index], split + 1)
        node.right = self.build_tree(dataset[point_index + 1:], split + 1)
        return node

    def create(self, dataset):
        starlin = self.get_var(dataset)
        root = self.build_tree(dataset, starlin)
        self.root = root
        return root

有时候check后的数据需要插入到树中:

插入的过程较简单,从root开始 按该层维度,大于该层维度的值的继续搜索左子树,反之右子树

直到搜索的节点为None 则在这里插入新的结点

    def insert(self, point):
        if self.root is None:
            print('Build a tree first !')
            return
        if len(point) is not len(self.root.point):
            print('This point have {l} splits but tree have {m}'.format(l=len(point), m=len(self.root.point)))
            return
        flag = False
        root = self.root
        while not flag:
            if point[root.index] < root.point[root.index]:
                if root.left is not None:
                    root = root.left
                else:
                    split = (root.index + 1) % len(point)
                    root.left = TreeNode(split, point, None, None)
                    flag = True
            else:
                if root.right is not None:
                    root = root.right
                else:
                    split = (root.index + 1) % len(point)
                    root.right = TreeNode(split, point, None, None)
                    flag = True

寻找过程,首先先按照插入的方法找到最接近的最底层子节点

然后依次向上回溯查找,如果该结点的另半个子树也可能成为最近点则将其Push进栈

查找至栈为空,则找到最近点。点间距离同理可应用不同的距离(相似度)算法

    def sim_distance(self, p1, p2):
        sum_of_squares = sum([pow(p1[i] - p2[i], 2) for i in range(len(p1))])

        return sqrt(sum_of_squares)

    def find_nearest(self, point):
        root = self.root
        s = Stack(99999)
        while root is not None:
            index = root.index
            s.push(root)
            if point[index] <= root.point[index]:
                root = root.left
            else:
                root = root.right

        nearest = s.pop()
        min_dist = self.sim_distance(nearest.point, point)
        while not s.isempty():
            back_point = s.pop()
            if back_point is None:
                continue
            index = back_point.index
            if self.sim_distance([point[index]], [back_point.point[index]]) < min_dist:
                if point[index] <= back_point.point[index]:
                    root = back_point.right
                else:
                    root = back_point.left
                s.push(root)
            if min_dist > self.sim_distance(back_point.point, point):
                nearest = back_point
                min_dist = self.sim_distance(back_point.point, point)

        return nearest.point, min_dist

KNN 算法的核心在于找到最近的k的点,然后根据这些点的类别缺点待查点的类别

我维护了一个长度始终为k的list来保存前k小得距离

每次跟 list尾部的进行比较,如果比其小则加入list,并排序 取前k项

 def find_near_kth(self, point, k):
        root = self.root
        result = []
        s = Stack(99999)
        while root is not None:
            index = root.index
            s.push(root)
            if point[index] <= root.point[index]:
                root = root.left
            else:
                root = root.right
        t_point = s.pop()
        result.append((t_point, self.sim_distance(t_point.point, point)))
        while not s.isempty():
            back_point = s.pop()
            if back_point is None:
                continue
            index = back_point.index
            if self.sim_distance([point[index]], [back_point.point[index]]) <= result[len(result) - 1][1] or len(
                    result) < k:
                if point[index] <= back_point.point[index]:
                    root = back_point.right
                else:
                    root = back_point.left
                s.push(root)
            if result[len(result) - 1][1] > self.sim_distance(back_point.point, point) or len(result) < k:
                result.append((back_point, self.sim_distance(back_point.point, point)))
                result.sort(key=lambda x: x[1])
                result = result[0:k]

        return result

最后用了很蠢得方法来找对应点的分类:

def decide_type(kd_result, t_point, t_type):
    ans = {i: 0 for i in t_type}
    for node in kd_result:
        for i in range(len(t_point)):
            if node[0].point == t_point[i]:
                ans[t_type[i]] += 1
                break

    max_v = 0
    max_type = None
    for i in ans:
        if ans[i] > max_v:
            max_v = ans[i]
            max_type = i;

    return max_type

测试如下:

kd = KdTree()
kd.create(train_point)
print(kd.find_near_kth((1, 1), 2))
# print(decide_type(kd.find_near_kth((6.5, 6), 3),train_point,train_type))