找回密码
 立即注册
查看: 234|回复: 0

matlab中kmeans聚类算法

[复制链接]
发表于 2022-4-14 09:43 | 显示全部楼层 |阅读模式
  1.                     版权声明:本文为博主原创文章,未经博主允许不得转载。                        https://blog.csdn.net/xholes/article/details/52911781                    </div>
  2.                                                 <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-cd6c485e8b.css">
  3.                                     <link rel="stylesheet" href="https://csdnimg.cn/release/phoenix/template/css/ck_htmledit_views-cd6c485e8b.css">
  4.             <div class="htmledit_views" id="content_views">
  5.                                         <p>本文介绍了K-means聚类算法,并注释了部分matlab实现的源码。<br></p><h2><a name="t0"></a>K-means算法</h2><p><span></span>K-means算法是一种硬聚类算法,根据数据到聚类中心的某种距离来作为判别该数据所属类别。K-means算法以距离作为相似度测度。</p><p>假设将对象数据集分为个不同的类,k均值聚类算法步骤如下:</p><p>Step1:随机从对象集中抽取个对象作为初始聚类中心;</p><p></p><p>Step2:对于所有的对象,分别计算其到各个聚类中的欧氏距离,相互比较后将其归属于距离最小的那一类;</p><p>Step3:根据step2得到的初始分类,对每个类别计算均值用来更新聚类中心;</p><p>Step4:根据新的聚类中心,重复进行step2和step3,直至满足算法终止条件。</p><p><span></span>K-means算法是基于划分的思想,因此算法易于理解且实现方法简单易行,但需要人工选择初始的聚类数目即算法是带参数的。类的数目确定往往非常复杂和具有不确定性,因此需要专业的知识和行业经验才能较好的确定。而且因为初始聚类中心的选择是随机的,因此会造成部分初始聚类中心相似或者处于数据边缘,造成算法的迭代次数明显增加,甚至会因为个别数据而造成聚类失败的现象。</p><p>其流程图大致如下:</p>                                                   <br><p><br></p><h2><a name="t1"></a>matlab源码</h2><p></p><pre><code class="language-ruby">function varargout = kmeans(X, k, varargin)
复制代码

