2017-05-04 21:55:48 u011746554 阅读数 213
  • 数据挖掘模型篇之R语言实践

    理论与实践结合的方式,通过通俗易懂的教学方式培养学生运用R语言完成常用挖掘模型算法建立及评估,学习完课程可以掌握:线性回归模型、聚类分析、关联规则算法、KNN近邻算法和主成分分析等常用的模型算法实现。针对具体的数据挖掘应用需求,能熟练抽象出可合适的数据挖掘模型,并整理出其技术实现路线。

    3696 人正在学习 去看看 谢佳标

kNN的概念
kNN是一种较为简单的监督学习方法,输入没有标注的新数据后,将新数据的特征与样本集中的每个数据对应的特征比较,然后算法选择出最接近的k的个数据,根据这k个数据判断新数据。如果是分类问题,投票法,加权投票法。回归问题可以是平均法。

实验
这次实践,采用最简单的欧式距离才度量特征间的相似性。数据集来源是“手写数字数据集的光学识别”。

# -*- coding: utf-8 -*-
"""
kNN.py
Created on Thu May 04 12:43:21 2017
@author: holy
"""

from numpy import *
import operator
from os import listdir

def classify0(inX,dataSet,labels,k):
    dataSetSize=dataSet.shape[0]
    diffMat=tile(inX,(dataSetSize,1))-dataSet
    sqDiffMat=diffMat**2
    sqDistances=sqDiffMat.sum(axis=1)
    distances=sqDistances**0.5
    sortedDistIndicies=distances.argsort()
    classCount={}
    for i in range(k):
        voteIlabel=labels[sortedDistIndicies[i]]
        classCount[voteIlabel]=classCount.get(voteIlabel,0)+1
 sortedClassCount=sorted(classCount.iteritems(),key=operator.itemgetter(1),reverse=True)
    return sortedClassCount[0][0]

def img2vector(filename):
    returnVect = zeros((1,1024))
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect[0,32*i+j] = int(lineStr[j])
    return returnVect

def handwritingClassTest():
    hwLabels = []
    trainingFileList = listdir('trainingDigits')           #load the training set
    m = len(trainingFileList)
    trainingMat = zeros((m,1024))
    for i in range(m):
        fileNameStr = trainingFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        hwLabels.append(classNumStr)
        trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
    testFileList = listdir('testDigits')        #iterate through the test set
    errorCount = 0.0
    mTest = len(testFileList)
    for i in range(mTest):
        fileNameStr = testFileList[i]
        fileStr = fileNameStr.split('.')[0]     #take off .txt
        classNumStr = int(fileStr.split('_')[0])
        vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
        classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3)
        print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
        if (classifierResult != classNumStr): errorCount += 1.0
    print "\nthe total number of errors is: %d" % errorCount
    print "\nthe total error rate is: %f" % (errorCount/float(mTest))

下面是测试代码

# -*- coding: utf-8 -*-
"""
Created on Thu May 04 15:31:17 2017
@author: holy
"""
import kNN
kNN.handwritingClassTest()

这里写图片描述

对于1000个数据集,错误率为0.011。

问题分析
从实验中可以体会到,每个样例预测都要使用到整个测试集且每个要对面的样本要进行距离的计算。
那么kNN的优缺点是什么呢?
另外,
k值的选取?
相似性的度量?
分类决策规则?
有没有对kNN的改进算法?

2019-11-25 17:19:00 Python_Matlab 阅读数 26
  • 数据挖掘模型篇之R语言实践

    理论与实践结合的方式,通过通俗易懂的教学方式培养学生运用R语言完成常用挖掘模型算法建立及评估,学习完课程可以掌握:线性回归模型、聚类分析、关联规则算法、KNN近邻算法和主成分分析等常用的模型算法实现。针对具体的数据挖掘应用需求,能熟练抽象出可合适的数据挖掘模型,并整理出其技术实现路线。

    3696 人正在学习 去看看 谢佳标

算法简介

K-近邻算法(KNN)是通过测量不同特征值之间的距离进行分类。
该算法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别,其中K通常是不大于20的整数。KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。


算法原理

