博客
关于我
【Spark MLlib】(四)K-Means 聚类分析
阅读量:366 次
发布时间:2019-03-05

本文共 3309 字,大约阅读时间需要 11 分钟。

使用 Spark 机器学习库来做机器学习工作,可以说是非常的简单,通常只需要在对原始数据进行处理后,然后直接调用相应的 API 就可以实现。但是要想选择合适的算法,高效准确地对数据进行分析,可能还需要深入了解下算法原理,以及相应 Spark MLlib API 实现的参数的意义,本文带你了解 K-means 聚类算法。

文章目录

一、K-means 聚类算法原理

何谓聚类,聚类指的是将数据分类到不同的类或者簇这样的一个过程,所以同一个簇中的对象有很大的相似性,而不同簇间的对象有很大的相异性,聚类与分类的不同在于,聚类所要求划分的类是未知的。

聚类分析是一个无监督学习 (Unsupervised Learning) 过程, 一般是用来对数据对象按照其特征属性进行分组,经常被应用在客户分群,欺诈检测,图像分析等领域。K-means 应该是最有名并且最经常使用的聚类算法了,其原理比较容易理解,并且聚类效果良好,有着广泛的使用。

和诸多机器学习算法一样,K-means 算法也是一个迭代式的算法,其主要步骤如下:

  • 第一步,选择 K 个点作为初始聚类中心。
  • 第二步,计算其余所有点到聚类中心的距离,并把每个点划分到离它最近的聚类中心所在的聚类中去。在这里,衡量距离一般有多个函数可以选择,最常用的是欧几里得距离 (Euclidean Distance), 也叫欧式距离。
  • 第三步,重新计算每个聚类中所有点的平均值,并将其作为新的聚类中心点。
  • 第四步,重复 (二)、(三) 步的过程,直至聚类中心不再发生改变,或者算法达到预定的迭代次数,又或聚类中心的改变小于预先设定的阀值。

在实际应用中,K-means 算法有两个不得不面对并且克服的问题。

  • 聚类个数 K 的选择。K 的选择是一个比较有学问和讲究的步骤,我们会在后文专门描述如何使用 Spark 提供的工具选择 K。
  • 初始聚类中心点的选择。选择不同的聚类中心可能导致聚类结果的差异。

Spark MLlib K-means 算法的实现在初始聚类点的选择上,借鉴了一个叫 K-means||的类 K-means++ 实现。K-means++ 算法在初始点选择上遵循一个基本原则: 初始聚类中心点相互之间的距离应该尽可能的远。基本步骤如下:

  • 第一步,从数据集 X 中随机选择一个点作为第一个初始点。
  • 第二步,计算数据集中所有点与最新选择的中心点的距离 D(x)。
  • 第三步,选择下一个中心点,使得最大。
  • 第四部,重复 (二)、(三) 步过程,直到 K 个初始点选择完成。

二、K-means 实现

Spark MLlib 中 K-means 算法的实现类 (KMeans.scala) 具有以下参数,具体如下:

class KMeans private (    private var k: Int,    private var maxIterations: Int,    private var runs: Int,    private var initializationMode: String,    private var initializationSteps: Int,    private var epsilon: Double,    private var seed: Long) extends Serializable with Logging
  • k 表示期望的聚类的个数。
  • maxInterations 表示方法单次运行最大的迭代次数。
  • runs 表示算法被运行的次数。K-means 算法不保证能返回全局最优的聚类结果,所以在目标数据集上多次跑 K-means 算法,有助于返回最佳聚类结果。
  • initializationMode 表示初始聚类中心点的选择方式, 目前支持随机选择或者 K-means||方式。默认是 K-means||。
  • initializationSteps表示 K-means||方法中的部数。
  • epsilon 表示 K-means 算法迭代收敛的阀值。
  • seed 表示集群初始化时的随机种子。

运行示例

import org.apache.spark.mllib.clustering.KMeansimport org.apache.spark.mllib.linalgimport org.apache.spark.{   SparkConf, SparkContext}import org.apache.spark.mllib.linalg.Vectorsimport org.apache.spark.rdd.RDD object KmeansSpark {      def main(args: Array[String]): Unit = {        //在本地启动Spark    val sparkConf = new SparkConf().setMaster("local[2]").setAppName("KmeansSpark")    val sc = new SparkContext(sparkConf)     //加载本地文件数据形成RDD    val data = sc.textFile("file:///root/test.txt")    val parsedData: RDD[linalg.Vector] = data.map(s=>{         val values: Array[Double] = s.split(" ").map(x => x.toDouble)      Vectors.dense(values)    })     //聚类中心个数    val numClusters = 8    //算法迭代次数    val numIterations = 20    //算法运行次数    val runs = 10    //KMeans训练    val kmeansModel = KMeans.train(parsedData, numClusters, numIterations, runs)     //打印聚类中心ID    kmeansModel.clusterCenters.foreach(x=>{        println(x)    })    //打印数据归属哪个聚类中心ID    parsedData.map(v => v.toString + " belong to cluster: " +kmeansModel.predict(v))    ss.foreach(x=>      println(x)    )    sc.stop()    }}

三、K值的选择

前面提到 K 的选择是 K-means 算法的关键,Spark MLlib 在 KMeansModel 类里提供了 computeCost 方法,该方法通过计算所有数据点到其最近的中心点的平方和来评估聚类的效果。一般来说,同样的迭代次数和算法跑的次数,这个值越小代表聚类的效果越好。但是在实际情况下,我们还要考虑到聚类结果的可解释性,不能一味的选择使 computeCost 结果值最小的那个 K。

val ks:Array[Int] = Array(3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20)ks.foreach(cluster => {    val model:KMeansModel = KMeans.train(parsedData, cluster,30,1) val ssd = model.computeCost(parsedData) println("sum of squared distances of points to their nearest center when k=" + cluster + " -> "+ ssd)})

转载地址:http://npig.baihongyu.com/

你可能感兴趣的文章
MySQL8修改密码的方法
查看>>
Mysql8在Centos上安装后忘记root密码如何重新设置
查看>>
Mysql8在Windows上离线安装时忘记root密码
查看>>
MySQL8找不到my.ini配置文件以及报sql_mode=only_full_group_by解决方案
查看>>
mysql8的安装与卸载
查看>>
MySQL8,体验不一样的安装方式!
查看>>
MySQL: Host '127.0.0.1' is not allowed to connect to this MySQL server
查看>>
Mysql: 对换(替换)两条记录的同一个字段值
查看>>
mysql:Can‘t connect to local MySQL server through socket ‘/var/run/mysqld/mysqld.sock‘解决方法
查看>>
MYSQL:基础——3N范式的表结构设计
查看>>
MYSQL:基础——触发器
查看>>
Mysql:连接报错“closing inbound before receiving peer‘s close_notify”
查看>>
mysqlbinlog报错unknown variable ‘default-character-set=utf8mb4‘
查看>>
mysqldump 参数--lock-tables浅析
查看>>
mysqldump 导出中文乱码
查看>>
mysqldump 导出数据库中每张表的前n条
查看>>
mysqldump: Got error: 1044: Access denied for user ‘xx’@’xx’ to database ‘xx’ when using LOCK TABLES
查看>>
Mysqldump参数大全(参数来源于mysql5.5.19源码)
查看>>
mysqldump备份时忽略某些表
查看>>
mysqldump实现数据备份及灾难恢复
查看>>