精华内容
下载资源
问答
  • kd树python
    千次阅读
    2022-03-31 16:26:54

    前两天学习了knn算法,knn的思想很简单,不过其中提出的kd树有理解的必要。故就用python写了一个kd树代码。
    个人感想是,把kd树算法实现一遍比看书看半天有用多了,而且还不会犯困(bushi
    思路来自https://www.joinquant.com/view/community/detail/dd60bd4e89761b916fe36dc4d14bb272
    讲的很好,不过有一个小漏洞,编程实现一遍才发现

    # 余康盛   python学习
    # 2022/3/31
    # 16:11
    # kd树结点
    class Node:
        def __init__(self):
            # 左孩子
            self.left = None
            # 右孩子
            self.right = None
            # 父节点
            self.parent = None
            # 特征坐标
            self.x = None
            # 切分轴
            self.dimension = None
            # 是否被访问过
            self.flag = False
    
    
    # 构建kd树
    def construct(d, data, node, layer):
        """
        :type d: int
        d是向量的维数
        :type data: list
        data是所有向量构成的列表
        :type node: Node
        node是当前进行运算的结点
        :type layer: int
        layer是当前kd树所在层数
        """
        node.dimension = layer
        # 如果只有一个元素,说明到了叶子结点,该分支结束
        if len(data) == 1:
            node.x = data[0]
            return
        if len(data) == 0:  # 没有代表的数据就作为一个空叶子结点
            return
        # 1,data中的数据按layer%N维进行排序
        data.sort(key=lambda x: x[layer % d])
        # 2,计算中间点的索引,偶数则取中间两位中较大的一位,记为该结点的特征坐标
        middle = len(data) // 2
        node.x = data[middle]
        # 3,划分data
        dataleft = data[:middle]
        dataright = data[middle + 1:]
        # 4,左孩子结点
    
        left_node = Node()
        node.left = left_node
        left_node.parent = node
        construct(d, dataleft, left_node, layer + 1)
        # 5,右孩子结点
    
        right_node = Node()
        node.right = right_node
        right_node.parent = node
        construct(d, dataright, right_node, layer + 1)
    
    
    def distance(a, b):  # 计算欧式距离
        """
        :type a: list
        :type b: list
        """
        dis = 0
        for i in range(0, len(a)):
            dis += (a[i] - b[i]) ** 2
        return dis ** 0.5
    
    
    def change_L(L, x, p, K):  # 判断并进行是否将该点加入近邻点列表
        """
        :type L: list
        L是近邻点列表
        :type x: list
        x是判断是否要加入近邻列表的向量
        :type p: list
        p是目标向量
        :type K:int
        K是近邻列表的最大元素个数
        """
        if len(L) < K:
            L.append(x)
            return
        dislist = []
        for i in range(0, K):
            dislist.append(distance(p, L[i]))
        index = dislist.index(max(dislist))
        if distance(p, x) < dislist[index]:  # 若x和p之间的距离小于L到p中最远的点,就用x替换此最远点
            L[index] = x
        return max(dislist)
    
    
    # 搜索kd树
    def search(node, p, L, K):
        """
        :type List: list
        :type node: Node
        :type p: list
        :type L: list
        :type K: int
        :type L0: list
        :type f: bool
        """
        # L为有k个座位的列表,用于保存已搜寻到的最近点
        # 1,根据p的坐标值和每个点的切分轴向下搜索,先到达底部结点
        n = node  # 用n来记录结点的位置,先从顶部开始,直到叶子结点
        while True:
            # 若到达了叶子结点则退出循环
            if (n.left == None) & (n.right == None):
                break
            if n.x[n.dimension] > p[n.dimension]:
                n = n.left
            else:
                n = n.right
        n.flag = True  # 标记为已访问过
        if n.x is None:  # 若为空叶子结点,则不必记录数值
            pass
        else:
            change_L(L, n.x, p, K)  # 若符合插入条件,就插入,不符合就不插入
        # (三)
        while True:
            # 若当前结点是根结点则输出L算法完成
            if n.parent is None:
                if len(L) < K:
                    print('K值超过数据总量')
                return L
            # 当前结点不是根结点,向上爬一格
            else:
                n = n.parent
                while n.flag == True:
                    # 若当前结点被访问过,就一直向上爬,到没被访问过的结点为止
                    # 若向上爬时遇到了已经被访问过的根结点,说明另一边已经搜索过了搜索结束
                    if (n.parent is None) & (n.flag):
                        if len(L) < K:
                            print('K值超过数据总量')
                        return L
                    n = n.parent
                # 此时n未被访问过,将其标记为访问过
                n.flag = True
    
                # (1)如果此时 L 里不足 k 个点,则将节点特征加入 L;
                # 如果 L 中已满 k 个点,且当前结点与 p 的距离小于与L的最大距离,
                # 则用节点特征替换掉 LL 中离最远的点。
                change_L(L, n.x, p, K)
                ''' 计算p和当前节点切分线的距离。如果该距离小等于于 LL 中最远的距离或者 LL 中不足 kk 个点,
                            则切分线另一边或者 切分线上可能有更近的点,
                            因此在当前节点的另一个枝从 (一) 开始执行。'''
                dislist = []
                for i in range(0, len(L)):
                    dislist.append(distance(p, L[i]))
                if (abs(p[n.dimension] - n.x[n.dimension]) < max(dislist)) | (len(L) < K):
                    if n.left.flag == False:
                        return search(n.left, p, L, K)
                    else:
                        return search(n.right, p, L, K)
                # 如果该距离大于等于 L 中距离 p 最远的距离并且 L 中已有 k 个点,则在切分线另一边不会有更近的点,重新执行(三)
    
    
    # 使用说明
    # data表示数据集,这里是list类型,元素表示数据点,是d维向量,d表示data中数据点的维度,p为要寻找k近邻的点,K为近邻个数,其他均为默认值
    data = [[5, 4], [7, 2], [2, 3], [4, 7], [8, 1], [9, 6]]
    node = Node()
    construct(d=2, data=data, node=node, layer=0)
    print(search(node=node, p=[5, 4], L=[], K=6))
    
    
    更多相关内容
  • 主要介绍了python K近邻算法的kd树实现,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧
  • python K近近邻邻算算法法的的kd树树实实现现 这篇文章主要介绍了python K近邻算法的kd树实现小编觉得挺不错的现在分享给大家也给大家做个参考 一起跟随小编过 看看吧 k近近邻算算法法的的介介绍 k近邻算法是一种...
  • 李航例题3.2构造kd树python代码

    千次阅读 2018-01-15 14:15:50
    关于kd树的原理可以看这篇http://blog.csdn.net/qll125596718/article/details/8426458 下面主要是关于李航统计学习方法中例3.2的pyhton实现: 先来贴一下运行结果: 这里还没把左右都出现的写成root。。。...

    关于kd树的原理可以看这篇http://blog.csdn.net/qll125596718/article/details/8426458

    下面主要是关于李航统计学习方法中例3.2的pyhton实现:

    先来贴一下运行结果:


    这里还没把左右都出现的写成root。。。其他跟图3.4显示一样

    下面就上一下代码:

    #author:xinxinzhang 
    import numpy as np
    def loadDataSet():
        T=[[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]  #书上数据加载一下
        return np.mat(T)
    
    def Kd_split(T):
        T0_var=np.var(T[:,0])    #选择x轴为坐标轴并算方差
        T1_var=np.var(T[:,1])    #选择y轴为坐标轴并算方差
        Tmax=max(T0_var,T1_var)  #看哪个轴上的方差大,取方差大的轴
        left_node=[]
        right_node=[]
        if Tmax==T0_var:     #如果x轴方差大
            X0_sorted=T[np.lexsort(T[:,::-1].T)]  #按x轴排个序先(这里就是按第一列从小到大排序)
            
            Kd_split(X0_sorted[0,:int(len(T)/2)])    #把从小一直到中值的数继续算方差、选轴、排序
            left_node.append(X0_sorted[0,int(len(T)/2)+1].T)#取中值给左节点列表
            
            Kd_split(X0_sorted[0,int(len(T) / 2):])    #把从中值一直到大的数算方差选轴排序
            right_node.append(X0_sorted[0,int(len(T)/2)].T)  #取中值给右节点列表
          
        else:
            X1_sorted = T[np.lexsort(T.T)]     #如果y轴方差大
           
            left_node.append(X1_sorted[0][:int(len(T) / 2)])  #取从小到中值的数给左节点
    
            right_node.append(X1_sorted[0][int(len(T) / 2):])#取从中值到大的数给右节点
            print('left_node',left_node)     #打印左右节点列表
            print('right_node:',right_node)
    
    T=loadDataSet()
    Kd_split(T)

    展开全文
  • Python实现kd树

    千次阅读 2020-06-17 10:57:22
    KD树建树采用的是从m个样本的n维特征中,分别计算n个特征的取值的方差,用方差最大的第k维特征nknk来作为根节点。对于这个特征,我们选择特征nknk的取值的中位数nkvnkv对应的样本作为划分点,对于所有第k维特征的...

    kd树的数据结构和二叉树类似,每个节点存有当前节点的数值,左右子树的节点,和以当前节点为根节点的子树的划分维度。

    
    class KdNode:
        def __init__(self, dim, val, left=None, right=None):
            self.dim = dim
            self.val = val
            self.right = right
            self.left = left

    建kd树采用的是从m个样本的n维特征中,分别计算n个特征的取值的方差,用方差最大的第k维特征n_{_{k}}作为划分左右子树的依据。对于这个特征,选择特征n_{_{k}}的取值的中位数n_{_{ki}}对应的样本作为划分点,对于所有第k维特征的取值小于n_{_{ki}}的样本,划入左子树,对于第k维特征的取值大于等于n_{_{ki}}的样本,划入右子树,对于左子树和右子树,采用和刚才同样的办法来找方差最大的特征来做更节点,递归生成kd树。

    比如我们有二维样本6个,[(2,3),(5,4),(9,6),(4,7),(8,1),(7,2)],构建kd树的具体步骤为:

    1. 找到划分的特征。6个数据点在x,y维度上的数据方差分别为6.97,5.37,所以在x轴上方差更大,用第0维特征建树。
    2. 确定划分点(5,4)。根据x维上的值将数据排序,6个数据的中值(所谓中值,即中间大小的值)为5,所以划分点的数据是(5,4)。这样,该节点的分割超平面就是通过(5,4)并垂直于:划分点维度的直线x=5;
    3. 确定左子空间和右子空间。 分割超平面x=5将整个空间分为两部分:x<=5的部分为左子空间,包含2个节点=[(2,3),(4,7)];另一部分为右子空间,包含2个节点=[(7,2),(8,1),(9,6)]。
    4. 用同样的办法划分左子树的节点[(2,3),(4,7)]和右子树的节点[(7,2),(8,1),(9,6)]。最终得到kd树。
    
    def make_kdTree(data):
        if data.shape[0] == 0:
            return None
        elif data.shape[0] == 1:
            node = KdNode(dim=0, val=data[0], left=None, right=None)
            return node
        vars = np.var(data, axis=0)
        dim = np.argmax(vars)
        my_sort = lambda c: sorted(c, key=lambda i: i[dim])
        data = my_sort(data)
        data = np.asarray(data)
        index = (data.shape[0]-1)//2
        # print(dim)
        # print('left:', data[:index, :])
        # print('right:', data[index + 1:, :])
        left = None if index == 0 else make_kdTree(data[:index, :])
        right = None if index+1 == index else make_kdTree(data[index+1:, :])
        node = KdNode(dim=dim, val=data[index], left=left
                      , right=right)
        return node
    
    
    def in_order(root):
        if root:
            print(root.dim, root.val)
            in_order(root.left)
            in_order(root.right)
        else:
            return
    
    if __name__ == '__main__':
    
        pts = [[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]]
        pts = np.asarray(pts)
        root = make_kdTree(pts)
        in_order(root)

    结果:

    0
    left: [[2 3]
     [4 7]]
    right: [[7 2]
     [8 1]
     [9 6]]
    1
    left: []
    right: [[4 7]]
    1
    left: [[8 1]]
    right: [[9 6]]
    0 [5 4]
    1 [2 3]
    0 [4 7]
    1 [7 2]
    0 [8 1]
    0 [9 6]

     参考链接:

    https://www.cnblogs.com/pinard/p/6061661.html?utm_source=itdadao&utm_medium=referral

    https://blog.51cto.com/underthehood/687160

    展开全文
  • python实现kd树以及最近邻查找算法

    千次阅读 2022-04-26 21:59:44
    python实现kd树以及最近邻查找算法一、kd树简介二、kd树生成1.确定切分域2.确定数据域3.理解递归树4.python实现递归树代码三、kd树上的最近邻查找算法 一、kd树简介 kd树是一种树形结构,树的每个节点存放一个k维...

    一、kd树简介

    kd树是一种树形结构,树的每个节点存放一个k维数据,某一节点的子节点可以看作是由过该节点一个平面切割后产生的(想象一下切蛋糕的过程),如此反复产生切割平面,就能为每个数据在空间中建立索引,如下图所示:
    在这里插入图片描述
    由于采用这种特殊的分割方式,使得在利用kd树做最近邻查找时,可以避开一些距离很远的点,查找速度得到了较大的提升,对于空间中Nk维数据,穷举法的算法复杂度为O(Nk),而使用kd树查找的算法复杂度只有O(klog(N))。kd树是一种典型的空间换时间的方式,即花费存储空间为数据建立索引,这样使得后续查找时速度更快,花费时间更少。

    二、kd树生成

    具体的算法实现主要参考的是这篇文章:https://www.cnblogs.com/eyeszjwang/articles/2429382.html,实现时有少量改动。生成kd树有两个关键的中间过程,即:

    1.确定切分域

    (1)确定split域:对于所有描述子数据(特征矢量),统计它们在每个维上的数据方差。以SURF特征为例,描述子为64维,可计算64个方差。挑选出最大值,对应的维就是split域的值。数据方差大表明沿该坐标轴方向上的数据分散得比较开,在这个方向上进行数据分割有较好的分辨率;

    这段文字用通俗一点的语言来说就是:对于二维的情况,每一次做数据切分的时候,沿着x轴还是y轴做切分是一个问题,那么我们要怎么确定呢?我们可以统计这些二维数据的x值和y值的方差,方差越大说明数据在这一方向上越离散,而数据越离散说明沿着这一方向上数据之间的距离区分度越大,简单点来说就是相互之间隔得更远,我们就用这个方向做切分。
    确定了切分域之后,我们就需要来对数据做切分了。

    2.确定数据域

    (2)确定Node-data域:数据点集Data-set按其第split域的值排序。位于正中间的那个数据点被选为Node-data。此时新的Data-set’ = Data-set\Node-data(除去其中Node-data这一点)。

    简单来说,这句话的意思是:现在我们已经确定了沿着x轴做切分,那么我们要怎么决定在x轴哪里做切分呢?我们可以将所有数据根据x值的大小做一个排序,然后选取正中间那个数据的x值作为切分的位置。注意,这里有一个关键的问题是:如果我们有偶数个数据,怎么确定中间那个数据?难道我们选取中间两个数做一下平均???如果没有记错的话这应该是中位数的定义。。。如果这样完全就是自找麻烦!因为我们要确保至少有一个数据的x值落在切分点上,但是取平均之后并不能保证!!!所以更好的办法是,在有两个中间数据的情况下,随便选取一个数据的x值就行了。
    决定了在x轴哪里做切分之后,我们就需要把数据做切分了,这里根据数据的x值相对于切分位置的大小,可以归为左节点和右节点,同时不要忘了:当前主节点也要保存一个数据,选取一个x值大小和切分位置相等的数据保存就行(如果有多个随便选一个就行,关键之处在于这个数据的x值落在切割线上。)

    3.理解递归树

    前面提到过,kd树是一种树形结构,因此可以递归生成,这是树形结构的共性,用程序语言来说,递归就是函数自己调用自己,在理解上也是很自然的。对于一组数据,我们通过找到的一个切分线把数据一分为二,而这个切分线的确定只和这组数据有关,左边的数据归为左节点,右边的数据归为右节点,更进一步,对于左边或者右边的这组数据,我们又可以将其看作一个整体,找到一个切分线把它一分为二,这样将一组数据一分为二的过程反复进行,相当于这个过程函数不断地调用自身,最终生成二叉树,将所有的数据分开。

    4.python实现递归树代码

    ###建立kd树和实现查询功能
    import numpy as np
    import matplotlib.pyplot as plt
    
    class kdTree:
        def __init__(self, parent_node):
            '''
            节点初始化
            '''
            self.nodedata = None   ###当前节点的数据值,二维数据
            self.split = None ###分割平面的方向轴序号,0代表沿着x轴分割,1代表沿着y轴分割
            self.range = None  ###分割临界值
            self.left = None    ###左子树节点
            self.right = None   ###右子树节点
            self.parent = parent_node  ###父节点
            self.leftdata = None  ###保留左边节点的所有数据
            self.rightdata = None ###保留右边节点的所有数据
            self.isinvted = False ###记录当前节点是否被访问过
    
        def print(self):
            '''
            打印当前节点信息
            '''
            print(self.nodedata, self.split, self.range)
    
        def getSplitAxis(self, all_data):
            '''
            根据方差决定分割轴
            '''
            var_all_data = np.var(all_data, axis=0)
            if var_all_data[0] > var_all_data[1]:
                return 0
            else:
                return 1
        
    
        def getRange(self, split_axis, all_data):
            '''
            获取对应分割轴上的中位数据值大小
            '''
            split_all_data = all_data[:, split_axis]
            data_count = split_all_data.shape[0]
            med_index = int(data_count/2)
            sort_split_all_data = np.sort(split_all_data)
            range_data = sort_split_all_data[med_index]
            return range_data
    
    
        def getNodeLeftRigthData(self, all_data):
            '''
            将数据划分到左子树,右子树以及得到当前节点
            '''
            data_count = all_data.shape[0]
            ls_leftdata = []
            ls_rightdata = []
            for i in range(data_count):
                now_data = all_data[i]
                if now_data[self.split] < self.range:
                    ls_leftdata.append(now_data)
                elif now_data[self.split] == self.range and self.nodedata == None:
                    self.nodedata = now_data
                else:
                    ls_rightdata.append(now_data)
            self.leftdata = np.array(ls_leftdata)
            self.rightdata = np.array(ls_rightdata)
    
    
        def createNextNode(self,all_data):
            '''
            迭代创建节点,生成kd树
            '''
            if all_data.shape[0] == 0:
                print("create kd tree finished!")
                return None
            self.split = self.getSplitAxis(all_data)
            self.range = self.getRange(self.split, all_data)
            self.getNodeLeftRigthData(all_data)
            if self.leftdata.shape[0] != 0:
                self.left = kdTree(self)
                self.left.createNextNode(self.leftdata)
            if self.rightdata.shape[0] != 0:
                self.right = kdTree(self)
                self.right.createNextNode(self.rightdata)
    
        def plotKdTree(self):
            '''
            在图上画出来树形结构的递归迭代过程
            '''
            if self.parent == None:
                plt.figure(dpi=300)
                plt.xlim([0.0, 10.0])
                plt.ylim([0.0, 10.0])
            color = np.random.random(3)
            if self.left != None:
                plt.plot([self.nodedata[0], self.left.nodedata[0]],[self.nodedata[1], self.left.nodedata[1]], '-o', color=color)
                plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.left.nodedata[0]-self.nodedata[0])/2.0, dy=(self.left.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
                self.left.plotKdTree()
            if self.right != None:
                plt.plot([self.nodedata[0], self.right.nodedata[0]],[self.nodedata[1], self.right.nodedata[1]], '-o', color=color)
                plt.arrow(x=self.nodedata[0], y=self.nodedata[1], dx=(self.right.nodedata[0]-self.nodedata[0])/2.0, dy=(self.right.nodedata[1]-self.nodedata[1])/2.0, color=color, head_width=0.2)
                self.right.plotKdTree()
            # if self.split == 0:
            #     x = self.range
            #     plt.vlines(x, 0, 10, color=color, linestyles='--')
            # else:
            #     y = self.range
            #     plt.hlines(y, 0, 10, color=color, linestyles='--')
    
    
    test_array = 10.0*np.random.random([30,2])
    my_kd_tree = kdTree(None)
    my_kd_tree.createNextNode(test_array)
    my_kd_tree.plotKdTree()
    

    这里代码中使用了Python面向对象技术,kdTree类的重要参数和前面给出的参考文章中的参数大致相同,具体代码细节不再说明,这里随机生成了30个范围在0-10之内的2维数据作为测试数据,下图是一次运行得到的结果:
    在这里插入图片描述
    可以很容易看到中间橙色的点就是根节点,以及每个节点的迭代过程,运行过程无误。

    三、kd树上的最近邻查找算法

    加快对目标数据的最近邻数据的搜索过程,是kd树这种特殊存储结构的最主要功能,尤其是在数据量非常大时,其速度优势更加明显。kd树上的最近邻查找算法主要涉及两个过程,即:

    1.生成搜索路径

    这一过程相对容易,也很好理解。由于我们之前已经根据不同的切分线,生成了包含所有数据点的kd树,那么现在给我们一个新的数据,我们首先当然是根据这些切分线来判断待查找的数据是属于哪个分区的,我们当然有理由相信与这个数据同属一个分区的数据点(即某个叶节点)是其最近邻点的概率比不同分区的点的概率要大。因此,我们通过对目标数据的二叉查找,可以确定出一条搜索路径以及初始的最近邻点,但是要注意的是,通过二叉查找找到的叶节点是目标点的最近邻点的可能性较大,但不是一定的,如下图:
    在这里插入图片描述
    目标点落在了y=4的上半平面,但是其最近邻点却在y=4的下半平面,所以这里我们初步搜索出来的一个叶节点并不一定是目标点的最近邻点,我们还需要不断地沿着搜索路径回溯,确定同一主节点的其它子节点中是否存在与目标点距离更近的点。

    2.搜索路径回溯

    为了实现路径回溯的功能,这里需要使用来存储搜索路径,具体说来,当回溯到某一节点的父节点时,需要判断目标点到该父节点对应切分线的距离是否小于当前的最小距离,如果比最短距离还小,说明在该父节点对应的另一分支中有可能存在与目标点距离更小的点,因此就需要搜索该分支中的节点。
    为了更加形象地说明,还是以上图为例。首先通过二分查找我们确定目标点与(4,7)点落在同一域内,因此将(4,7)作为初始最近邻点,然后向上回溯到(5,4)点,而(5,4)点对应的切分线是y=4,通过计算发现目标点到直线y=4的距离小于当前最短距离,因此在目标点的对侧即(5,4)节点的另一分支可能存在与目标点距离更近的点,因此我们需要跳到另一分支中重新检索,这里由于另一分支的深度不一定和前一分支相同,因此在跳到另一分支的头节点之后,我们还需要在此基础之上重复第1步中的路径搜索过程,到达该分支的叶节点,然后重复向上回溯查找直到将搜索路径全部回溯完成,我们就可以得到目标点的最近邻点。
    这其中还有一个值得注意的地方,就是向上回溯时为了避免路径在两个分支之间来回跳跃导致死循环,需要将整个回溯过程中访问过的节点从路径中去掉,用一个标签来指示就可以,上述代码中使用的是
    isinvted
    来标记当前节点是否被访问过。

    3.最近邻查找算法代码

    具体代码实现是在以上kdTree类的基础上在添加几个内部函数就可以了,具体添加的函数为:

    	def divDataToLeftOrRight(self, find_data):
            '''
            根据传入的数据将其分给左节点(0)或右节点(1)
            '''
            data_value = find_data[self.split]
            if data_value < self.range:
                return 0
            else:
                return 1
    
        def getSearchPath(self, ls_path, find_data):
            '''
            二叉查找到叶节点上
            '''
            now_node = ls_path[-1]
            if now_node == None:
                return ls_path
            now_split = now_node.divDataToLeftOrRight(find_data)
            if now_split == 0:
                next_node = now_node.left
            else:
                next_node = now_node.right
            while(next_node!=None):
                ls_path.append(next_node)
                next_split = next_node.divDataToLeftOrRight(find_data)
                if next_split == 0:
                    next_node = next_node.left
                else:
                    next_node = next_node.right
            return ls_path
                
        def getNestNode(self, find_data, min_dist, min_data):
            '''
            回溯查找目标点的最近邻距离
            '''
            ls_path = []
            ls_path.append(self)
            self.getSearchPath(ls_path, find_data)
            now_node = ls_path.pop()
            now_node.isinvted = True
            min_data = now_node.nodedata
            min_dist = np.linalg.norm(find_data-min_data)
            while(len(ls_path)!=0):
                back_node = ls_path.pop()   ### 向上回溯一个节点
                if back_node.isinvted == True:
                    continue
                else:
                    back_node.isinvted = True
                back_dist = np.linalg.norm(find_data-back_node.nodedata)
                if back_dist < min_dist:
                    min_data = back_node.nodedata
                    min_dist = back_dist
                if np.abs(find_data[back_node.split]-back_node.range) < min_dist:
                    ls_path.append(back_node)
                    if back_node.left.isinvted == True:
                        if back_node.right == None:
                            continue
                        ls_path.append(back_node.right)
                    else:
                        if back_node.left == None:
                            continue
                        ls_path.append(back_node.left)
                    ls_path = back_node.getSearchPath(ls_path, find_data)
                    now_node = ls_path.pop()
                    now_node.isinvted = True
                    now_dist = np.linalg.norm(find_data-now_node.nodedata)
                    if now_dist < min_dist:
                        min_data = now_node.nodedata
                        min_dist = now_dist
            print("min distance:{}  min data:{}".format(min_dist, min_data))
            return min_dist
    
        def getNestDistByEx(self, test_array, find_data, min_dist, min_data):
            '''
            穷举法得到目标点的最近邻距离
            '''
            data_count = test_array.shape[0]
            min_data = test_array[0]
            min_dist = np.linalg.norm(find_data-min_data)
            for i in range(data_count):
                now_data = test_array[i]
                now_dist = np.linalg.norm(find_data-now_data)
                if now_dist < min_dist:
                    min_dist = now_dist
                    min_data = now_data
            print("min distance:{}  min data:{}".format(min_dist, min_data))
            return min_dist
    

    代码的对齐格式是一致的,直接加入以上类中就可以,当然为了对比以及验证结果的正确性,在类中还实现了穷举查找算法。首先用50个点测试了一下回溯查找结果的正确性,绘制的结果如下:
    在这里插入图片描述
    查找的目标点是(5.0, 5.0),查找到的最近邻点在目标点左下角,从图上来看结果是正确的。为了对比穷举法和利用kd树回溯查找的速度,数据点设置为10000个,代码为:

    test_array = 10.0*np.random.random([10000,2])   ### 随机生成n个2维0-10以内的数据点
    my_kd_tree = kdTree(None)                    ### kd树实例化
    my_kd_tree.createNextNode(test_array)        ### 生成kd树
    # my_kd_tree.plotKdTree()   
    find_data = np.array([5.0, 5.0])             ### 待查找目标点
    min_dist = 0                                 ### 临时变量,存储最短距离
    min_data = np.array([0.0, 0.0])              ### 临时变量,存储取到最短距离时对应的数据点
    
    %time min_dist = my_kd_tree.getNestNode(find_data, min_dist, min_data)        ### 利用kd树回溯查找
    %time min_dist = my_kd_tree.getNestDistByEx(test_array, find_data, min_dist, min_data)    ### 穷举法查找
    

    用%time命令来显示单步运行查找算法所需的时间,运行结果如下:
    在这里插入图片描述
    可以看到两种算法最终查找到的最短距离以及最近邻数据点都是一样的,证明了算法的正确性。同时kd树查找过程只用了1ms左右,而穷举法查找用了70ms左右,二者相差了70倍,当然随着数据量增大这个差距还会继续增加的,最终应该会趋于某个极限值。

    展开全文
  • 【统计学习方法】k近邻 kd树python实现

    千次阅读 多人点赞 2017-08-17 15:22:35
    前言 代码可在Github上下载:代码下载 k近邻可以算是机器学习中易于理解、实现的一个算法了,《机器学习...所以需要将数据用树形结构存储,以便快速检索,这也就是本文要阐述的kd树。 实现 分为两部分,一个是k...
  • 博主费了好大劲,搜索和尝试pyhton实现kd树以及用pythonkd树实现mnist分类。 最大的难点是二叉树的python索引(或者说是递归索引) 让博主深为收益的两篇文章:感谢大佬博主们 传送门 统计学习方法笔记(二)-kd树...
  • 《统计学习方法》中,kd树Python实现。
  • 机器学习笔记——kd树python实现

    千次阅读 2017-12-06 16:01:24
    kd树实现k近邻时当训练数据量较大时,采用线性扫描法(将数据集中的数据与查询点逐个计算距离比对)会导致计算量大效率低下.这时可以利用数据本身蕴含的结构信息,构造数据索引进行快速匹配.索引树便是其中常用的一...
  • 主要介绍了Python语言描述KNN算法与Kd树,具有一定借鉴价值,需要的朋友可以参考下。
  • 提到KD-Tree相信大家应该都不会觉得陌生(不陌生你点进来干嘛[捂脸]),大名鼎鼎的KNN算法就用到了KD-Tree。本文就KD-Tree的基本原理进行讲解,并手把手、肩并肩地带您实现这一算法。完整实现代码请参考本人的p...哦...
  • 在Open3D中构建KD树,以便快速检索最近的邻居。
  • 《统计学习方法》——kd树python实现

    千次阅读 2017-04-18 23:24:06
    kd树原理之前看KNN时,确实发现这个计算量很大。因此有人提出了kd树算法,其作用是,当你需要求得与预测点最近的K个点时,这个算法可以达到O(logN)的时间复杂度(相当于搜索一颗二叉树的...kd树python实现这里给出的是kd
  • KD树python实现

    千次阅读 2020-02-12 19:58:56
    结点类型 class Kd_node: value = [] #节点值 deep = None #节点深度 feature = None #划分标志 left = None #左子树 right = None # 右子树 ...1.建立kd树 def Train(x): """ 训练模型,输入x,y来...
  • python实现KD树

    千次阅读 多人点赞 2018-11-21 08:50:48
    关于KD树的介绍,许多博客已经描述的很清楚了,这里就不再...构建kd树,提高KNN算法的效率(数据结构要自己做出来才有趣) 1. 使用对象方法封装kd树 2. 每一个结点也用对象表示,结点的相关信息保存在实例属性中 ...
  • KD树的时候没把类别考虑进去。。。所以先用KD算出最近的k个点,然后找到对应分类最后输出占比最大的KD树是一种二叉树,用来分割空间上得点一个树节点的结构如下:class TreeNode:index = -1 # 对应维度序号point =...
  • 参考文章:wenffe:python实现KD树 1. kd树的构造 import numpy as np class Node(object): """ 定义节点类: val:节点中的实例点 label:节点中实例的类别 dim:当前节点的分割维度 left:节点的左子树 ...
  • Kd树实现knn算法(python)

    2019-12-01 22:26:34
    python建立kd树,然后实现knn算法,数据集为白酒品质 机器学习课程第二次作业part1,程序借鉴了很多别人的内容,自己目前还是个菜鸡,入门机器学习,道阻且长! 白酒数据集下载地址 (不知道为什么正确率很低,...
  • K近邻算法的kd树实现

    2020-11-26 07:33:57
    以下是kd树python实现 准备工作 #读取数据准备 def file2matrix(filename): fr = open(filename) returnMat = [] #样本数据矩阵 for line in fr.readlines(): line = line.strip().split('\t') returnMat.append...
  • K近邻法之kd树及其Python实现

    千次阅读 2017-07-30 20:01:55
    下面以Python来实现kd树的生成及搜索。 ##Generate KD tree def createTree(dataSet, layer = 0, feature = 2): length = len(dataSet) dataSetCopy = dataSet[:] featureNum = layer % feature dataSetCopy....
  • KD_K._Kd树_

    2021-09-28 21:07:02
    基于KD树实现的python算法,用于KD树的生成及K近邻算法
  • 按照链接里的算法写了k近邻的python实现from math import sqrtclass KDnode:def __init__(self, data, left, right, split):self.left = leftself.right = rightself.split = splitself.data = dataclass KDtree:def...

空空如也

空空如也

1 2 3 4 5 ... 20
收藏数 3,424
精华内容 1,369
关键字:

kd树python