%K均值聚类.
% IDX = KMEANS(X, K) 分割X[N P]的数据矩阵中的样本为K个类,是一种最小化类内点到中心距离和的总和的分割。
% 矩阵X中的行对应的是数据样本,列对应的是变量。
% 提示: 当X是一个向量,本函数会忽略它的方向,将其当作一个[N 1]的数据矩阵。
% KMEANS 函数返回一个代表各个数据样本所属类别索引的[N 1]维向量,函数默认使用平方的欧氏距离。
% KMEANS 将NaNs当作丢失的数据并且忽略X中任何包含NaNs的行
%
%
% [IDX, C] = KMEANS(X, K) 返回一个包含K个聚类中心的[K P]维的矩阵C.
%
% [IDX, C, SUMD] = KMEANS(X, K) 返回一个类间点到聚类中心距离和的[K 1]维向量SUMD。
%
% [IDX, C, SUMD, D] = KMEANS(X, K) 返回一个每个点到任一聚类中心距离的[N K]维矩阵D。
%
% [ … ] = KMEANS(…, ‘PARAM1’,val1, ‘PARAM2’,val2, …) 指定了可选参数对(参数名/参数值)来控制算法的迭代。
% 参数如下:
%
% ‘Distance’ - 距离测度, P维空间, KMEANS算法需要最小化的值
% 可以选择:
% ‘sqeuclidean’ - 平方的欧氏距离 (默认)
% ‘cityblock’ - 曼哈顿距离,各维度差异的绝对值之和。
% ‘cosine’ - 1减去两个样本(当作向量)夹角的余弦值
% ‘correlation’ - 1减去两个样本(当作值的序列)的相关系数
%
% ‘hamming’ - 汉明距离,二进制数据相匹配位置的不同比特百分比。
%
% ‘Start’ - 选择初始聚类中心的方法,有时候也称作种子。
% 可以选择:
% ‘plus’ - 默认值。 利用k-means++算法从X中选择K个观测值:从X中随机的选取第一个聚类中心;之后的
% 聚类中心以一定的概率从剩余的样本中根据其到最近的聚类中心的比例来随机的选取。
% ‘sample’ - 随机的从X中选取K个观测值。
% ‘uniform’ - 根据X的取值范围均匀的随机选取K个样本,对汉明距离不适用。
% ‘cluster’ - 随机的利用X中10%的样本进行一个预聚类的阶段,预聚类阶段的初始聚类中心选取采用‘sample’。
% matrix - 一个初始聚类中心的[K P]维矩阵。此时,你可以用[]代替K,算法会自动的根据矩阵的第一个维度推算K值。
% 你也可以使用3D数组,暗含着第三维为参数’Replicates’的值。
%
% ‘Replicates’ - 重复聚类的次数,默认为1。 每次都会有一个新的初始聚类中心。
%
% ‘EmptyAction’ - 发生空类时的处理措施。
% 可以选择:
% ‘singleton’ - 默认方法。利用据该中心最远的一个观测值建立一个新的类。
% ‘error’ - 将产生空类作为一个错误(error)。
% ‘drop’ - 移除空类并将对应的C和D中的值设置为NaN。
%
%
% ‘Options’ - 迭代算法最小化拟合准则(?)的选项,通过STATSET创建。 Choices of STATSET
% STATSET参数可以选择:
%
% ‘Display’ - 显示输出的哪一阶段的值,可以为 ‘off’(默认),‘iter’和‘final’;
% ‘MaxIter’ - 最大的迭代次数,默认值为100。
%
% ‘UseParallel’ - 在满足条件下,如果为真则开启并行计算否则使用串行模式。默认使用串行模式。
% ‘UseSubstreams’ - 默认不使用。
% ‘Streams’ - 这些区域指明是否执行并行的多个‘Start’值和当产生初始聚类中心时如何使用随机数值,
% 更详细的参考 PARALLELSTATS。
% 提示: 如果 'UseParallel’为TRUE且 'UseSubstreams’为FALSE,
% 那么’Streams’的长度必须等于KMEANS使用的workers的数目。
% 如果打开了并行池,那么它的大小和并行池一样。如果没有打开并行池,
% 那么MATLAB可能会自动的打开(这取决于你的安装设置)。为了得到更好的结果,
% 建议运用PARPOOL命令创建并行池的优先级以便当’UseParallel’为TRUE时执行算法。
%
% ‘OnlinePhase’ - 标志位,表示KMEANS是否除了运行一个"batch update"阶段还需一个"on-line
% update"阶段 。on-line阶段在大数据量时耗时很多。默认为‘off’。
%
% 示例:
%
% X = [randn(20,2)+ones(20,2); randn(20,2)-ones(20,2)];
% opts = statset(‘Display’,‘final’);
% [cidx, ctrs] = kmeans(X, 2, ‘Distance’,‘city’, …
% ‘Replicates’,5, ‘Options’,opts);
% plot(X(cidx==1,1),X(cidx==1,2),‘r.’, …
% X(cidx==2,1),X(cidx==2,2),‘b.’, ctrs(:,1),ctrs(:,2),‘kx’);
%
% 也可以参考LINKAGE, CLUSTERDATA, SILHOUETTE。
% KMEANS 运用两阶段迭代算法来最小化K个类中样本到中心的距离和。
% 第一阶段利用文献中经常描述的"batch" 更新, 其中每次迭代中都一
% 次性地将样本分配到最近的聚类中心,然后更新聚类中心。这一阶段
% 偶尔(特别实在小样本的时候)会陷入局部最优。因此,"batch"阶段可
% 以考虑为第二阶段提供一个快速且可能为近似解的初始聚类中心。第二
% 阶段利用文献中常提及的"on-line"更新, 其中。如果能够减小距离
% 的总和那么其中的样本点都是单独地重新分配且每次分配后都重新计算
% 聚类中心。第二阶段中的每次迭代都会遍历所有的点,但是on-line阶段会收
% 敛到一个局部最小值。寻找全局最优的问题一般只能通过详细(幸运)地选择初始
% 聚类中心,但是使用重复多次的使用随机初始聚类中心中的典型结果是一个全局最小。
%
% 参考文献:
%
% [1] Seber, G.A.F. (1984) Multivariate Observations, Wiley, New York.
% [2] Spath, H. (1985) Cluster Dissection andAnalysis: Theory, FORTRAN
% Programs, Examples, translated by J. Goldschmidt, Halsted Press,
% New York.
%判断输入变量是否少于两个
if nargin < 2
error(message(‘stats:kmeans:TooFewInputs’));
end
%判断X是否是实数矩阵;
if ~isreal(X)
error(message(‘stats:kmeans:ComplexData’));
end
%查找是否有NaN数据,有的话就删除,更新X矩阵;
wasnan = any(isnan(X),2);
hadNaNs = any(wasnan);
if hadNaNs
warning(message(‘stats:kmeans:MissingDataRemoved’));
X = X(~wasnan,:);
end
% 获取X矩阵的维数
[n, p] = size(X);
%参数名与默认参数值设置
pnames = { ‘distance’‘start’‘replicates’‘emptyaction’‘onlinephase’‘options’‘maxiter’‘display’};
dflts = {‘sqeuclidean’‘plus’ [] ‘singleton’‘off’ [] [] []};
[distance,start,reps,emptyact,online,options,maxit,display] …
= internal.stats.parseArgs(pnames, dflts, varargin{:});
distNames = {‘sqeuclidean’,‘cityblock’,‘cosine’,‘correlation’,‘hamming’};
distance = internal.stats.getParamVal(distance,distNames,’’‘Distance’’’);
switch distance
case‘cosine’
Xnorm = sqrt(sum(X.^2, 2));%模长
if any(min(Xnorm) <= eps(max(Xnorm)))
error(message(‘stats:kmeans:ZeroDataForCos’));
end
X = bsxfun(@rdivide,X,Xnorm);%标准化
case‘correlation’
X = bsxfun(@minus, X, mean(X,2));
Xnorm = sqrt(sum(X.^2, 2));
if any(min(Xnorm) <= eps(max(Xnorm)))
error(message(‘stats:kmeans:ConstantDataForCorr’));
end
X = bsxfun(@rdivide,X,Xnorm);
case‘hamming’
if ~all( X(:) ==0| X(:)==1)
error(message(‘stats:kmeans:NonbinaryDataForHamm’));
end
end
Xmins = [];
Xmaxs = [];
CC = [];
if ischar(start)
startNames = {‘uniform’,‘sample’,‘cluster’,‘plus’,‘kmeans++’};
j = find(strncmpi(start,startNames,length(start)));
if length(j) > 1
error(message(‘stats:kmeans:AmbiguousStart’, start));
elseif isempty(j)
error(message(‘stats:kmeans:UnknownStart’, start));
elseif isempty(k)
error(message(‘stats:kmeans:MissingK’));
end
start = startNames{j};
if strcmp(start, ‘uniform’)
if strcmp(distance, ‘hamming’)
error(message(‘stats:kmeans:UniformStartForHamm’));
end
Xmins = min(X,[],1);%求每一列的最小值
Xmaxs = max(X,[],1);%求每一列的最大值
end
elseif isnumeric(start) %如果初始中心是数值类型(numeric)
CC = start;
start = ‘numeric’;
if isempty(k)
k = size(CC,1);%如果K为空通过数值的初始聚类中心获取K值
elseif k ~= size(CC,1);%检测初始聚类中心行是否合法
error(message(‘stats:kmeans:StartBadRowSize’));
elseif size(CC,2) ~= p %检测初始聚类中心列是否合法
error(message(‘stats:kmeans:StartBadColumnSize’));
end
if isempty(reps)
reps = size(CC,3);%如果重复次数参数为空,检测初始聚类中心的第三维获取
elseif reps ~= size(CC,3);
error(message(‘stats:kmeans:StartBadThirdDimSize’));
end
  1. % Need to center explicit starting points <span class="hljs-keyword">for</span> 'correlation'. (Re)normalization
  2. % <span class="hljs-keyword">for</span> 'cosine'/'correlation' is done at each iteration.
  3. <span class="hljs-keyword">if</span> isequal(distance, 'correlation')
  4.       CC = bsxfun(@minus, CC, mean(CC,2));%如果距离测度为相关性需要中心化数据
  5. <span class="hljs-keyword">end</span>
