function y = meanshift(x, h, varargin)
% MEANSHIFT - Mean shift implementation
%   
% SYNTAX
%
%   YOUT = MEANSHIFT( XIN, BAND )
%   YOUT = MEANSHIFT( ..., 'epsilon', EPSILON )
%   YOUT = MEANSHIFT( ..., 'verbose', VERBOSE )
%   YOUT = MEANSHIFT( ..., 'display', DISPLAY )
%
% INPUT
%
%   XIN         Input data (for clustering)             [n-by-d]
%   BAND        Bandwidth value                         [scalar]
%   
% OPTIONAL
% 
%   EPSILON     Threshold for convergence               [scalar]
%               {default: 1e-4*h}
%   VERBOSE     Print iteration number & error?         [boolean]
%               {default: false}
%   DISPLAY     Plot results of each iteration?         [boolean]
%               (only for 2D points)
%               {default: false}
%
% OUTPUT
%
%   YOUT        Final points location after mean shift  [n-by-d]
%
% DESCRIPTION
%
%   YOUT = MEANSHIFT(XIN,BAND) implements mean shift algorithm on
%   input points XIN, using Gaussian kernel with bandwidth BAND.
%   The local maxima of each point is then recorded in the output
%   array YOUT.
%
% DEPENDENCIES
%
%   <none>
%
% LOCAL-FUNCTIONS
% 
%   rangesearch2sparse
%   parseOptArgs
%
% See also      kmeans
%
  
  %% PARAMETERS

  % stoping threshold
  opt.epsilon = 1e-4*h;
  opt.verbose = false;
  opt.display = false;
  
  
  %% PARSE OPTIONAL INPUTS
  
  opt = parseOptArgs(opt, varargin{:});
  
  
  %% INITIALIZATION
  
  % number of points -- dimensionality
  [n, d] = size( x );
  
  % initialize output points to input points
  y = x;
  
  % mean shift vectors (initialize to infinite)
  m = inf;
  
  % iteration counter
  iter = 0;
  
  if opt.display && d == 2
    fig = figure(1337);
    set(fig, 'name', 'real_time_quiver')
  end
  
  norm(x)
  while norm(m) > opt.epsilon  % --- iterate unitl convergence
  
    iter = iter + 1;
    
    % find pairwise distance matrix (inside radius)
    [I, D] = rangesearch( x, y, h );
    D      = cellfun( @(x) x.^2, D, 'UniformOutput', false );
    W      = rangesearch2sparse( I, D );

    
    % compute kernel matrix
    W = spfun( @(x) exp( -x / (2*h^2) ), W );
    
    % make sure diagonal elements are 1
    W = W + spdiags( ones(n,1), 0, n, n );
    
    % compute new y vector
    y_new = W * x;
    
    % normalize vector
    
    l = [sum(W, 2) sum(W, 2)];
    
    y_new = y_new ./ l;
    
        
    % calculate mean-shift vector
    m = y_new - y;
    
    if opt.display && d == 2
      
      figure(1337)
      clf
      hold on
      scatter( y(:,1), y(:,2) );
      quiver( y(:,1), y(:,2), m(:,1), m(:,2), 0 );
      pause(0.3)
      
    end
    
    % update y
    y = y_new;
    
    if opt.verbose
      fprintf( ' Iteration %d - error %.2g\n', iter, norm(m) );
    end    
    
  end % while (m > epsilon)
  
    
end
  


%% LOCAL FUNCTION: CREATE SPARSE MATRIX FROM RANGE SEARCH

function mat = rangesearch2sparse(idxCol, dist)
% INPUT         idxCol  Index columns for matrix        [n-cell]
%               dist    Distances of points             [n-cell]
% OUTPUT        mat     Sparse matrix with distances    [n-by-n sparse]

  % number of neighbors for each point
  nNbr = cellfun( @(x) numel(x), idxCol );
    
  % number of points
  n = numel( idxCol );

  % row indices (for sparse matrix formation convenience)
  idxRow = arrayfun( @(n,i) i * ones( 1, n ), nNbr, (1:n)', ...
                     'UniformOutput', false );
  
  % sparse matrix formation
  mat = sparse( [idxRow{:}], [idxCol{:}], [dist{:}], n, n );

  
end



%% LOCAL FUNCTION: PARSE OPTIONAL ARGUMENTS

function opt = parseOptArgs (dflt, varargin)
% INPUT         dflt    Struct with default parameters  [struct]
%               <name-value pairs>                      [varargin]
% OUTPUT        opt     Updated parameters              [struct]
  
  %% INITIALIZATION
    
  ip = inputParser;
    
  ip.CaseSensitive   = false;
  ip.KeepUnmatched   = false;
  ip.PartialMatching = true;
  ip.StructExpand    = true;
  
  
  %% PARAMETERS
  
  argNames = fieldnames( dflt );
  for i = 1 : length(argNames)
    addParameter( ip, argNames{i}, dflt.(argNames{i}) );
  end
  
  
  %% PARSE AND RETURN
  
  parse( ip, varargin{:} );
  
  opt = ip.Results;
  
  
  %% SET EMPTY VALUES TO DEFAULTS
  
  for i = 1 : length(argNames)
    if isempty( opt.(argNames{i}) )
      opt.(argNames{i}) = dflt.(argNames{i});
    end
  end
  
end



%%------------------------------------------------------------
%
% AUTHORS
%
%   Dimitris Floros                         fcdimitr@auth.gr
%
% VERSION
%
%   0.2 - January 04, 2018
%
% CHANGELOG
% 
%   0.2 (Jan 04, 2018) - Dimitris
%       * FIX: distance should be squared euclidean
%       * FIX: range search radius should be bandwidth
%
%   0.1 (Dec 29, 2017) - Dimitris
%       * initial implementation
%
% ------------------------------------------------------------