在训练集数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:
1)计算从测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取截距最小的K个点,其中K一般不大于20且为奇数;
4)确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。
在KNN中,通过计算对象间距来作为各个对象之间的非相似性指标,避免了对象之间的匹配问题,通常采用欧式距离,欧式距离也称为欧几里得距离;
在这里插入图片描述
还有其他距离衡量亦可以,比如:余弦值距离(cos), 相关度(correlation), 曼哈顿距离 (Manhattan distance)。

假设我们现在有两种不同形状的点分别为黑色矩形和蓝色圆行,分布在二维空间中,这就对应了训练样点包含的两个类别,且特征数量为3。如果我们希望推测图中红色五角星那个点是属于哪个类别,那么KNN算法将会计算该待推测点与所有训练样点之间的距离,并且挑出距离最小的K个样点(此处分别设K=1、K=5),则图中圈起来的点将被视为待推测点类别的参考依据。当K=1时,圈起来的点除了待测点之外就只有矩形,故预测该点为黑色矩形类别的;当K=5时,圈起来的点中有4个是圆行,1个是矩形,针对这种情况,KNN通常采用投票法来进行推测,即找出K个样本中类别出现次数最多的那个类别,因此该待推测点的类型值即被推测为蓝色圆形类别。
K=1情况
K=5的情况
K的取值对算法的结果有影响,K太小,分类结果容易受到噪点的影响,误差会增大;K太大,近邻中可能包含太多其他类别的点(对距离加权,可以降低K值设定的影响);K=N(样本数),则完全不足取,因为此时无论输入实例是什么,都只是简单的预测它属于在训练实例中最多的类,模型过于简单,忽略了训练实例中大量有用信息。
一般采用8:2或者6:4对训练集切分,通过交叉与验证得出最佳K值。

程序:

# KNN算法
def knn(x_test, x_data, y_data, k):
    '''
    x_test:测试数据
    x_data:已知数据
    y_data:已知数据的标签
    K:选择K个最近的实例
    返回k个中标签最多的类别
    '''
    # 计算样本数量
    x_data_size = x_data.shape[0]
    # 复制x_test
    np.tile(x_test, (x_data_size,1))
    # 计算x_test与每一个样本的差值
    diffMat = np.tile(x_test, (x_data_size,1)) - x_data
    # 计算差值的平方
    sqDiffMat = diffMat**2
    # 求和
    sqDistances = sqDiffMat.sum(axis=1)
    # 开方
    distances = sqDistances**0.5
    # 从小到大排序
    sortedDistances = distances.argsort()
    classCount = {}
    for i in range(k):
        # 获取标签
        votelabel = y_data[sortedDistances[i]]
        # 统计标签数量
        classCount[votelabel] = classCount.get(votelabel,0) + 1
    # 根据operator.itemgetter(1)-第1个值对classCount排序,然后再取倒序
    sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1), reverse=True)
    # 获取数量最多的标签
    return sortedClassCount[0][0]
2018-11-27 11:24:37 qq_31721173 阅读数 74
  • 数据挖掘模型篇之R语言实践

    理论与实践结合的方式,通过通俗易懂的教学方式培养学生运用R语言完成常用挖掘模型算法建立及评估,学习完课程可以掌握:线性回归模型、聚类分析、关联规则算法、KNN近邻算法和主成分分析等常用的模型算法实现。针对具体的数据挖掘应用需求,能熟练抽象出可合适的数据挖掘模型,并整理出其技术实现路线。

    3696 人正在学习 去看看 谢佳标

CS321n入门之KNN(1)——一个初学者的随学笔记

来源:吃树叶的土豆

视频中的第一个实例就是一个简单的KNN算法,作为小白特别是对py特性还不是很清楚的初学者。刷完视频之后再去看作业PPt的时候着实是一脸懵逼。通过几个小时的调试勉强弄明白是什么情况,同时也希望给同样是小白的初学者分享一些经验,以提高学习效率。

1、了解数据集

视频中实现KNN进行图像分类算法的数据集是:CIFAR-10。所以我们从这个数据集开始分析:
通过调试我们可以很直观的看到CIFAR-10数据集中的数据呈现状况如下:
1)batch_label: 用于标记数据集的类型,如:这里的训练集:training batch
2)label:标签,CIFAR-10是具有10种类型的图像数据库。这里的标签是指是数据集中的每张图片属于10种类型中的哪一类:“0~9”
3)data:数据,这里存的是数据集中所有图像的各个像素值。实际上是将nm二维的图像数据先转化成一位数组,由此data中的每一行则代表一张图像。图像大小为3232
4)filename:文件名,记录数据集中所有图像的文件名
如下图所示:
在这里插入图片描述

