找回密码
 立即注册
查看: 317|回复: 0

Spark MLlib KMeans聚类算法

[复制链接]
发表于 2022-4-14 17:30 | 显示全部楼层 |阅读模式
1.1 KMeans聚类算法

1.1.1 基础理论

KMeans算法的基本思想是初始随机给定K个簇中心,按照最邻近原则把待分类样本点分到各个簇。然后按平均法重新计算各个簇的质心,从而确定新的簇心。一直迭代,直到簇心的移动距离小于某个给定的值。

K-Means聚类算法主要分为三个步骤:

(1)第一步是为待聚类的点寻找聚类中心;

(2)第二步是计算每个点到聚类中心的距离,将每个点聚类到离该点最近的聚类中去;

(3)第三步是计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心;

反复执行(2)、(3),直到聚类中心不再进行大范围移动或者聚类次数达到要求为止。

1.1.2过程演示

下图展示了对n个样本点进行K-means聚类的效果,这里k取2:

(a)未聚类的初始点集;

(b)随机选取两个点作为聚类中心;

(c)计算每个点到聚类中心的距离,并聚类到离该点最近的聚类中去;

(d)计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心;

(e)重复(c),计算每个点到聚类中心的距离,并聚类到离该点最近的聚类中去;

(f)重复(d),计算每个聚类中所有点的坐标平均值,并将这个平均值作为新的聚类中心。





参照以下文档:

http://blog.sina.com.cn/s/blog_62186b46010145ne.html

1.2 Spark Mllib KMeans源码分析