复制代码
else
error(message(‘stats:kmeans:InvalidStart’));
end
emptyactNames = {‘error’,‘drop’,‘singleton’};
emptyact = internal.stats.getParamVal(emptyact,emptyactNames,’’‘EmptyAction’’’);
[~,online] = internal.stats.getParamVal(online,{‘on’,‘off’},’’‘OnlinePhase’’’);
online = (online==1);
% ‘maxiter’ and ‘display’ are grandfathered as separate param name/value pairs
if ~isempty(display)
options = statset(options,‘Display’,display);
end
if ~isempty(maxit)
options = statset(options,‘MaxIter’,maxit);
end
options = statset(statset(‘kmeans’), options);
display = find(strncmpi(options.Display, {‘off’,‘notify’,‘final’,‘iter’},…
length(options.Display))) - 1;
maxit = options.MaxIter;%确定最大迭代次数
if ~(isscalar(k) && isnumeric(k) && isreal(k) && k > 0 && (round(k)==k))
error(message(‘stats:kmeans:InvalidK’));
% elseif k == 1
% this special case works automatically
elseif n < k
error(message(‘stats:kmeans:TooManyClusters’));
end
% Assume one replicate 检测重复次数的值
if isempty(reps)
reps = 1;
elseif ~internal.stats.isScalarInt(reps,1)
error(message(‘stats:kmeans:BadReps’));
end
[useParallel, RNGscheme, poolsz] = …
internal.stats.parallel.processParallelAndStreamOptions(options,true);
usePool = useParallel && poolsz>0;%检测是否使用并行池
% Prepare forin-progress
if display > 1 % ‘iter’ or ‘final’
if usePool
% If we are running on a parallel pool, each worker will generate
% a separate periodic report. Before starting the loop, we
% seed the parallel pool so that each worker will have an
% identifying label (eg, index) for its report.
internal.stats.parallel.distributeToPool( …
‘workerID’, num2cell(1:poolsz) );
  1.     % Periodic reports behave differently <span class="hljs-keyword">in</span> parallel than they <span class="hljs-keyword">do</span>
  2.     % <span class="hljs-keyword">in</span> serial computation (which is the baseline).
  3.     % We advise the user of the difference.
  4.    
  5.     <span class="hljs-keyword">if</span> display == 3 % 'iter' only
  6.         warning(message('stats:kmeans:displayParallel2'));
  7.         fprintf('    worker\t  iter\t phase\t     num\t         sum\n' );
  8.     <span class="hljs-keyword">end</span>
  9. <span class="hljs-keyword">else</span>
  10.     <span class="hljs-keyword">if</span> useParallel
  11.         warning(message('stats:kmeans:displayParallel'));
  12.     <span class="hljs-keyword">end</span>
  13.     <span class="hljs-keyword">if</span> display == 3 % 'iter' only
  14.         fprintf('  iter\t phase\t     num\t         sum\n');
  15.     <span class="hljs-keyword">end</span>
  16. <span class="hljs-keyword">end</span>
