function y = meanshift(x, h, varargin) % MEANSHIFT - Mean shift implementation % % SYNTAX % % YOUT = MEANSHIFT( XIN, BAND ) % YOUT = MEANSHIFT( ..., 'epsilon', EPSILON ) % YOUT = MEANSHIFT( ..., 'verbose', VERBOSE ) % YOUT = MEANSHIFT( ..., 'display', DISPLAY ) % % INPUT % % XIN Input data (for clustering) [n-by-d] % BAND Bandwidth value [scalar] % % OPTIONAL % % EPSILON Threshold for convergence [scalar] % {default: 1e-4*h} % VERBOSE Print iteration number & error? [boolean] % {default: false} % DISPLAY Plot results of each iteration? [boolean] % (only for 2D points) % {default: false} % % OUTPUT % % YOUT Final points location after mean shift [n-by-d] % % DESCRIPTION % % YOUT = MEANSHIFT(XIN,BAND) implements mean shift algorithm on % input points XIN, using Gaussian kernel with bandwidth BAND. % The local maxima of each point is then recorded in the output % array YOUT. % % DEPENDENCIES % % % % LOCAL-FUNCTIONS % % rangesearch2sparse % parseOptArgs % % See also kmeans % %% PARAMETERS % stoping threshold opt.epsilon = 1e-4*h; opt.verbose = false; opt.display = false; %% PARSE OPTIONAL INPUTS opt = parseOptArgs(opt, varargin{:}); %% INITIALIZATION % number of points -- dimensionality [n, d] = size( x ); % initialize output points to input points y = x; % mean shift vectors (initialize to infinite) m = inf; % iteration counter iter = 0; if opt.display && d == 2 fig = figure(1337); set(fig, 'name', 'real_time_quiver') end norm(x) while norm(m) > opt.epsilon % --- iterate unitl convergence iter = iter + 1; % find pairwise distance matrix (inside radius) [I, D] = rangesearch( x, y, h ); D = cellfun( @(x) x.^2, D, 'UniformOutput', false ); W = rangesearch2sparse( I, D ); % compute kernel matrix W = spfun( @(x) exp( -x / (2*h^2) ), W ); % make sure diagonal elements are 1 W = W + spdiags( ones(n,1), 0, n, n ); % compute new y vector y_new = W * x; % normalize vector l = [sum(W, 2) sum(W, 2)]; y_new = y_new ./ l; % calculate mean-shift vector m = y_new - y; if opt.display && d == 2 figure(1337) clf hold on scatter( y(:,1), y(:,2) ); quiver( y(:,1), y(:,2), m(:,1), m(:,2), 0 ); pause(0.3) end % update y y = y_new; if opt.verbose fprintf( ' Iteration %d - error %.2g\n', iter, norm(m) ); end end % while (m > epsilon) end %% LOCAL FUNCTION: CREATE SPARSE MATRIX FROM RANGE SEARCH function mat = rangesearch2sparse(idxCol, dist) % INPUT idxCol Index columns for matrix [n-cell] % dist Distances of points [n-cell] % OUTPUT mat Sparse matrix with distances [n-by-n sparse] % number of neighbors for each point nNbr = cellfun( @(x) numel(x), idxCol ); % number of points n = numel( idxCol ); % row indices (for sparse matrix formation convenience) idxRow = arrayfun( @(n,i) i * ones( 1, n ), nNbr, (1:n)', ... 'UniformOutput', false ); % sparse matrix formation mat = sparse( [idxRow{:}], [idxCol{:}], [dist{:}], n, n ); end %% LOCAL FUNCTION: PARSE OPTIONAL ARGUMENTS function opt = parseOptArgs (dflt, varargin) % INPUT dflt Struct with default parameters [struct] % [varargin] % OUTPUT opt Updated parameters [struct] %% INITIALIZATION ip = inputParser; ip.CaseSensitive = false; ip.KeepUnmatched = false; ip.PartialMatching = true; ip.StructExpand = true; %% PARAMETERS argNames = fieldnames( dflt ); for i = 1 : length(argNames) addParameter( ip, argNames{i}, dflt.(argNames{i}) ); end %% PARSE AND RETURN parse( ip, varargin{:} ); opt = ip.Results; %% SET EMPTY VALUES TO DEFAULTS for i = 1 : length(argNames) if isempty( opt.(argNames{i}) ) opt.(argNames{i}) = dflt.(argNames{i}); end end end %%------------------------------------------------------------ % % AUTHORS % % Dimitris Floros fcdimitr@auth.gr % % VERSION % % 0.2 - January 04, 2018 % % CHANGELOG % % 0.2 (Jan 04, 2018) - Dimitris % * FIX: distance should be squared euclidean % * FIX: range search radius should be bandwidth % % 0.1 (Dec 29, 2017) - Dimitris % * initial implementation % % ------------------------------------------------------------