class KMeansprivate (

privatevar k: Int,

privatevar maxIterations: Int,

privatevar runs: Int,

privatevar initializationMode: String,

privatevar initializationSteps: Int,

privatevar epsilon: Double,

privatevar seed: Long)extends Serializablewith Logging {

// KMeans类参数:

k:聚类个数,默认2maxIterations:迭代次数,默认20runs:并行度,默认1

initializationMode:初始中心算法,默认"k-means||"initializationSteps:初始步长,默认5epsilon:中心距离阈值,默认1e-4seed:随机种子。

/**

* Constructs a KMeans instance with default parameters: {k: 2, maxIterations: 20, runs: 1,

* initializationMode: "k-means||", initializationSteps: 5, epsilon: 1e-4, seed: random}.

*/

defthis() =this(2,20, 1, KMeans.K_MEANS_PARALLEL,5, 1e-4, Utils.random.nextLong())
// 参数设置

/** Set the number of clusters to create (k). Default: 2. */

def setK(k: Int):this.type = {

this.k = k

this

}
**省略各个参数设置代码**

// run方法,KMeans主入口函数

/**

* Train a K-means model on the given set of points; `data` should be cached for high

* performance, because this is an iterative algorithm.

*/

def run(data: RDD[Vector]): KMeansModel = {


if (data.getStorageLevel == StorageLevel.NONE) {

logWarning("The input data is not directly cached, which may hurt performance if its"

+ " parent RDDs are also uncached.")

}


// Compute squared norms and cache them.

// 计算每行数据的L2范数,数据转换:data[Vector]=> data[(Vector, norms)],其中norms是Vector的L2范数,norms就是

val norms = data.map(Vectors.norm(_,2.0))

norms.persist()

val zippedData = data.zip(norms).map {case (v, norm) =>

new VectorWithNorm(v, norm)

}

val model = runAlgorithm(zippedData)

norms.unpersist()


// Warn at the end of the run as well, for increased visibility.

if (data.getStorageLevel == StorageLevel.NONE) {

logWarning("The input data was not directly cached, which may hurt performance if its"

+ " parent RDDs are also uncached.")

}

model

}

// runAlgorithm方法,KMeans实现方法。

/**

* Implementation of K-Means algorithm.

*/

privatedef runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {


val sc = data.sparkContext


val initStartTime = System.nanoTime()


val centers =if (initializationMode == KMeans.RANDOM) {

initRandom(data)

} else {

initKMeansParallel(data)

}


val initTimeInSeconds = (System.nanoTime() - initStartTime) /1e9

logInfo(s"Initialization with $initializationMode took " +"%.3f".format(initTimeInSeconds) +

" seconds.")


val active = Array.fill(runs)(true)

val costs = Array.fill(runs)(0.0)


var activeRuns =new ArrayBuffer[Int] ++ (0 until runs)

var iteration =0


val iterationStartTime = System.nanoTime()

//KMeans迭代执行,计算每个样本属于哪个中心点,中心点累加样本的值及计数,然后根据中心点的所有的样本数据进行中心点的更新,并比较更新前的数值,判断是否完成。其中runs代表并行度。

// Execute iterations of Lloyd's algorithm until all runs have converged

while (iteration < maxIterations && !activeRuns.isEmpty) {

type WeightedPoint = (Vector, Long)

def mergeContribs(x: WeightedPoint, y: WeightedPoint): WeightedPoint = {

axpy(1.0, x._1, y._1)

(y._1, x._2 + y._2)

}


val activeCenters = activeRuns.map(r => centers(r)).toArray

val costAccums = activeRuns.map(_ => sc.accumulator(0.0))


val bcActiveCenters = sc.broadcast(activeCenters)


// Find the sum and count of points mapping to each center

//计算属于每个中心点的样本,对每个中心点的样本进行累加和计算;

runs代表并行度,k中心点个数,sums代表中心点样本累加值,counts代表中心点样本计数;

contribs代表((并行度I,中心J),(中心J样本之和,中心J样本计数和));

findClosest方法:找到点与所有聚类中心最近的一个中心

val totalContribs = data.mapPartitions { points =>

val thisActiveCenters = bcActiveCenters.value

val runs = thisActiveCenters.length

val k = thisActiveCenters(0).length

val dims = thisActiveCenters(0)(0).vector.size


val sums = Array.fill(runs, k)(Vectors.zeros(dims))

val counts = Array.fill(runs, k)(0L)


points.foreach { point =>

(0 until runs).foreach { i =>

val (bestCenter, cost) = KMeans.findClosest(thisActiveCenters(i), point)

costAccums(i) += cost

val sum = sums(i)(bestCenter)

axpy(1.0, point.vector, sum)

counts(i)(bestCenter) += 1

}

}


val contribs =for (i <-0 until runs; j <-0 until k) yield {

     ((i, j), (sums(i)(j), counts(i)(j)))

}

contribs.iterator

}.reduceByKey(mergeContribs).collectAsMap()

//更新中心点,更新中心点= sum/count

判断newCentercenters之间的距离是否 > epsilon * epsilon;

// Update the cluster centers and costs for each active run

for ((run, i) <- activeRuns.zipWithIndex) {

var changed =false

var j =0

while (j < k) {

val (sum, count) = totalContribs((i, j))

if (count !=0) {

scal(1.0 / count, sum)

val newCenter =new VectorWithNorm(sum)

if (KMeans.fastSquaredDistance(newCenter, centers(run)(j)) > epsilon * epsilon) {

changed = true

}

centers(run)(j) = newCenter

}

j += 1

}

if (!changed) {

active(run) = false

logInfo("Run " + run +" finished in " + (iteration +1) + " iterations")

}

costs(run) = costAccums(i).value

}


activeRuns = activeRuns.filter(active(_))

iteration += 1

}


val iterationTimeInSeconds = (System.nanoTime() - iterationStartTime) /1e9

logInfo(s"Iterations took " +"%.3f".format(iterationTimeInSeconds) +" seconds.")


if (iteration == maxIterations) {

logInfo(s"KMeans reached the max number of iterations: $maxIterations.")

} else {

logInfo(s"KMeans converged in $iteration iterations.")

}


val (minCost, bestRun) = costs.zipWithIndex.min


logInfo(s"The cost for the best run is $minCost.")


new KMeansModel(centers(bestRun).map(_.vector))

}

//findClosest方法:找到点与所有聚类中心最近的一个中心

/**

* Returns the index of the closest center to the given point, as well as the squared distance.

*/

private[mllib]def findClosest(

centers: TraversableOnce[VectorWithNorm],

point: VectorWithNorm): (Int, Double) = {

var bestDistance = Double.PositiveInfinity

var bestIndex =0

var i =0

centers.foreach { center =>

// Since `\|a - b\| \geq |\|a\| - \|b\||`, we can use this lower bound to avoid unnecessary

// distance computation.

var lowerBoundOfSqDist = center.norm - point.norm

lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist

if (lowerBoundOfSqDist < bestDistance) {

val distance: Double = fastSquaredDistance(center, point)

if (distance < bestDistance) {

bestDistance = distance

bestIndex = i

}

}

i += 1

}

(bestIndex, bestDistance)

}
findClosest方法中:var lowerBoundOfSqDist = center.norm - point.norm

lowerBoundOfSqDist = lowerBoundOfSqDist * lowerBoundOfSqDist

如果中心点center是(a1,b1),需要计算的点point是(a2,b2),那么lowerBoundOfSqDist是:




如下是展开式,第二个是真正计算欧式距离时的除去开平方的公式。(在查找最短距离的时候无需计算开方,因为只需要计算出开方里面的式子就可以进行比较了,mllib也是这样做的)



可轻易证明上面两式的第一式将会小于等于第二式,因此在进行距离比较的时候,先计算很容易计算的lowerBoundOfSqDist,如果lowerBoundOfSqDist都不小于之前计算得到的最小距离bestDistance,那真正的欧式距离也不可能小于bestDistance了,因此这种情况下就不需要去计算欧式距离,省去很多计算工作。

如果lowerBoundOfSqDist小于了bestDistance,则进行距离的计算,调用fastSquaredDistance,这个方法将调用MLUtils.scala里面的fastSquaredDistance方法,计算真正的欧式距离,代码如下:

/**

* Returns the squared Euclidean distance between two vectors. The following formula will be used

* if it does not introduce too much numerical error:

* <pre>

*   \|a - b\|_2^2 = \|a\|_2^2 + \|b\|_2^2 - 2 a^T b.

* </pre>

* When both vector norms are given, this is faster than computing the squared distance directly,

* especially when one of the vectors is a sparse vector.

*

* @param v1 the first vector

* @param norm1 the norm of the first vector, non-negative

* @param v2 the second vector

* @param norm2 the norm of the second vector, non-negative

* @param precision desired relative precision for the squared distance

* @return squared distance between v1 and v2 within the specified precision

*/

private[mllib]def fastSquaredDistance(

v1: Vector,

norm1: Double,

v2: Vector,

     norm2: Double,

precision: Double = 1e-6): Double = {

val n = v1.size

require(v2.size == n)

require(norm1 >= 0.0 && norm2 >=0.0)

val sumSquaredNorm = norm1 * norm1 + norm2 * norm2

val normDiff = norm1 - norm2

var sqDist =0.0

/*

* The relative error is

* <pre>

* EPSILON * ( \|a\|_2^2 + \|b\\_2^2 + 2 |a^T b|) / ( \|a - b\|_2^2 ),

* </pre>

* which is bounded by

* <pre>

* 2.0 * EPSILON * ( \|a\|_2^2 + \|b\|_2^2 ) / ( (\|a\|_2 - \|b\|_2)^2 ).

* </pre>

* The bound doesn't need the inner product, so we can use it as a sufficient condition to

* check quickly whether the inner product approach is accurate.

*/

val precisionBound1 =2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON)

if (precisionBound1 < precision) {

sqDist = sumSquaredNorm - 2.0 * dot(v1, v2)

} elseif (v1.isInstanceOf[SparseVector] || v2.isInstanceOf[SparseVector]) {

val dotValue = dot(v1, v2)

sqDist = math.max(sumSquaredNorm - 2.0 * dotValue,0.0)

val precisionBound2 = EPSILON * (sumSquaredNorm +2.0 * math.abs(dotValue)) /

(sqDist + EPSILON)

if (precisionBound2 > precision) {

sqDist = Vectors.sqdist(v1, v2)

}

} else {

sqDist = Vectors.sqdist(v1, v2)

}

sqDist

}

