# class boundary preserving algorithm for data condensation

function [clustCent,data2cluster,cluster2dataCell] = msc(dataPts,bandWidth,plotFlag); % MeanShift Clustering of data %*** Check input **** if nargin < 2 error('no bandwidth specified') end if nargin < 3 plotFlag = true; plotFlag = false; end %**** Initialize stuff *** [numDim,numPts] = size(dataPts); numClust = 0; bandSq = bandWidth^2; initPtInds = 1:numPts; maxPos = max(dataPts,[],2); minPos = min(dataPts,[],2); boundBox = maxPos-minPos; sizeSpace = norm(boundBox); stopThresh = 1e-3*bandWidth; clustCent = []; beenVisitedFlag = zeros(1,numPts,'uint8'); numInitPts = numPts; clusterVotes = zeros(1,numPts,'uint16'); while numInitPts tempInd = ceil( (numInitPts-1e-6)*rand); stInd = initPtInds(tempInd); myMean = dataPts(:,stInd); myMembers = []; thisClusterVotes = zeros(1,numPts,'uint16'); while 1 %loop untill convergence sqDistToAll = sum((repmat(myMean,1,numPts) - dataPts).^2); inInds = find(sqDistToAll < bandSq); thisClusterVotes(inInds) = thisClusterVotes(inInds)+1; myOldMean = myMean; myMean = mean(dataPts(:,inInds),2); myMembers = [myMembers inInds]; beenVisitedFlag(myMembers) = 1; %*** plot stuff **** if plotFlag figure(12345),clf,hold on if numDim == 2 plot(dataPts(1,:),dataPts(2,:),'*') plot(dataPts(1,myMembers),dataPts(2,myMembers),'ys') plot(myMean(1),myMean(2),'go') plot(myOldMean(1),myOldMean(2),'rd') pause end end if norm(myMean-myOldMean) < stopThresh %check for merge posibilities mergeWith = 0; for cN = 1:numClust distToOther = norm(myMean-clustCent(:,cN)); if distToOther < bandWidth/2 mergeWith = cN; break; end end if mergeWith > 0 % something to merge clustCent(:,mergeWith) = 0.5*(myMean+clustCent(:,mergeWith)); %clustMembsCell{mergeWith} = unique([clustMembsCell{mergeWith} myMembers]); clusterVotes(mergeWith,:) = clusterVotes(mergeWith,:) + thisClusterVotes; else %its a new cluster numClust = numClust+1; clustCent(:,numClust) = myMean; %clustMembsCell{numClust} = myMembers; clusterVotes(numClust,:) = thisClusterVotes; end break; end end initPtInds = find(beenVisitedFlag == 0); numInitPts = length(initPtInds); end [val,data2cluster] = max(clusterVotes,[],1); if nargout > 2 cluster2dataCell = cell(numClust,1); for cN = 1:numClust myMembers = find(data2cluster == cN); cluster2dataCell{cN} = myMembers; end end

