Browse Source

Add db-scan and k-means clustering

master
Apostolos Fanakis 6 years ago
parent
commit
71bfc02dea
  1. 138
      dbscan.m
  2. 80
      spike_sorting.m

138
dbscan.m

@ -0,0 +1,138 @@
% DBSCAN DBSCAN clustering algorithm
%
% Usage: [C, ptsC, centres] = dbscan(P, E, minPts)
%
% Arguments:
% P - dim x Npts array of points.
% E - Distance threshold.
% minPts - Minimum number of points required to form a cluster.
%
% Returns:
% C - Cell array of length Nc listing indices of points associated with
% each cluster.
% ptsC - Array of length Npts listing the cluster number associated with
% each point. If a point is denoted as noise (not enough nearby
% elements to form a cluster) its cluster number is 0.
% centres - dim x Nc array of the average centre of each cluster.
% Reference:
% Martin Ester, Hans-Peter Kriegel, Jörg Sander, Xiaowei Xu (1996). "A
% density-based algorithm for discovering clusters in large spatial databases
% with noise". Proceedings of the Second International Conference on Knowledge
% Discovery and Data Mining (KDD-96). AAAI Press. pp. 226-231.
% Also see: http://en.wikipedia.org/wiki/DBSCAN
% Copyright (c) 2013 Peter Kovesi
% Centre for Exploration Targeting
% The University of Western Australia
% peter.kovesi at uwa edu au
%
% Permission is hereby granted, free of charge, to any person obtaining a copy
% of this software and associated documentation files (the "Software"), to deal
% in the Software without restriction, subject to the following conditions:
%
% The above copyright notice and this permission notice shall be included in
% all copies or substantial portions of the Software.
%
% The Software is provided "as is", without warranty of any kind.
% PK January 2013
function [C, ptsC, centres] = dbscan(P, E, minPts)
[dim, Npts] = size(P);
ptsC = zeros(Npts,1);
C = {};
Nc = 0; % Cluster counter.
Pvisit = zeros(Npts,1); % Array to keep track of points that have been visited.
for n = 1:Npts
if ~Pvisit(n) % If this point not visited yet
Pvisit(n) = 1; % mark as visited
neighbourPts = regionQuery(P, n, E); % and find its neighbours
if length(neighbourPts) < minPts-1 % Not enough points to form a cluster
ptsC(n) = 0; % Mark point n as noise.
else % Form a cluster...
Nc = Nc + 1; % Increment number of clusters and process
% neighbourhood.
C{Nc} = [n]; % Initialise cluster Nc with point n
ptsC(n) = Nc; % and mark point n as being a member of cluster Nc.
ind = 1; % Initialise index into neighbourPts array.
% For each point P' in neighbourPts ...
while ind <= length(neighbourPts)
nb = neighbourPts(ind);
if ~Pvisit(nb) % If this neighbour has not been visited
Pvisit(nb) = 1; % mark it as visited.
% Find the neighbours of this neighbour and if it has
% enough neighbours add them to the neighbourPts list
neighbourPtsP = regionQuery(P, nb, E);
if length(neighbourPtsP) >= minPts
neighbourPts = [neighbourPts neighbourPtsP];
end
end
% If this neighbour nb not yet a member of any cluster add it
% to this cluster.
if ~ptsC(nb)
C{Nc} = [C{Nc} nb];
ptsC(nb) = Nc;
end
ind = ind + 1; % Increment neighbour point index and process
% next neighbour
end
end
end
end
% Find centres of each cluster
centres = zeros(dim,length(C));
for n = 1:length(C)
for k = 1:length(C{n})
centres(:,n) = centres(:,n) + P(:,C{n}(k));
end
centres(:,n) = centres(:,n)/length(C{n});
end
end % of dbscan
%------------------------------------------------------------------------
% Find indices of all points within distance E of point with index n
% This function could make use of a precomputed distance table to avoid
% repeated distance calculations, however this would require N^2 storage.
% Not a big problem either way if the number of points being clustered is
% small. For large datasets this function will need to be optimised.
% Arguments:
% P - the dim x Npts array of data points
% n - Index of point of interest
% E - Distance threshold
function neighbours = regionQuery(P, n, E)
E2 = E^2;
[dim, Npts] = size(P);
neighbours = [];
for i = 1:Npts
if i ~= n
% Test if distance^2 < E^2
v = P(:,i)-P(:,n);
dist2 = v'*v;
if dist2 < E2
neighbours = [neighbours i];
end
end
end
end % of regionQuery