在这里插入图片描述

2、相关步骤

该算法实现的主要步骤可以分为以下:
1)数据处理:将数据集中的数据进行格式标准处理,这里通常会使用到numpy数据处理库。
2)模型建立:利用已封装好或自定义的模型进行训练,也就是class
3)评价指标:通过测试集的分类结果和测试集中已标记的label进行比较,求得该模型训练结果的有效性。由此来判断一个模型的好坏。
4)模型优化:调整超参数。在本算法中存在的超参数包括两个。其一:distance距离公式的选择,通常是曼哈顿距离和欧式距离;其二,k值的设定。

3、代码

import pickle as p
import matplotlib.pyplot as plt
import numpy as np


# NearestNeighbor class
class NearestNeighbor(object):
    def __init__(self):
        pass

    def train(self, X, y):
        """ X is N x D where each row is an example. Y is 1-dimension of size N """
        # the nearest neighbor classifier simply remembers all the training data
        self.Xtr = X
        self.ytr = y

    def predict(self, X):
        """ X is N x D where each row is an example we wish to predict label for """
        num_test = X.shape[0]#获取数据大小
        print(num_test)
        # lets make sure that the output type matches the input type
        Ypred = np.zeros(num_test, dtype=self.ytr.dtype)

        # loop over all test rows
        for i in range(num_test):
            # find the nearest training image to the i'th test image
            # using the L1 distance (sum of absolute value differences)
            distances = np.sum(np.sqrt(pow(self.Xtr - X[i, :],2)), axis=1)
            print(distances)
            min_index = np.argmin(distances)  # get the index with smallest distance
            Ypred[i] = self.ytr[min_index]  # predict the label of the nearest example

        return Ypred


def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    #打开文件赋予权限
    with open(filename, 'rb')as f:
        datadict = p.load(f, encoding='latin1') #建立文件读取变量
        X = datadict['data']#读取data字段
        Y = datadict['labels']#读取labels字段
        #print(Y)
        Y = np.array(Y)  # 字典里载入的Y是list类型,把它变成array类型,具体是将Y中的逗号去掉
        #print(Y)
        return X, Y


def load_CIFAR_Labels(filename):
    with open(filename, 'rb') as f:
        label_names = p.load(f, encoding='latin1')
        names = label_names['label_names']
        return names


# load data
label_names = load_CIFAR_Labels("cifar-10-batches-py/batches.meta")#读取数据集中数据名
imgX1, imgY1 = load_CIFAR_batch("cifar-10-batches-py/data_batch_1")
imgX2, imgY2 = load_CIFAR_batch("cifar-10-batches-py/data_batch_2")
imgX3, imgY3 = load_CIFAR_batch("cifar-10-batches-py/data_batch_3")
imgX4, imgY4 = load_CIFAR_batch("cifar-10-batches-py/data_batch_4")
imgX5, imgY5 = load_CIFAR_batch("cifar-10-batches-py/data_batch_5")#分别读取数据集中的label和data字段
Xte_rows, Yte = load_CIFAR_batch("cifar-10-batches-py/test_batch")#测试集

Xtr_rows = np.concatenate((imgX1, imgX2, imgX3, imgX4, imgX5))
print(Xtr_rows)
Ytr_rows = np.concatenate((imgY1, imgY2, imgY3, imgY4, imgY5))
print(Ytr_rows)

nn = NearestNeighbor()  # create a Nearest Neighbor classifier class
nn.train(Xtr_rows[:1000, :], Ytr_rows[:1000])  # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte_rows[:100, :])  # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print('accuracy: %f' % (np.mean(Yte_predict == Yte[:100])))#计算准确率
print('fenlei:' ,Yte_predict)
print('ceshibiaoqian:',Yte)
# show a picture
image = imgX1[6, 0:1024].reshape(32, 32)
print(image.shape)
plt.imshow(image, cmap=plt.cm.gray)
plt.axis('off')  # 去除图片边上的坐标轴
plt.show()

