diff --git a/dbscan.m b/dbscan.m new file mode 100644 index 0000000..843dd06 --- /dev/null +++ b/dbscan.m @@ -0,0 +1,138 @@ +% DBSCAN DBSCAN clustering algorithm +% +% Usage: [C, ptsC, centres] = dbscan(P, E, minPts) +% +% Arguments: +% P - dim x Npts array of points. +% E - Distance threshold. +% minPts - Minimum number of points required to form a cluster. +% +% Returns: +% C - Cell array of length Nc listing indices of points associated with +% each cluster. +% ptsC - Array of length Npts listing the cluster number associated with +% each point. If a point is denoted as noise (not enough nearby +% elements to form a cluster) its cluster number is 0. +% centres - dim x Nc array of the average centre of each cluster. + +% Reference: +% Martin Ester, Hans-Peter Kriegel, Jörg Sander, Xiaowei Xu (1996). "A +% density-based algorithm for discovering clusters in large spatial databases +% with noise". Proceedings of the Second International Conference on Knowledge +% Discovery and Data Mining (KDD-96). AAAI Press. pp. 226-231. +% Also see: http://en.wikipedia.org/wiki/DBSCAN + +% Copyright (c) 2013 Peter Kovesi +% Centre for Exploration Targeting +% The University of Western Australia +% peter.kovesi at uwa edu au +% +% Permission is hereby granted, free of charge, to any person obtaining a copy +% of this software and associated documentation files (the "Software"), to deal +% in the Software without restriction, subject to the following conditions: +% +% The above copyright notice and this permission notice shall be included in +% all copies or substantial portions of the Software. +% +% The Software is provided "as is", without warranty of any kind. + +% PK January 2013 + +function [C, ptsC, centres] = dbscan(P, E, minPts) + + [dim, Npts] = size(P); + + ptsC = zeros(Npts,1); + C = {}; + Nc = 0; % Cluster counter. + Pvisit = zeros(Npts,1); % Array to keep track of points that have been visited. + + for n = 1:Npts + if ~Pvisit(n) % If this point not visited yet + Pvisit(n) = 1; % mark as visited + neighbourPts = regionQuery(P, n, E); % and find its neighbours + + if length(neighbourPts) < minPts-1 % Not enough points to form a cluster + ptsC(n) = 0; % Mark point n as noise. + + else % Form a cluster... + Nc = Nc + 1; % Increment number of clusters and process + % neighbourhood. + + C{Nc} = [n]; % Initialise cluster Nc with point n + ptsC(n) = Nc; % and mark point n as being a member of cluster Nc. + + ind = 1; % Initialise index into neighbourPts array. + + % For each point P' in neighbourPts ... + while ind <= length(neighbourPts) + + nb = neighbourPts(ind); + + if ~Pvisit(nb) % If this neighbour has not been visited + Pvisit(nb) = 1; % mark it as visited. + + % Find the neighbours of this neighbour and if it has + % enough neighbours add them to the neighbourPts list + neighbourPtsP = regionQuery(P, nb, E); + if length(neighbourPtsP) >= minPts + neighbourPts = [neighbourPts neighbourPtsP]; + end + end + + % If this neighbour nb not yet a member of any cluster add it + % to this cluster. + if ~ptsC(nb) + C{Nc} = [C{Nc} nb]; + ptsC(nb) = Nc; + end + + ind = ind + 1; % Increment neighbour point index and process + % next neighbour + end + end + end + end + + % Find centres of each cluster + centres = zeros(dim,length(C)); + for n = 1:length(C) + for k = 1:length(C{n}) + centres(:,n) = centres(:,n) + P(:,C{n}(k)); + end + centres(:,n) = centres(:,n)/length(C{n}); + end + +end % of dbscan + +%------------------------------------------------------------------------ +% Find indices of all points within distance E of point with index n +% This function could make use of a precomputed distance table to avoid +% repeated distance calculations, however this would require N^2 storage. +% Not a big problem either way if the number of points being clustered is +% small. For large datasets this function will need to be optimised. + +% Arguments: +% P - the dim x Npts array of data points +% n - Index of point of interest +% E - Distance threshold + +function neighbours = regionQuery(P, n, E) + + E2 = E^2; + [dim, Npts] = size(P); + neighbours = []; + + for i = 1:Npts + if i ~= n + % Test if distance^2 < E^2 + v = P(:,i)-P(:,n); + dist2 = v'*v; + if dist2 < E2 + neighbours = [neighbours i]; + end + end + end + +end % of regionQuery + diff --git a/spike_sorting.m b/spike_sorting.m index 2d0d04c..9171556 100644 --- a/spike_sorting.m +++ b/spike_sorting.m @@ -19,7 +19,7 @@ thresholdFactorEndValue = 12; % k ending value thresholdFactorStep = 0.01; % k jumping step numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue); numberOfSpikesPerFactor(numberOfFactors) = 0; - + for fileIndex=1:8 fprintf('Loading test dataset no. %d\n', fileIndex); filename = sprintf('dataset\\Data_Test_%d.mat', fileIndex); @@ -72,7 +72,7 @@ for fileIndex=1:8 xlabel('Threshold factor (k)'); ylabel('Number of spikes'); hold on; - plot([thresholdFactorInitValue endValue], [Dataset.spikeNum, Dataset.spikeNum]); + plot([thresholdFactorInitValue endValue], [double(Dataset.spikeNum), double(Dataset.spikeNum)]); xlim([thresholdFactorInitValue endValue]); drawnow; hold off; @@ -95,7 +95,7 @@ xlabel('Dataset median'); ylabel('Threshold factor'); hold on; -empiricalRule = polyfit(datasetMedians, datasetFactors, 3); +empiricalRule = polyfit(datasetMedians, datasetFactors, 8); visualizationX = linspace(0, 0.5, 50); visualizationY = polyval(empiricalRule, visualizationX); plot(visualizationX, visualizationY); @@ -103,9 +103,9 @@ hold off %% ================================================================================================= %% S.2 -clearvars = {'closestIndex' 'datasetFactors' 'datasetMedians' 'endValue' 'minValue' 'numberOfFactors' ... - 'numberOfSpikesPerFactor' 'numberOfSpikesTrimmed' 'thresholdFactorEndValue' 'thresholdFactorInitValue' ... - 'thresholdFactorStep' 'visualizationX' 'visualizationY'}; +clearvars = {'closestIndex' 'datasetFactors' 'datasetMedians' 'endValue' 'minValue' ... + 'numberOfFactors' 'numberOfSpikesPerFactor' 'numberOfSpikesTrimmed' 'thresholdFactorEndValue' ... + 'thresholdFactorInitValue' 'thresholdFactorStep' 'visualizationX' 'visualizationY'}; clear(clearvars{:}) clear clearvars @@ -119,11 +119,22 @@ for fileIndex=1:4 dataMedian = median(abs(data)/0.6745); factorEstimation = polyval(empiricalRule, dataMedian); + factorEstimation = 4; threshold = factorEstimation * dataMedian; numberOfSpikes = 0; spikesTimesEst(2500) = 0; spikesEst(2500, 64) = 0; + figure(); + plot(data(1:10000)); + xlim([0, 10000]); + title(['First 10000 samples of dataset #' num2str(fileIndex)]); + xlabel('Sample #'); + ylabel('Trivial Unit'); %TODO: Is this mVolts? + hold on; + plot([1 10000], [threshold, threshold]); + drawnow; + % calculates number of spikes spikeStartIndex = 1; spikeEndIndex = 1; @@ -151,7 +162,7 @@ for fileIndex=1:4 % discernes the extrema (minimum or maximum) that % occurs first firstIndex = min([absoluteMaxIndex absoluteMinIndex]); - + % Q.2.1 spikesTimesEst(numberOfSpikes) = firstIndex; % Q.2.2 @@ -166,7 +177,7 @@ for fileIndex=1:4 fprintf('%d spikes found for dataset #%d\n', numberOfSpikes, fileIndex); fprintf('actual number of spikes = %d\n', length(Dataset.spikeTimes)); fprintf('diff = %d\n\n', numberOfSpikes - length(Dataset.spikeTimes)); - + figure(); hold on; for spike=1:numberOfSpikes @@ -204,13 +215,15 @@ for fileIndex=1:4 fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations); if numberOfSpikes-numberOfCorrectEstimations > 0 - fprintf('Number of uncorrect spike detections is %d\n', numberOfSpikes-numberOfCorrectEstimations); + fprintf('Number of uncorrect spike detections is %d\n', ... + numberOfSpikes - numberOfCorrectEstimations); end fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes); - fprintf('Average error of spike index estimation is %d\n\n', averageEstimationError); + fprintf('Average error of spike index estimation is %.2f\n\n', averageEstimationError); %% Q.2.4 - features(numberOfCorrectEstimations, 5) = 0; + features(numberOfCorrectEstimations, 7) = 0; + pcaCoefficients = pca(correctSpikes, 'NumComponents', 2); for spike=1:numberOfCorrectEstimations % finds index of max @@ -224,6 +237,9 @@ for fileIndex=1:4 % calculates the fft of the signal asud = fft(correctSpikes(spike, :)); features(spike, 5) = asud(1)^2; + + features(spike, 6) = correctSpikes(spike, :) * pcaCoefficients(:, 1); + features(spike, 7) = correctSpikes(spike, :) * pcaCoefficients(:, 2); end figure(); @@ -231,13 +247,57 @@ for fileIndex=1:4 title(['Dataset #' num2str(fileIndex) ' feature plot']); xlabel('Index of max value'); ylabel('Peak-to-peak amplitude'); + figure(); + scatter(features(:, 6), features(:, 7)); + title(['Dataset #' num2str(fileIndex) ' feature plot']); + xlabel('PCA feature 1'); + ylabel('PCA feature 2'); %% Q.2.5 accuracy = MyClassify(features, spikeClass'); - fprintf('Accuracy achieved is %.2f%%\n\n', accuracy) + fprintf('Accuracy achieved is %.2f%%\n\n', accuracy); + + % clustering using DB-SCAN algorithm + [~, dbScanClasses, ~] = dbscan(features(:, 6:7)', 0.4, 20); + % fixes classes enumeration + dbScanClasses(dbScanClasses==1) = 7; + dbScanClasses(dbScanClasses==3) = 1; + dbScanClasses(dbScanClasses==7) = 3; + + figure(); + scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o'); + hold on; + scatter(features(dbScanClasses == 1, 6), features(dbScanClasses == 1, 7), [], 'b', '*'); + scatter(features(dbScanClasses == 2, 6), features(dbScanClasses == 2, 7), [], 'r', '*'); + scatter(features(dbScanClasses == 3, 6), features(dbScanClasses == 3, 7), [], 'g', '*'); + title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with DB-SCAN']); + xlabel('PCA feature 1'); + ylabel('PCA feature 2'); + accuracy = classperf(spikeClass',dbScanClasses); + fprintf('Accuracy achieved with DB-SCAN is %.2f%%\n\n', accuracy.CorrectRate*100); + + % clustering using kmeans algorithm + rng(1); % For reproducibility + kMeansClasses = kmeans(features(:, 6:7), 3); + % fixes classes enumeration + kMeansClasses(kMeansClasses==2) = 7; + kMeansClasses(kMeansClasses==3) = 2; + kMeansClasses(kMeansClasses==7) = 3; + figure(); + scatter(features(kMeansClasses == 0, 6), features(kMeansClasses == 0, 7), [], 'k', 'o'); + hold on; + scatter(features(kMeansClasses == 2, 6), features(kMeansClasses == 2, 7), [], 'r', '*'); + scatter(features(kMeansClasses == 3, 6), features(kMeansClasses == 3, 7), [], 'g', '*'); + scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*'); + title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']); + xlabel('PCA feature 1'); + ylabel('PCA feature 2'); + accuracy = classperf(spikeClass',kMeansClasses); + fprintf('Accuracy achieved with K-Means is %.2f%%\n\n', accuracy.CorrectRate*100); clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ... - 'correctSpikes', 'spikeClass', 'Data'}; + 'correctSpikes', 'spikeClass', 'Data', 'Dataset', 'pcaCoefficients', 'accuracy', 'asud', ... + 'dbScanClasses', 'kMeansClasses'}; clear(clearvars{:}) clear clearvars end