80
spike_sorting.m

@ -72,7 +72,7 @@ for fileIndex=1:8
xlabel('Threshold factor (k)'); xlabel('Threshold factor (k)');
ylabel('Number of spikes'); ylabel('Number of spikes');
hold on; hold on;
plot([thresholdFactorInitValue endValue], [Dataset.spikeNum, Dataset.spikeNum]); plot([thresholdFactorInitValue endValue], [double(Dataset.spikeNum), double(Dataset.spikeNum)]);
xlim([thresholdFactorInitValue endValue]); xlim([thresholdFactorInitValue endValue]);
drawnow; drawnow;
hold off; hold off;
@ -95,7 +95,7 @@ xlabel('Dataset median');
ylabel('Threshold factor'); ylabel('Threshold factor');
hold on; hold on;
empiricalRule = polyfit(datasetMedians, datasetFactors, 3); empiricalRule = polyfit(datasetMedians, datasetFactors, 8);
visualizationX = linspace(0, 0.5, 50); visualizationX = linspace(0, 0.5, 50);
visualizationY = polyval(empiricalRule, visualizationX); visualizationY = polyval(empiricalRule, visualizationX);
plot(visualizationX, visualizationY); plot(visualizationX, visualizationY);
@ -103,9 +103,9 @@ hold off
%% ================================================================================================= %% =================================================================================================
%% S.2 %% S.2
clearvars = {'closestIndex' 'datasetFactors' 'datasetMedians' 'endValue' 'minValue' 'numberOfFactors' ... clearvars = {'closestIndex' 'datasetFactors' 'datasetMedians' 'endValue' 'minValue' ...
'numberOfSpikesPerFactor' 'numberOfSpikesTrimmed' 'thresholdFactorEndValue' 'thresholdFactorInitValue' ... 'numberOfFactors' 'numberOfSpikesPerFactor' 'numberOfSpikesTrimmed' 'thresholdFactorEndValue' ...
'thresholdFactorStep' 'visualizationX' 'visualizationY'}; 'thresholdFactorInitValue' 'thresholdFactorStep' 'visualizationX' 'visualizationY'};
clear(clearvars{:}) clear(clearvars{:})
clear clearvars clear clearvars
@ -119,11 +119,22 @@ for fileIndex=1:4
dataMedian = median(abs(data)/0.6745); dataMedian = median(abs(data)/0.6745);
factorEstimation = polyval(empiricalRule, dataMedian); factorEstimation = polyval(empiricalRule, dataMedian);
factorEstimation = 4;
threshold = factorEstimation * dataMedian; threshold = factorEstimation * dataMedian;
numberOfSpikes = 0; numberOfSpikes = 0;
spikesTimesEst(2500) = 0; spikesTimesEst(2500) = 0;
spikesEst(2500, 64) = 0; spikesEst(2500, 64) = 0;
figure();
plot(data(1:10000));
xlim([0, 10000]);
title(['First 10000 samples of dataset #' num2str(fileIndex)]);
xlabel('Sample #');
ylabel('Trivial Unit'); %TODO: Is this mVolts?
hold on;
plot([1 10000], [threshold, threshold]);
drawnow;
% calculates number of spikes % calculates number of spikes
spikeStartIndex = 1; spikeStartIndex = 1;
spikeEndIndex = 1; spikeEndIndex = 1;
@ -204,13 +215,15 @@ for fileIndex=1:4
fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations); fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations);
if numberOfSpikes-numberOfCorrectEstimations > 0 if numberOfSpikes-numberOfCorrectEstimations > 0
fprintf('Number of uncorrect spike detections is %d\n', numberOfSpikes-numberOfCorrectEstimations); fprintf('Number of uncorrect spike detections is %d\n', ...
numberOfSpikes - numberOfCorrectEstimations);
end end
fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes); fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes);
fprintf('Average error of spike index estimation is %d\n\n', averageEstimationError); fprintf('Average error of spike index estimation is %.2f\n\n', averageEstimationError);
%% Q.2.4 %% Q.2.4
features(numberOfCorrectEstimations, 5) = 0; features(numberOfCorrectEstimations, 7) = 0;
pcaCoefficients = pca(correctSpikes, 'NumComponents', 2);
for spike=1:numberOfCorrectEstimations for spike=1:numberOfCorrectEstimations
% finds index of max % finds index of max
@ -224,6 +237,9 @@ for fileIndex=1:4
% calculates the fft of the signal % calculates the fft of the signal
asud = fft(correctSpikes(spike, :)); asud = fft(correctSpikes(spike, :));
features(spike, 5) = asud(1)^2; features(spike, 5) = asud(1)^2;
features(spike, 6) = correctSpikes(spike, :) * pcaCoefficients(:, 1);
features(spike, 7) = correctSpikes(spike, :) * pcaCoefficients(:, 2);
end end
figure(); figure();
@ -231,13 +247,57 @@ for fileIndex=1:4
title(['Dataset #' num2str(fileIndex) ' feature plot']); title(['Dataset #' num2str(fileIndex) ' feature plot']);
xlabel('Index of max value'); xlabel('Index of max value');
ylabel('Peak-to-peak amplitude'); ylabel('Peak-to-peak amplitude');
figure();
scatter(features(:, 6), features(:, 7));
title(['Dataset #' num2str(fileIndex) ' feature plot']);
xlabel('PCA feature 1');
ylabel('PCA feature 2');
%% Q.2.5 %% Q.2.5
accuracy = MyClassify(features, spikeClass'); accuracy = MyClassify(features, spikeClass');
fprintf('Accuracy achieved is %.2f%%\n\n', accuracy) fprintf('Accuracy achieved is %.2f%%\n\n', accuracy);
% clustering using DB-SCAN algorithm
[~, dbScanClasses, ~] = dbscan(features(:, 6:7)', 0.4, 20);
% fixes classes enumeration
dbScanClasses(dbScanClasses==1) = 7;
dbScanClasses(dbScanClasses==3) = 1;
dbScanClasses(dbScanClasses==7) = 3;
figure();
scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o');
hold on;
scatter(features(dbScanClasses == 1, 6), features(dbScanClasses == 1, 7), [], 'b', '*');
scatter(features(dbScanClasses == 2, 6), features(dbScanClasses == 2, 7), [], 'r', '*');
scatter(features(dbScanClasses == 3, 6), features(dbScanClasses == 3, 7), [], 'g', '*');
title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with DB-SCAN']);
xlabel('PCA feature 1');
ylabel('PCA feature 2');
accuracy = classperf(spikeClass',dbScanClasses);
fprintf('Accuracy achieved with DB-SCAN is %.2f%%\n\n', accuracy.CorrectRate*100);
% clustering using kmeans algorithm
rng(1); % For reproducibility
kMeansClasses = kmeans(features(:, 6:7), 3);
% fixes classes enumeration
kMeansClasses(kMeansClasses==2) = 7;
kMeansClasses(kMeansClasses==3) = 2;
kMeansClasses(kMeansClasses==7) = 3;
figure();
scatter(features(kMeansClasses == 0, 6), features(kMeansClasses == 0, 7), [], 'k', 'o');
hold on;
scatter(features(kMeansClasses == 2, 6), features(kMeansClasses == 2, 7), [], 'r', '*');
scatter(features(kMeansClasses == 3, 6), features(kMeansClasses == 3, 7), [], 'g', '*');
scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*');
title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']);
xlabel('PCA feature 1');
ylabel('PCA feature 2');
accuracy = classperf(spikeClass',kMeansClasses);
fprintf('Accuracy achieved with K-Means is %.2f%%\n\n', accuracy.CorrectRate*100);
clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ... clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ...
'correctSpikes', 'spikeClass', 'Data'}; 'correctSpikes', 'spikeClass', 'Data', 'Dataset', 'pcaCoefficients', 'accuracy', 'asud', ...
'dbScanClasses', 'kMeansClasses'};
clear(clearvars{:}) clear(clearvars{:})
clear clearvars clear clearvars
end end

Loading…
Cancel
Save