复制代码
end
if issparse(X) || ~isfloat(X) || strcmp(distance,‘cityblock’) || …
strcmp(distance,‘hamming’)
[varargout{1:nargout}] = kmeans2(X,k, distance, emptyact,reps,start,…
Xmins,Xmaxs,CC,online,display, maxit,useParallel, RNGscheme,usePool,…
wasnan,hadNaNs,varargin{:});
return;
end
emptyErrCnt = 0;
% Define the function that will perform one iteration of the
% loop inside smartFor
loopbody = @loopBody;%定义循环体函数
% Initialize nested variables so they will not appear to be functions here
%初始化循环嵌套变量
totsumD = 0;
iter = 0;
%将数据转置
X = X’;
Xmins = Xmins’;
Xmaxs = Xmaxs’;
% 执行KMEANS多次(reps)在各自的工作区上.
ClusterBest = internal.stats.parallel.smartForReduce(…
reps, loopbody, useParallel, RNGscheme, ‘argmin’);
% 选出最优解
varargout{1} = ClusterBest{5};%最优解的索引idx
varargout{2} = ClusterBest{6}’;%最优解的聚类中心C
varargout{3} = ClusterBest{3}; %最优解的类内距离和sumD
totsumDbest = ClusterBest{1};%最优解的所有类内距离和的总和
if nargout > 3
varargout{4} = ClusterBest{7}; %最优解的点到任意聚类中心的距离
end
if display > 1 % ‘final’ or ‘iter’
fprintf(’%s\n’,getString(message(‘stats:kmeans:FinalSumOfDistances’,sprintf(’%g’,totsumDbest))));
end
if hadNaNs
varargout{1} = statinsertnan(wasnan, varargout{1});% idxbest
if nargout > 3
varargout{4} = statinsertnan(wasnan, varargout{4}); %Dbest
end
end
  1. function cellout = loopBody(rep,S)%循环体函数
  2.    
  3.     <span class="hljs-keyword">if</span> isempty(S)
  4.         S = RandStream.getGlobalStream;
  5.     <span class="hljs-keyword">end</span>
  6.    
  7.     <span class="hljs-keyword">if</span> display > 1 % 'iter'
  8.         <span class="hljs-keyword">if</span> usePool
  9.             dispfmt = '%8d\t%6d\t%6d\t%8d\t%12g\n';
  10.             labindx = internal.stats.parallel.workerGetValue('workerID');
  11.         <span class="hljs-keyword">else</span>
  12.             dispfmt = '%6d\t%6d\t%8d\t%12g\n';
  13.         <span class="hljs-keyword">end</span>
  14.     <span class="hljs-keyword">end</span>
  15.     %定义元胞数组
  16.     cellout = cell(7,1);  % cellout{1}类间距离总和
  17.                           % cellout{2}重复次数
  18.                           % cellout{3}类内距离总和
  19.                           % cellout{4}迭代次数
  20.                           % cellout{5}索引
  21.                           % cellout{6}聚类中心
  22.                           % cellout{7}距离
  23.    
  24.     % Populating total sum of distances to Inf. This is used <span class="hljs-keyword">in</span> the
  25.     % reduce operation <span class="hljs-keyword">if</span> update fails due to empty cluster.
  26.     cellout{1} = Inf;%赋值
  27.     cellout{2} = rep;
  28.     %初始化聚类中心
  29.     switch start
  30.         <span class="hljs-keyword">case</span> 'uniform'
  31.             %C = Xmins(:,ones(1,k)) + rand(S,[p,k]).*(Xmaxs(:,ones(1,k))-Xmins(:,ones(1,k)));
  32.             C = Xmins(:,ones(1,k)) + rand(S,[k,p])'.*(Xmaxs(:,ones(1,k))-Xmins(:,ones(1,k)));
  33.             % For 'cosine' <span class="hljs-keyword">and</span> 'correlation', these are uniform inside a subset
  34.             % of the unit hypersphere.仍需要为'correlation'进行中心化.  
  35.             %  'cosine'/'correlation'的正交化在每次迭代中完成
  36.             <span class="hljs-keyword">if</span> isequal(distance, 'correlation')
  37.                 C = bsxfun(@minus, C, mean(C,1));
  38.             <span class="hljs-keyword">end</span>
  39.             <span class="hljs-keyword">if</span> isa(X,'single')
  40.                 C = single(C);
  41.             <span class="hljs-keyword">end</span>
  42.         <span class="hljs-keyword">case</span> 'sample'
  43.             C = X(:,randsample(S,n,k));
  44.         <span class="hljs-keyword">case</span> 'cluster'
  45.             Xsubset = X(:,randsample(S,n,floor(.1*n)));
  46.             % Turn display off <span class="hljs-keyword">for</span> the initialization
  47.             optIndex = find(strcmpi('options',varargin));
  48.             <span class="hljs-keyword">if</span> isempty(optIndex)
  49.                 opts = statset('Display','off');
  50.                 varargin = [varargin,'options',opts];
  51.             <span class="hljs-keyword">else</span>
  52.                 varargin{optIndex+1}.Display = 'off';
  53.             <span class="hljs-keyword">end</span>
  54.             [~, C] = kmeans(Xsubset', k, varargin{:}, 'start','sample', 'replicates',1);
  55.             C = C';
  56.         <span class="hljs-keyword">case</span> 'numeric'
  57.             C = CC(:,:,rep)';
  58.             <span class="hljs-keyword">if</span> isa(X,'single')
  59.                 C = single(C);
  60.             <span class="hljs-keyword">end</span>
  61.         <span class="hljs-keyword">case</span> {'plus','kmeans++'}
  62.             % Select the first seed by sampling uniformly at random
  63.             index = zeros(1,k);
  64.             [C(:,1), index(1)] = datasample(S,X,1,2);
  65.             minDist = inf(n,1);
  66.       
  67.             % Select the rest of the seeds by a probabilistic model
  68.            <span class="hljs-keyword">for</span> ii = 2:k                    
  69.                 minDist = min(minDist,distfun(X,C(:,ii-1),distance));
  70.                 denominator = sum(minDist);
  71.                 <span class="hljs-keyword">if</span> denominator==0 |</span><span class="hljs-params">| isinf(denominator) |</span><span class="hljs-params">| isnan(denominator)
  72.                     C(:,ii:k) = datasample(S,X,k-ii+1,2,'Replace',<span class="hljs-literal">false</span>);
  73.                     <span class="hljs-keyword">break</span>;
  74.                 <span class="hljs-keyword">end</span>
  75.                 sampleProbability = minDist/denominator;
  76.                 [C(:,ii), index(ii)] = datasample(S,X,1,2,'Replace',<span class="hljs-literal">false</span>,...
  77.                     'Weights',sampleProbability);        
  78.             <span class="hljs-keyword">end</span>
  79.     <span class="hljs-keyword">end</span>
  80.     <span class="hljs-keyword">if</span> ~isfloat(C)      % X may be logical
  81.         C = double(C);
  82.     <span class="hljs-keyword">end</span>
  83.    
  84.     % 计算点到聚类中心的距离和归属到各个类别
  85.     D = distfun(X, C, distance, 0, rep, reps);%计算点到个中心的距离
  86.     [d, idx] = min(D, [], 2);%根据最短距离归属到各个类
  87.     m = accumarray(idx,1,[k,1])';%计算各个类中样本的个数
  88.    
  89.     try % catch空类错误并转移到下一个重复次
  90.         
  91.         %开始第一阶段:批分配
  92.         converged = batchUpdate();
  93.         
  94.         % 开始第二阶段:单个分配
  95.         <span class="hljs-keyword">if</span> online
  96.             converged = onlineUpdate();
  97.         <span class="hljs-keyword">end</span>
  98.         
  99.         
  100.         <span class="hljs-keyword">if</span> display == 2 % 'final'
  101.             fprintf('%s\n',getString(message('stats:kmeans:IterationsSumOfDistances',rep,iter,sprintf('%g',totsumD) )));
  102.         <span class="hljs-keyword">end</span>
  103.         
  104.         <span class="hljs-keyword">if</span> ~converged
  105.             <span class="hljs-keyword">if</span> reps==1
  106.                 warning(message('stats:kmeans:FailedToConverge', maxit));
  107.             <span class="hljs-keyword">else</span>
  108.                 warning(message('stats:kmeans:FailedToConvergeRep', maxit, rep));
  109.             <span class="hljs-keyword">end</span>
  110.         <span class="hljs-keyword">end</span>
  111.         
  112.         % 计算类内距离和
  113.         nonempties = find(m>0);%判断没有空类,生成非空类的线性目录
  114.         D(:,nonempties) = distfun(X, C(:,nonempties), distance, iter, rep, reps);
  115.         d = D((idx-1)*n + (1:n)');
  116.         sumD = accumarray(idx,d,[k,1]);% 计算类内距离和
  117.         totsumD = sum(sumD(nonempties));% 计算所有类内距离和的总和
  118.         
  119.         % 保存目前最好的解
  120.         cellout = {totsumD,rep,sumD,iter,idx,C,D}';
  121.       
  122.         % 如果在重复运行中发生空类现象,进行捕获并警告,然后继续下一次重复运行,
  123.         %  只有在所有的重复运行失败才会ERROR,再次引发另一种ERROR。
  124.     catch ME
  125.         <span class="hljs-keyword">if</span> reps == 1 |</span><span class="hljs-params">| (~isequal(ME.identifier,'stats:kmeans:EmptyCluster')  && ...
  126.                      ~isequal(ME.identifier,'stats:kmeans:EmptyClusterRep'))
  127.             rethrow(ME);
  128.         <span class="hljs-keyword">else</span>
  129.             emptyErrCnt = emptyErrCnt + 1;
  130.             warning(message('stats:kmeans:EmptyClusterInBatchUpdate', rep, iter));
  131.             <span class="hljs-keyword">if</span> emptyErrCnt == reps
  132.                 error(message('stats:kmeans:EmptyClusterAllReps'));
  133.             <span class="hljs-keyword">end</span>
  134.         <span class="hljs-keyword">end</span>
  135.     <span class="hljs-keyword">end</span> % catch
  136.    
  137.     %------------------------------------------------------------------
  138.    
  139.     function converged = batchUpdate()
  140.         
  141.         % 遍历每个点,更新每个类
  142.         moved = 1:n;
  143.         changed = 1:k;
  144.         previdx = zeros(n,1);
  145.         prevtotsumD = Inf;
  146.         
  147.         %
  148.         % 开始第一阶段
  149.         %
  150.         
  151.         iter = 0;
  152.         converged = <span class="hljs-literal">false</span>;
  153.         <span class="hljs-keyword">while</span> <span class="hljs-literal">true</span>
  154.             iter = iter + 1;
  155.             
  156.             % 更新新的聚类中心和数目以及每个样本到新聚类中心的距离
  157.             [C(:,changed), m(changed)] = gcentroids(X, idx, changed, distance);
  158.             D(:,changed) = distfun(X, C(:,changed), distance, iter, rep, reps);
  159.             
  160.             %处理空类
  161.             empties = changed(m(changed) == 0);
  162.             <span class="hljs-keyword">if</span> ~isempty(empties)
  163.                 <span class="hljs-keyword">if</span> strcmp(emptyact,'error')
  164.                     <span class="hljs-keyword">if</span> reps==1
  165.                         error(message('stats:kmeans:EmptyCluster', iter));
  166.                     <span class="hljs-keyword">else</span>
  167.                         error(message('stats:kmeans:EmptyClusterRep', iter, rep));
  168.                     <span class="hljs-keyword">end</span>
  169.                 <span class="hljs-keyword">end</span>
  170.                 switch emptyact
  171.                     <span class="hljs-keyword">case</span> 'drop'
  172.                         <span class="hljs-keyword">if</span> reps==1
  173.                             warning(message('stats:kmeans:EmptyCluster', iter));
  174.                         <span class="hljs-keyword">else</span>
  175.                             warning(message('stats:kmeans:EmptyClusterRep', iter, rep));
  176.                         <span class="hljs-keyword">end</span>
  177.                         % Remove the empty cluster from any further processing
  178.                         D(:,empties) = NaN;
  179.                         changed = changed(m(changed) > 0);
  180.                     <span class="hljs-keyword">case</span> 'singleton'
  181.                         <span class="hljs-keyword">for</span> i = empties
  182.                             d = D((idx-1)*n + (1:n)'); % use newly updated distances
  183.                            
  184.                             % 选取一个距离当前类最远的样本作为一个新的类
  185.                             [~, lonely] = max(d);
  186.                             from = idx(lonely); % taking from this cluster
  187.                             <span class="hljs-keyword">if</span> m(from) < 2
  188.                                 % In the very unusual event that the cluster had only
  189.                                 % one member, pick any other non-singleton point.
  190.                                 from = find(m>1,1,'first');
  191.                                 lonely = find(idx==from,1,'first');
  192.                             <span class="hljs-keyword">end</span>
  193.                             C(:,i) = X(:,lonely);
  194.                             m(i) = 1;
  195.                             idx(lonely) = i;
  196.                             D(:,i) = distfun(X, C(:,i), distance, iter, rep, reps);
  197.                            
  198.                             % Update clusters from which points are taken
  199.                             [C(:,from), m(from)] = gcentroids(X, idx, from, distance);
  200.                             D(:,from) = distfun(X, C(:,from), distance, iter, rep, reps);
  201.                             changed = unique([changed from]);
  202.                         <span class="hljs-keyword">end</span>
  203.                 <span class="hljs-keyword">end</span>
  204.             <span class="hljs-keyword">end</span>
  205.             
  206.             % 在当前配置下计算总距离
  207.             totsumD = sum(D((idx-1)*n + (1:n)'));
  208.             % 循环测试: 如果目标为减少,返回出去
  209.             % 最后一步,之后进行单个更新阶段
  210.             <span class="hljs-keyword">if</span> prevtotsumD <= totsumD
  211.                 idx = previdx;
  212.                 [C(:,changed), m(changed)] = gcentroids(X, idx, changed, distance);
  213.                 iter = iter - 1;
  214.                 <span class="hljs-keyword">break</span>;
  215.             <span class="hljs-keyword">end</span>
  216.             <span class="hljs-keyword">if</span> display > 2 % 'iter'
  217.                 <span class="hljs-keyword">if</span> usePool
  218.                     fprintf(dispfmt,labindx,iter,1,length(moved),totsumD);
  219.                 <span class="hljs-keyword">else</span>
  220.                     fprintf(dispfmt,iter,1,length(moved),totsumD);
  221.                 <span class="hljs-keyword">end</span>
  222.             <span class="hljs-keyword">end</span>
  223.             <span class="hljs-keyword">if</span> iter >= maxit
  224.                 <span class="hljs-keyword">break</span>;
  225.             <span class="hljs-keyword">end</span>
  226.             
  227.             %对每个点根据就近原则归属到各自的类
  228.             previdx = idx;
  229.             prevtotsumD = totsumD;
  230.             [d, nidx] = min(D, [], 2);
  231.             
  232.             % 决定哪个样本点移动
  233.             moved = find(nidx ~= previdx);
  234.             <span class="hljs-keyword">if</span> ~isempty(moved)
  235.                 % Resolve ties <span class="hljs-keyword">in</span> favor of <span class="hljs-keyword">not</span> moving
  236.                 moved = moved(D((previdx(moved)-1)*n + moved) > d(moved));
  237.             <span class="hljs-keyword">end</span>
  238.             <span class="hljs-keyword">if</span> isempty(moved)
  239.                 converged = <span class="hljs-literal">true</span>;
  240.                 <span class="hljs-keyword">break</span>;
  241.             <span class="hljs-keyword">end</span>
  242.             idx(moved) = nidx(moved);
  243.             
  244.             % 寻找得到或者失去样本点的类
  245.             changed = unique([idx(moved); previdx(moved)])';
  246.             
  247.         <span class="hljs-keyword">end</span> % phase one
  248.         
  249.     <span class="hljs-keyword">end</span> % nested function
  250.    
  251.     %------------------------------------------------------------------
  252.    
  253.     function converged = onlineUpdate()
  254.                   
  255.         %
  256.         % 第二阶段开始: 单个分配
  257.         %
  258.         changed = find(m > 0);
  259.         lastmoved = 0;
  260.         nummoved = 0;
  261.         iter1 = iter;
  262.         converged = <span class="hljs-literal">false</span>;
  263.         Del = NaN(n,k); % 重新分配的准则
  264.         <span class="hljs-keyword">while</span> iter < maxit
  265.             %计算每个样本点到各个类的距离以及因每个类中添加或者移除样本点引起的误差和的变化
  266.             %没有发生变化的类并不用更新。仅含单个样本点的类是总距离计算中的特殊情况。
  267.             %移除它们仅有的样本点并不是最好的选择,根据分配准则最好保证它们能够得到保留,
  268.             %令人高兴地是,对于这种情况我们自动的使用Del(i,idx(i)) == 0。
  269.             switch distance
  270.                 <span class="hljs-keyword">case</span> 'sqeuclidean'
  271.                     <span class="hljs-keyword">for</span> i = changed
  272.                         mbrs = (idx == i)';
  273.                         sgn = 1 - 2*mbrs; % -1 <span class="hljs-keyword">for</span> members, 1 <span class="hljs-keyword">for</span> nonmembers
  274.                         <span class="hljs-keyword">if</span> m(i) == 1
  275.                             sgn(mbrs) = 0; % prevent divide-by-zero <span class="hljs-keyword">for</span> singleton mbrs
  276.                         <span class="hljs-keyword">end</span>
  277.                       Del(:,i) = (m(i) ./ (m(i) + sgn)) .* sum((bsxfun(@minus, X, C(:,i))).^2, 1);
  278.                     <span class="hljs-keyword">end</span>
  279.                   <span class="hljs-keyword">case</span> {'cosine','correlation'}
  280.                     % The points are normalized, centroids are <span class="hljs-keyword">not</span>, so normalize them
  281.                     normC = sqrt(sum(C.^2, 1));
  282.                     <span class="hljs-keyword">if</span> any(normC < eps(<span class="hljs-keyword">class</span>(normC))) % small relative to unit-length data points
  283.                         <span class="hljs-keyword">if</span> reps==1
  284.                             error(message('stats:kmeans:ZeroCentroid', iter));
  285.                         <span class="hljs-keyword">else</span>
  286.                             error(message('stats:kmeans:ZeroCentroidRep', iter, rep));
  287.                         <span class="hljs-keyword">end</span>
  288.                         
  289.                     <span class="hljs-keyword">end</span>
  290.                     % This can be done without a loop, but the loop saves memory allocations
  291.                     <span class="hljs-keyword">for</span> i = changed
  292.                         XCi =  C(:,i)'*X;
  293.                         mbrs = (idx == i)';
  294.                         sgn = 1 - 2*mbrs; % -1 <span class="hljs-keyword">for</span> members, 1 <span class="hljs-keyword">for</span> nonmembers
  295.                         Del(:,i) = 1 + sgn .*...
  296.                             (m(i).*normC(i) - sqrt((m(i).*normC(i)).^2 + 2.*sgn.*m(i).*XCi + 1));
  297.                     <span class="hljs-keyword">end</span>
  298.             <span class="hljs-keyword">end</span>
  299.             
  300.             % 对于任意一个样本点,确定可能是最好的移动方式。然后选择其中的一个进行移动
  301.             previdx = idx;
  302.             prevtotsumD = totsumD;
  303.             [minDel, nidx] = min(Del, [], 2);
  304.             moved = find(previdx ~= nidx);
  305.             moved(m(previdx(moved))==1)=[];
  306.             <span class="hljs-keyword">if</span> ~isempty(moved)
  307.                 % Resolve ties <span class="hljs-keyword">in</span> favor of <span class="hljs-keyword">not</span> moving
  308.                 moved = moved(Del((previdx(moved)-1)*n + moved) > minDel(moved));
  309.             <span class="hljs-keyword">end</span>
  310.             <span class="hljs-keyword">if</span> isempty(moved)
  311.                 % Count an iteration <span class="hljs-keyword">if</span> phase 2 did nothing at all, <span class="hljs-keyword">or</span> <span class="hljs-keyword">if</span> we're
  312.                 % <span class="hljs-keyword">in</span> the middle of a pass through all the points
  313.                 <span class="hljs-keyword">if</span> (iter == iter1) |</span><span class="hljs-params">| nummoved > 0
  314.                     iter = iter + 1;
  315.                     <span class="hljs-keyword">if</span> display > 2 % 'iter'
  316.                         <span class="hljs-keyword">if</span> usePool
  317.                             fprintf(dispfmt,labindx,iter,2,length(moved),totsumD);
  318.                         <span class="hljs-keyword">else</span>
  319.                             fprintf(dispfmt,iter,2,length(moved),totsumD);
  320.                         <span class="hljs-keyword">end</span>
  321.                     <span class="hljs-keyword">end</span>
  322.                 <span class="hljs-keyword">end</span>
  323.                 converged = <span class="hljs-literal">true</span>;
  324.                 <span class="hljs-keyword">break</span>;
  325.             <span class="hljs-keyword">end</span>
  326.             
  327.             % Pick the <span class="hljs-keyword">next</span> move <span class="hljs-keyword">in</span> cyclic order
  328.             %循环地选择下一次的移动
  329.             moved = mod(min(mod(moved - lastmoved - 1, n) + lastmoved), n) + 1;
  330.             
  331.             % 遍历完所有的样本点,则完成一次迭代
  332.             <span class="hljs-keyword">if</span> moved <= lastmoved
  333.                 iter = iter + 1;
  334.                 <span class="hljs-keyword">if</span> display > 2 % 'iter'
  335.                     <span class="hljs-keyword">if</span> usePool
  336.                         fprintf(dispfmt,labindx,iter,2,length(moved),totsumD);
  337.                     <span class="hljs-keyword">else</span>
  338.                         fprintf(dispfmt,iter,2,length(moved),totsumD);
  339.                     <span class="hljs-keyword">end</span>
  340.                 <span class="hljs-keyword">end</span>
  341.                 <span class="hljs-keyword">if</span> iter >= maxit, <span class="hljs-keyword">break</span>; <span class="hljs-keyword">end</span>
  342.                 nummoved = 0;
  343.             <span class="hljs-keyword">end</span>
  344.             nummoved = nummoved + 1;
  345.             lastmoved = moved;
  346.             
  347.             oidx = idx(moved);
  348.             nidx = nidx(moved);
  349.             totsumD = totsumD + Del(moved,nidx) - Del(moved,oidx);
  350.             
  351.             %更新类的索引向量、新旧类别各自的样本数目和中心
  352.             idx(moved) = nidx;
  353.             m(nidx) = m(nidx) + 1;
  354.             m(oidx) = m(oidx) - 1;
  355.             switch distance
  356.                 <span class="hljs-keyword">case</span> {'sqeuclidean','cosine','correlation'}
  357.                     C(:,nidx) = C(:,nidx) + (X(:,moved) - C(:,nidx)) / m(nidx);
  358.                     C(:,oidx) = C(:,oidx) - (X(:,moved) - C(:,oidx)) / m(oidx);
  359.             <span class="hljs-keyword">end</span>
  360.             changed = sort([oidx nidx]);
  361.         <span class="hljs-keyword">end</span> % phase two
  362.         
  363.     <span class="hljs-keyword">end</span> % nested function
  364.    
  365. <span class="hljs-keyword">end</span>
复制代码
end % main function
%------------------------------------------------------------------
function D = distfun(X, C, dist, iter,rep, reps)
%DISTFUN计算样本点到中心的距离
switch dist
case ‘sqeuclidean’
D = pdist2mex(X,C,‘sqe’,[],[],[]);
case {‘cosine’,‘correlation’}
% 样本点已被标准化而中心点没有,因此对它们进行标准化
normC = sqrt(sum(C.^2, 1));
if any(normC < eps(class(normC))) % small relative to unit-length data points(?)
if reps==1
error(message(‘stats:kmeans:ZeroCentroid’, iter));
else
error(message(‘stats:kmeans:ZeroCentroidRep’, iter, rep));
end
  1.     <span class="hljs-keyword">end</span>
  2.     C = bsxfun(@rdivide,C,normC);
  3.     D = pdist2mex(X,C,'cos',[],[],[]);
复制代码
end
end % function
%------------------------------------------------------------------
function [centroids, counts] = gcentroids(X, index, clusts, dist)
%GCENTROIDS Centroids and counts stratified by group.
%计算各类的样本数目和中心点
p = size(X,1);
num = length(clusts);
centroids = NaN(p,num,‘like’,X);
counts = zeros(1,num,‘like’,X);
for i = 1:num
members = (index == clusts(i));
if any(members)
counts(i) = sum(members);
switch dist
case {‘sqeuclidean’,‘cosine’,‘correlation’}
centroids(:,i) = sum(X(:,members),2) / counts(i);
end
end
end
end % function
Python 中的Kmeans



  • from sklearn.cluster      import KMeans   



  • import numpy      as np   



  •      X = np.array([[     1,      2], [     1,      4], [     1,      0], [     4,      2], [     4,      4], [     4,     0]])   



  •      kmeans=KMeans(n_clusters=     2,random_state=     0).fit(X)   




转载自: https://blog.csdn.net/xholes/article/details/52911781

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

×
懒得打字嘛,点击右侧快捷回复 【右侧内容,后台自定义】
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Unity开发者联盟 ( 粤ICP备20003399号 )

GMT+8, 2025-5-7 12:30 , Processed in 0.776298 second(s), 23 queries .

Powered by Discuz! X3.5 Licensed

© 2001-2025 Discuz! Team.

快速回复 返回顶部 返回列表