From 11593acac6570376428234c1ee7746fa84babef7 Mon Sep 17 00:00:00 2001 From: Apostolof Date: Sun, 30 Dec 2018 17:34:34 +0200 Subject: [PATCH] Add hierarchical clustering --- spike_sorting.m | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/spike_sorting.m b/spike_sorting.m index 9171556..29a2fbb 100644 --- a/spike_sorting.m +++ b/spike_sorting.m @@ -39,7 +39,7 @@ for fileIndex=1:8 dataMedian = median(abs(data)/0.6745); datasetMedians(fileIndex) = dataMedian; - parfor factorIteration=1:numberOfFactors % runs for each k + for factorIteration=1:numberOfFactors % runs for each k % builds threshold thresholdFactor = thresholdFactorInitValue + (factorIteration - 1) * thresholdFactorStep; threshold = thresholdFactor * dataMedian; @@ -258,6 +258,8 @@ for fileIndex=1:4 fprintf('Accuracy achieved is %.2f%%\n\n', accuracy); % clustering using DB-SCAN algorithm + % code for DB-SCAN downloaded from here: + % https://www.peterkovesi.com/matlabfns/ [~, dbScanClasses, ~] = dbscan(features(:, 6:7)', 0.4, 20); % fixes classes enumeration dbScanClasses(dbScanClasses==1) = 7; @@ -276,6 +278,28 @@ for fileIndex=1:4 accuracy = classperf(spikeClass',dbScanClasses); fprintf('Accuracy achieved with DB-SCAN is %.2f%%\n\n', accuracy.CorrectRate*100); + % hierarchical clustering + distances = pdist(features(:, 6:7)); + linkages = linkage(distances, 'ward'); + hierarchicalClusters = cluster(linkages, 'maxclust', 3); + % fixes classes enumeration + hierarchicalClusters(hierarchicalClusters==1) = 7; + hierarchicalClusters(hierarchicalClusters==2) = 1; + hierarchicalClusters(hierarchicalClusters==7) = 2; + + figure(); + scatter(features(hierarchicalClusters == 0, 6), features(hierarchicalClusters == 0, 7), [], 'k', 'o'); + hold on; + scatter(features(hierarchicalClusters == 2, 6), features(hierarchicalClusters == 2, 7), [], 'r', '*'); + scatter(features(hierarchicalClusters == 3, 6), features(hierarchicalClusters == 3, 7), [], 'g', '*'); + scatter(features(hierarchicalClusters == 1, 6), features(hierarchicalClusters == 1, 7), [], 'b', '*'); + title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']); + xlabel('PCA feature 1'); + ylabel('PCA feature 2'); + accuracy = classperf(spikeClass',hierarchicalClusters); + fprintf('Accuracy achieved with K-Means is %.2f%%\n\n', accuracy.CorrectRate*100); + + % clustering using kmeans algorithm rng(1); % For reproducibility kMeansClasses = kmeans(features(:, 6:7), 3); @@ -283,12 +307,12 @@ for fileIndex=1:4 kMeansClasses(kMeansClasses==2) = 7; kMeansClasses(kMeansClasses==3) = 2; kMeansClasses(kMeansClasses==7) = 3; + figure(); - scatter(features(kMeansClasses == 0, 6), features(kMeansClasses == 0, 7), [], 'k', 'o'); + scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*'); hold on; scatter(features(kMeansClasses == 2, 6), features(kMeansClasses == 2, 7), [], 'r', '*'); scatter(features(kMeansClasses == 3, 6), features(kMeansClasses == 3, 7), [], 'g', '*'); - scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*'); title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']); xlabel('PCA feature 1'); ylabel('PCA feature 2');