image = imgX2[6, 0:1024].reshape(32, 32)
print(image.shape)
plt.imshow(image, cmap=plt.cm.gray)
plt.axis('off')  # 去除图片边上的坐标轴
plt.show()
image = imgX3[6, 0:1024].reshape(32, 32)
print(image.shape)
plt.imshow(image, cmap=plt.cm.gray)
plt.axis('off')  # 去除图片边上的坐标轴
plt.show()
image = imgX4[6, 0:1024].reshape(32, 32)
print(image.shape)
plt.imshow(image, cmap=plt.cm.gray)
plt.axis('off')  # 去除图片边上的坐标轴
plt.show()
image = imgX5[6, 0:1024].reshape(32, 32)
print(image.shape)
plt.imshow(image, cmap=plt.cm.gray)
plt.axis('off')  # 去除图片边上的坐标轴
plt.show()
'''
image = imgX6[6, 0:1024].reshape(32, 32)
print(image.shape)
plt.imshow(image, cmap=plt.cm.gray)
plt.axis('off')  # 去除图片边上的坐标轴
plt.show()

4、相关运行结果

1)曼哈顿距离运行结果

在这里插入图片描述

2)欧式距离运行结果

在这里插入图片描述

5、建议

各位和博主一样的小白,可以通过读取各个变量中的数据来了解数据集的构成。以及各个模型中的具体操作。祝大家学习顺利。

(未完待续)

2018-11-25 20:41:21 u011583316 阅读数 2169
  • 数据挖掘模型篇之R语言实践

    理论与实践结合的方式,通过通俗易懂的教学方式培养学生运用R语言完成常用挖掘模型算法建立及评估,学习完课程可以掌握:线性回归模型、聚类分析、关联规则算法、KNN近邻算法和主成分分析等常用的模型算法实现。针对具体的数据挖掘应用需求,能熟练抽象出可合适的数据挖掘模型,并整理出其技术实现路线。

    3696 人正在学习 去看看 谢佳标

下周二算法课需要讲一个算法PPT,趁着自己在学习大数据,最佳的算法选择方向无疑是机器学习了。除了K-means我还接触过KNN以及反向传播神经网络。等到后面在系统学习复习(开天辟地)的时候再做一个详细的梳理。

K-means


K-means算法是硬聚类算法,是典型的基于原型的目标函数聚类方法的代表,它是数据点到原型的某种距离作为优化的目标函数,利用函数求极值的方法得到迭代运算的调整规则。K-means算法以欧式距离作为相似度测度,它是求对应某一初始聚类中心向量V最优分类,使得评价指标J最小。算法采用误差平方和准则函数作为聚类准则函数。

在这里插入图片描述
算法使用小案例

假设有一批人的年龄的数据,大致知道其中有一堆少年儿童,一堆青年人,一堆老年人。.

聚类就是自动发现这三堆数据,并把相似的数据聚合到同一堆中。所以对于这个例子,如果要聚成3堆的话,那么输入就是一堆年龄数据,注意,此时的年龄数据并不带有类标号,也就是说我只知道里面大致有三堆人,至于谁是哪一堆,现在是不知道的,而输出就是每个数据所属的类标号,聚类完成之后,就知道谁和谁是一堆了。

什么叫聚类?

聚类的目标:将一组向量分成若干组,组内数据是相似的,而组间数据是有较明显差异。

与分类区别:分类与聚类最大的区别在于分类的目标事先已知,聚类也被称为无监督机器学习

K 是什么?

K是聚类算法中当前类的个数

Means 是什么?

means是均值算法

算法描述:

算法核心很简单,我感觉K打头的算法都挺简单的。相信你看了也会对机器学习充满自信。

  • 任选K个点作为初始聚类中心
  • 根据每个聚类的中心,计算每个对象与这些中心的距离,并根据最小距离重新对对象进行划分
  • 重新计算每个聚类的中心(质心实际上不存在的)
  • 当满足一定条件,如类别划分不在发生变化时,算法终止,否则继续2-3



使用场景

  • 样本球形分布
  • 密度,大小不同的聚类

时间复杂度:

该算法的时间复杂度为:O(nkt)

  • n->聚类对象数
  • t->迭代次数
  • k->初始中心个数

平面划分

两点之间的垂直平分线,迭代划分。

欧式距离

对于欧式空间的样本数据,以平方误差和(sum of the squared error, SSE)作为聚类的目标函数,同时也可以衡量不同聚类结果好坏的指标:

表示样本点x到cluster Ci 的质心 ci 距离平方和;最优的聚类结果应使得SSE达到最小值。

 //欧式距离,计算两点距离
    public double EurDistance(Point point, Point center)
    {
        double detX = point.getX() - center.getX();
        double detY = point.getY() - center.getY();

        return Math.sqrt(detX * detX + detY * detY);
    }

重新计算每个聚类的中心对象

中心对象:均值

  /*
     * 调整聚类中心,按照求平衡点的方法获得新的簇心
     */
    public void adjustCenters()
    {
        double sumx[] = new double[k];
        double sumy[] = new double[k];
        int count[] = new int[k];

        // 保存每个簇的横纵坐标之和
        for (int i = 0; i < k; i++)
        {
            sumx[i] = 0.0;
            sumy[i] = 0.0;
            count[i] = 0;
        }

        // 计算每个簇的横纵坐标总和、记录每个簇的个数
        for (Point point : points)
        {
            int clusterID = point.getClusterID();

            // System.out.println(clusterID);
            sumx[clusterID - 1] += point.getX();
            sumy[clusterID - 1] += point.getY();
            count[clusterID - 1]++;
        }

        // 更新簇心坐标
        for (int i = 0; i < k; i++)
        {
            Point tmpPoint = centers.get(i);
            tmpPoint.setX(sumx[i] / count[i]);
            tmpPoint.setY(sumy[i] / count[i]);
            tmpPoint.setClusterID(i + 1);

            centers.set(i, tmpPoint);
        }
    }

终止条件

最小化蔟内对象到质心的距离,从而最小化WCSS。通过损失函数来衡量算法停止条件。

  • 损失函数:WCSS

xi代表某个样本点,ck代表每个类的中心点。每个类里的元素越凝聚越好。

中心点选择

  • K(类别)的选择 细粒度 越多越准确
  • 随机选取
  • 多次随机:选择最小的WCSS那次聚类

算法优点

  1. 时间复杂度低,速度快。
  2. 由具有出色的速度和良好的可扩展性
  3. 当簇接近高斯分布时,它的效果较好
    算法缺点
  • 在簇的平均值可被定义的情况下才能使用,可能不适用于某些应用;
  • 在 K-means 算法中 K 是事先给定的,这个 K 值的选定是非常难以估计的。很多时候,事先并不知道给定的数据集应该分成多少个类别才最合适;
  • 在 K-means 算法中,首先需要根据初始聚类中心来确定一个初始划分,然后对初始划分进行优化。这个初始聚类中心的选择对聚类结果有较大的影响,一旦初始值选择的不好,可能无法得到有效的聚类结果;
  • 该算法需要不断地进行样本分类调整,不断地计算调整后的新的聚类中心,因此当数据量非常大时,算法的时间开销是非常大的;
  • 若簇中含有异常点,将导致均值偏离严重(即:对噪声和孤立点数据敏感);
  • 不适用于发现非凸形状的簇或者大小差别很大的簇。

你可别说为啥这算法缺点比有点还多。“存在即合理!”

Java实现


这里结果可视化我用了JfreeChart图表绘制类库。数据集是iris数据集。另外提醒一下就是使用本代码时记得创建文件输出路径,以及导入绘图的核心包。

package com.xianglei.kmeansback;

public class Point
{
    // 点的坐标
    private Double x;
    private Double y;
    private Double z;
    public int getTag() {
		return tag;
	}



	public void setTag(int tag) {
		this.tag = tag;
	}

	private Double w;

    // 所在类ID
    private int clusterID = -1;
    private int tag=0;

    public Point(Double x, Double y,Double w, Double z) {

        this.x = x;
        this.y = y;
        this.w = w;
        this.z = z;
    }

  

    public Double getX()
    {
        return x;
    }

    public void setX(Double x)
    {
        this.x = x;
    }

    public Double getY()
    {
        return y;
    }

    public Double getZ() {
		return z;
	}



	@Override
	public String toString() {
		return "Point [x=" + x + ", y=" + y + ", z=" + z + ", w=" + w + ", clusterID=" + clusterID + "]";
	}



	public void setZ(Double z) {
		this.z = z;
	}



	public Double getW() {
		return w;
	}



	public void setW(Double w) {
		this.w = w;
	}



	public void setY(Double y)
    {
        this.y = y;
    }

    public int getClusterID()
    {
        return clusterID;
    }

    public void setClusterID(int clusterID)
    {
        this.clusterID = clusterID;
    }
}

package com.xianglei.kmeansback;

import java.util.ArrayList;
import java.util.List;

public class KMeansCluster
{
    // 聚类中心数
    public int k = 5;

    // 迭代最大次数
    public int maxIter = 50;

    // 测试点集
    public List<Point> points;

    // 中心点
    public List<Point> centers;

    public static final double MINDISTANCE = 10000.00;

    public KMeansCluster(int k, int maxIter, List<Point> points) {
        this.k = k;
        this.maxIter = maxIter;
        this.points = points;

        //初始化中心点
        initCenters();
    }

    /*
     * 初始化聚类中心
     * 这里的选取策略是,从点集中按序列抽取K个作为初始聚类中心
     */
    public void initCenters()
    {
        centers = new ArrayList<>(k);

        for (int i = 0; i < k; i++)
        {
            Point tmPoint = points.get(i*33+48);
            Point center = new Point(tmPoint.getX(), tmPoint.getY(),tmPoint.getW(), tmPoint.getZ());
            center.setClusterID(i + 1);
            centers.add(center);
        }
    }


    /*
     * 停止条件是满足迭代次数
     */
    public void runKmeans()
    {
        // 已迭代次数
        int count = 1;

        while (count++ <= maxIter)
        {
            // 遍历每个点,确定其所属簇
            for (Point point : points)
            {
            	if(point.getTag()==0)
                assignPointToCluster(point);
            	
            }

            //调整中心点
            adjustCenters();
        }
    }



    /*
     * 调整聚类中心,按照求平衡点的方法获得新的簇心
     */
    public void adjustCenters()
    {
        double sumx[] = new double[k];
        double sumy[] = new double[k];
        double sumw[] = new double[k];
        double sumz[] = new double[k];
        int count[] = new int[k];

        // 保存每个簇的横纵坐标之和 K=3
        for (int i = 0; i < k; i++)
        {
            sumx[i] = 0.0;
            sumy[i] = 0.0;
            sumw[i] = 0.0;
            sumz[i] = 0.0;
            count[i] = 0;
        }

        // 计算每个簇的横纵坐标总和、记录每个簇的个数
        for (Point point : points)
        {
        	if(point.getTag()==0){
            int clusterID = point.getClusterID();

            // System.out.println(clusterID);
            sumx[clusterID - 1] += point.getX();
            sumy[clusterID - 1] += point.getY();
            sumw[clusterID - 1] += point.getW();
            sumz[clusterID - 1] += point.getZ();
            count[clusterID - 1]++;
        	}
        }

        // 更新簇心坐标
        for (int i = 0; i < k; i++)
        {
            Point tmpPoint = centers.get(i);
            tmpPoint.setX(sumx[i] / count[i]);
            tmpPoint.setY(sumy[i] / count[i]);
            tmpPoint.setW(sumw[i] / count[i]);
            tmpPoint.setZ(sumz[i] / count[i]);
            tmpPoint.setClusterID(i + 1);

            centers.set(i, tmpPoint);
        }
    }


    /*划分点到某个簇中,欧式距离标准
     * 对传入的每个点,找到与其最近的簇中心点,将此点加入到簇
     */
    public void assignPointToCluster(Point point)
    {
        double minDistance = MINDISTANCE;

        int clusterID = -1;

        for (Point center : centers)
        {
            double dis = EurDistance(point, center);
            if (dis < minDistance)
            {
                minDistance = dis;
                clusterID = center.getClusterID();
            }
        }
        point.setClusterID(clusterID);

    }

    //欧式距离,计算两点距离
    public double EurDistance(Point point, Point center)
    {
        double detX = point.getX() - center.getX();
        double detY = point.getY() - center.getY();
        double detW = point.getW() - center.getW();
        double detZ = point.getZ() - center.getZ();

        return Math.sqrt(detX * detX + detY * detY+ detW * detW+ detZ * detZ);
    }
}


package com.xianglei.kmeansback;

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Font;
import java.awt.Image;
import java.awt.image.ImageObserver;
import java.awt.image.ImageProducer;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Scanner;

import org.jfree.chart.ChartFactory;
import org.jfree.chart.ChartFrame;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.StandardChartTheme;
import org.jfree.chart.annotations.XYTextAnnotation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.Plot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.XYLineAndShapeRenderer;
import org.jfree.data.xy.DefaultXYDataset;
import org.jfree.data.xy.XYDataset;
import org.jfree.ui.RefineryUtilities;

public class Kmean {
	// 用来聚类的点集
	public List<Point> points;

	// 将聚类结果保存到文件
	FileWriter out = null;

	// 格式化double类型的输出,保留两位小数
	DecimalFormat dFormat = new DecimalFormat("00.00");

	// 具体执行聚类的对象
	public KMeansCluster kMeansCluster;

	// 簇的数量,迭代次数
	public int numCluster = 0;

	public int numIterator = 200;

	// 点集的数量,生成指定数量的点集
	public int numPoints = 50;

	// 聚类结果保存路径
	public static final String FILEPATH = "f:/kmeans/res.txt";
	public static final String DATAPATH = "f:/kmeans/iris.txt";

	public static void main(String[] args) {
		// 指定点集个数,簇的个数,迭代次数
		Kmean kmeans = new Kmean(0, 3, 200000);

		// 初始化点集、KMeansCluster对象
		kmeans.init();

		// 使用KMeansCluster对象进行聚类
		kmeans.runKmeans();

		kmeans.printRes();
		kmeans.Test();
		kmeans.saveResToFile(FILEPATH);

	}

	

	private void Test() {
		// TODO Auto-generated method stub
		
	}



	public Kmean(int numPoints, int cluster_number, int iterrator_number) {

		this.numPoints = numPoints;
		this.numCluster = cluster_number;
		this.numIterator = iterrator_number;
	}

	private void init() {
		this.initPoints();
		kMeansCluster = new KMeansCluster(numCluster, numIterator, points);
	}

	private void runKmeans() {
		kMeansCluster.runKmeans();
	}

	// 初始化点集
	public void initPoints() {
		points = new ArrayList<>(numPoints);

		try {
			Scanner in = new Scanner(new File(DATAPATH));// 读入文件
			while (in.hasNextLine() && numPoints <= 154) {
				Point tmpPoint = new Point(null, null, null, null);
				numPoints++;
				String str = in.nextLine();// 将文件的每一行存到str的临时变量中
				String[] split = str.split(" ");

				tmpPoint.setX(Double.valueOf(split[1]));
				tmpPoint.setY(Double.valueOf(split[2]));
				tmpPoint.setZ(Double.valueOf(split[3]));
				tmpPoint.setW(Double.valueOf(split[4]));
				if (split[5].contains("setosa"))
					tmpPoint.setClusterID(1);
				if (split[5].contains("versicolor"))
					tmpPoint.setClusterID(2);
				if (split[5].contains("virginica"))
					tmpPoint.setClusterID(3);

				points.add(tmpPoint);
				System.out.println(numPoints);
				System.out.println(split[1] + "-" + split[2] + "-" + split[3] + "-" + split[4] + "-" + split[5]
						+ " - - - 类别:" + tmpPoint.getClusterID());
			}
	
		} catch (Exception e) {

		}

	}

	public void printRes() {

		System.out.println("==================Centers-I====================");
		for (Point center : kMeansCluster.centers) {
			System.out.println(center.toString());
		}

		System.out.println("==================Points====================");

		for (Point point : points) {
			if (point.getTag() == 0)
				System.out.println(point.toString());
		}
	}

	public void saveResToFile(String filePath) {
		try {
			out = new FileWriter(new File(filePath));
			String[] stinga = new String[numPoints];
			String[] stingb = new String[numPoints];
			String[] tag = new String[numPoints];
			int i = 0;
			for (Point point : points) {
				if (point.getTag() == 0) {
					out.write(String.valueOf(point.getClusterID()));
					out.write("-");

					out.write(dFormat.format(point.getX()));
					out.write("-");
					out.write(dFormat.format(point.getY()));
					out.write(dFormat.format(point.getW()));
					out.write("-");
					out.write(dFormat.format(point.getZ()));
					out.write("\r\n");

					stinga[i] = Double.toString(point.getZ());
					stingb[i] = Double.toString(point.getW());
					tag[i] = Double.toString(point.getClusterID());
					i++;
					System.out.println("=================================");
					System.out.println("聚类后结果:" + point.toString());
				}
			}

			data("k-means", stinga, stingb, tag);
			out.flush();
			out.close();

		} catch (IOException e) {
			e.printStackTrace();
		}
	}

	public static void data(String title, String[] a, String[] b, String[] t) {
		DefaultXYDataset xydataset = new DefaultXYDataset();

		double[][] data = new double[2][a.length];
		double[][] data2 = new double[2][a.length];
		double[][] data3 = new double[2][a.length];

		for (int i = 0; i < a.length; i++) {

			if (t[i].contains("1")) {
				data[0][i] = Double.parseDouble(a[i]);
				data[1][i] = Double.parseDouble(b[i]);

			}
			if (t[i].contains("2")) {
				data2[0][i] = Double.parseDouble(a[i]);
				data2[1][i] = Double.parseDouble(b[i]);

			}
			if (t[i].contains("3")) {
				data3[0][i] = Double.parseDouble(a[i]);
				data3[1][i] = Double.parseDouble(b[i]);
			}

		}

		xydataset.addSeries("Setosa", data);
		xydataset.addSeries("Versicolor", data2);
		xydataset.addSeries("Virginica", data3);

		final JFreeChart chart = ChartFactory.createScatterPlot("K-Means", "X", "Y", xydataset,
				PlotOrientation.VERTICAL, true, true, false);

		chart.setBorderVisible(false);

		XYPlot xyPlot2 = chart.getXYPlot();
		xyPlot2.getRenderer().setSeriesPaint(0, Color.RED);
		xyPlot2.getRenderer().setSeriesPaint(1, Color.GREEN);
		xyPlot2.getRenderer().setSeriesPaint(2, Color.black);

		ChartFrame frame = new ChartFrame(title, chart);
		frame.pack();
		RefineryUtilities.centerFrameOnScreen(frame);
		frame.setVisible(true);
	}

}

聚类结果

一:随机模拟数据集
在这里插入图片描述

二.iris数据集

在这里插入图片描述


总结

就算法的思想和实现来讲是比较简单的,也为我的算法学习开了个好头。但是我看了其他大牛的博客后,发现这个算法更多可玩的地方是如何去解决这些缺点。但是由于时间关系,在当下这个时间段我就不深挖了,你要是有兴趣的话给你一个传送门深入理解 K-means算法。???

KNN算法
2016-12-14 19:03:19 lxqluo 阅读数 498
  • 数据挖掘模型篇之R语言实践

    理论与实践结合的方式,通过通俗易懂的教学方式培养学生运用R语言完成常用挖掘模型算法建立及评估,学习完课程可以掌握:线性回归模型、聚类分析、关联规则算法、KNN近邻算法和主成分分析等常用的模型算法实现。针对具体的数据挖掘应用需求,能熟练抽象出可合适的数据挖掘模型,并整理出其技术实现路线。

    3696 人正在学习 去看看 谢佳标

1. 百度百科概念介绍

http://baike.baidu.com/link?url=5qCFfNMvGkBLEzdoCIBVwvw8MAHqlKdEREaafle5_BOUffJvIw1cXZT_n1MwNXMoKdqtZVG0-_haJtUgV3l8DpWT0cr5h3-IpnOqkHmN0hjMyIPwxjjcd0o9DaVEqWql


2. PPT介绍

http://wenku.baidu.com/link?url=Ay-8o704KSL1mIJTw55qa5sPuuW34lc0FvUatD-NnitHA_oZZPWuaNBiU7UK9VxV1FUyGbGqgPfSP4gTIxi04opkdKe_jHi02RqP3Tp5nie


3. KNN用途

https://www.douban.com/note/176119064/


4. KNN距离公式比较

http://liyonghui160com.iteye.com/blog/2084557


5. KNN算法介绍

http://liyonghui160com.iteye.com/blog/2084557


6. K最近邻(KNN)算法的java实现

http://www.open-open.com/lib/view/open1409822475713.html


7. 各种分类算法比较

http://bbs.pinggu.org/thread-2604496-1-1.html

PPT机器学习概述

阅读数 4888

k-近邻算法

阅读数 12

没有更多推荐了,返回首页