diff --git a/spike_sorting.m b/spike_sorting.m index 95429d9..8ebd965 100644 --- a/spike_sorting.m +++ b/spike_sorting.m @@ -79,7 +79,7 @@ for fileIndex=1:8 % finds dataset's theshold factor k that produces the closest number of % spikes ot the ground truth - [minValue, closestIndex] = min(abs(numberOfSpikesTrimmed-Dataset.spikeNum)); + [~, closestIndex] = min(abs(numberOfSpikesPerFactor-Dataset.spikeNum)); datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep; clear Dataset @@ -186,6 +186,7 @@ for fileIndex=1:4 title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']); xlabel('Samples'); ylabel('Trivial Unit'); + xlim([1 64]); drawnow; hold off; @@ -275,10 +276,22 @@ for fileIndex=1:4 end [~, dbScanClasses, ~] = dbscan(features(:, 6:7)', distThreshold, minPts); + %[dbScanClasses, accuracy] = ShuffledClassAccuracy(spikeClass, dbScanClasses); % fixes classes enumeration - dbScanClasses(dbScanClasses==1) = 7; - dbScanClasses(dbScanClasses==3) = 1; - dbScanClasses(dbScanClasses==7) = 3; + if fileIndex == 1 + dbScanClasses(dbScanClasses==1) = 7; + dbScanClasses(dbScanClasses==3) = 1; + dbScanClasses(dbScanClasses==7) = 3; + elseif fileIndex == 3 + dbScanClasses(dbScanClasses==1) = 7; + dbScanClasses(dbScanClasses==2) = 1; + dbScanClasses(dbScanClasses==7) = 2; + elseif fileIndex == 4 + dbScanClasses(dbScanClasses==1) = 7; + dbScanClasses(dbScanClasses==3) = 1; + dbScanClasses(dbScanClasses==2) = 3; + dbScanClasses(dbScanClasses==7) = 2; + end figure(); scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o'); @@ -297,9 +310,15 @@ for fileIndex=1:4 linkages = linkage(distances, 'ward'); hierarchicalClusters = cluster(linkages, 'maxclust', 3); % fixes classes enumeration - hierarchicalClusters(hierarchicalClusters==1) = 7; - hierarchicalClusters(hierarchicalClusters==2) = 1; - hierarchicalClusters(hierarchicalClusters==7) = 2; + if fileIndex == 1 + hierarchicalClusters(hierarchicalClusters==1) = 7; + hierarchicalClusters(hierarchicalClusters==2) = 1; + hierarchicalClusters(hierarchicalClusters==7) = 2; + elseif fileIndex == 3 + hierarchicalClusters(hierarchicalClusters==1) = 7; + hierarchicalClusters(hierarchicalClusters==2) = 1; + hierarchicalClusters(hierarchicalClusters==7) = 2; + end figure(); scatter(features(hierarchicalClusters == 0, 6), features(hierarchicalClusters == 0, 7), [], 'k', 'o'); @@ -317,9 +336,20 @@ for fileIndex=1:4 rng(1); % For reproducibility kMeansClasses = kmeans(features(:, 6:7), 3); % fixes classes enumeration - kMeansClasses(kMeansClasses==2) = 7; - kMeansClasses(kMeansClasses==3) = 2; - kMeansClasses(kMeansClasses==7) = 3; + if fileIndex == 1 + kMeansClasses(kMeansClasses==2) = 7; + kMeansClasses(kMeansClasses==3) = 2; + kMeansClasses(kMeansClasses==7) = 3; + elseif fileIndex == 2 + kMeansClasses(kMeansClasses==1) = 7; + kMeansClasses(kMeansClasses==2) = 1; + kMeansClasses(kMeansClasses==7) = 2; + elseif fileIndex == 3 + kMeansClasses(kMeansClasses==1) = 7; + kMeansClasses(kMeansClasses==2) = 1; + kMeansClasses(kMeansClasses==3) = 2; + kMeansClasses(kMeansClasses==7) = 3; + end figure(); scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*');