fastSquaredDistance方法会先计算一个精度,有关精度的计算val precisionBound1 = 2.0 * EPSILON * sumSquaredNorm / (normDiff * normDiff + EPSILON),如果在精度满足条件的情况下,欧式距离sqDist = sumSquaredNorm - 2.0 * v1.dot(v2),sumSquaredNorm即为

,2.0 * v1.dot(v2)即为

。这也是之前将norm计算出来的好处。如果精度不满足要求,则进行原始的距离计算公式了

,即调用Vectors.sqdist(v1, v2)。

1.3 Mllib KMeans实例

1、数据

数据格式为:特征1 特征2 特征3
0.0 0.0 0.0
0.1 0.1 0.1
0.2 0.2 0.2
9.0 9.0 9.0
9.1 9.1 9.1
9.2 9.2 9.2
2、代码

//1读取样本数据

valdata_path ="/home/jb-huangmeiling/kmeans_data.txt"

valdata =sc.textFile(data_path)

valexamples =data.map { line =>

Vectors.dense(line.split(' ').map(_.toDouble))

}.cache()

valnumExamples =examples.count()

println(s"numExamples = $numExamples.")

//2建立模型

valk =2

valmaxIterations =20

valruns =2

valinitializationMode ="k-means||"

valmodel = KMeans.train(examples,k, maxIterations,runs, initializationMode)

//3计算测试误差

valcost =model.computeCost(examples)

println(s"Total cost = $cost.")

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
懒得打字嘛,点击右侧快捷回复 【右侧内容,后台自定义】
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Unity开发者联盟 ( 粤ICP备20003399号 )

GMT+8, 2025-6-1 04:21 , Processed in 0.342569 second(s), 26 queries .

Powered by Discuz! X3.5 Licensed

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表