Browse Source

Fix class enumerations, Other fixes

master
Apostolos Fanakis 6 years ago
parent
commit
daab28c64e
  1. 32
      spike_sorting.m

32
spike_sorting.m

@ -79,7 +79,7 @@ for fileIndex=1:8
% finds dataset's theshold factor k that produces the closest number of % finds dataset's theshold factor k that produces the closest number of
% spikes ot the ground truth % spikes ot the ground truth
[minValue, closestIndex] = min(abs(numberOfSpikesTrimmed-Dataset.spikeNum)); [~, closestIndex] = min(abs(numberOfSpikesPerFactor-Dataset.spikeNum));
datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep; datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep;
clear Dataset clear Dataset
@ -186,6 +186,7 @@ for fileIndex=1:4
title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']); title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']);
xlabel('Samples'); xlabel('Samples');
ylabel('Trivial Unit'); ylabel('Trivial Unit');
xlim([1 64]);
drawnow; drawnow;
hold off; hold off;
@ -275,10 +276,22 @@ for fileIndex=1:4
end end
[~, dbScanClasses, ~] = dbscan(features(:, 6:7)', distThreshold, minPts); [~, dbScanClasses, ~] = dbscan(features(:, 6:7)', distThreshold, minPts);
%[dbScanClasses, accuracy] = ShuffledClassAccuracy(spikeClass, dbScanClasses);
% fixes classes enumeration % fixes classes enumeration
if fileIndex == 1
dbScanClasses(dbScanClasses==1) = 7; dbScanClasses(dbScanClasses==1) = 7;
dbScanClasses(dbScanClasses==3) = 1; dbScanClasses(dbScanClasses==3) = 1;
dbScanClasses(dbScanClasses==7) = 3; dbScanClasses(dbScanClasses==7) = 3;
elseif fileIndex == 3
dbScanClasses(dbScanClasses==1) = 7;
dbScanClasses(dbScanClasses==2) = 1;
dbScanClasses(dbScanClasses==7) = 2;
elseif fileIndex == 4
dbScanClasses(dbScanClasses==1) = 7;
dbScanClasses(dbScanClasses==3) = 1;
dbScanClasses(dbScanClasses==2) = 3;
dbScanClasses(dbScanClasses==7) = 2;
end
figure(); figure();
scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o'); scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o');
@ -297,9 +310,15 @@ for fileIndex=1:4
linkages = linkage(distances, 'ward'); linkages = linkage(distances, 'ward');
hierarchicalClusters = cluster(linkages, 'maxclust', 3); hierarchicalClusters = cluster(linkages, 'maxclust', 3);
% fixes classes enumeration % fixes classes enumeration
if fileIndex == 1
hierarchicalClusters(hierarchicalClusters==1) = 7;
hierarchicalClusters(hierarchicalClusters==2) = 1;
hierarchicalClusters(hierarchicalClusters==7) = 2;
elseif fileIndex == 3
hierarchicalClusters(hierarchicalClusters==1) = 7; hierarchicalClusters(hierarchicalClusters==1) = 7;
hierarchicalClusters(hierarchicalClusters==2) = 1; hierarchicalClusters(hierarchicalClusters==2) = 1;
hierarchicalClusters(hierarchicalClusters==7) = 2; hierarchicalClusters(hierarchicalClusters==7) = 2;
end
figure(); figure();
scatter(features(hierarchicalClusters == 0, 6), features(hierarchicalClusters == 0, 7), [], 'k', 'o'); scatter(features(hierarchicalClusters == 0, 6), features(hierarchicalClusters == 0, 7), [], 'k', 'o');
@ -317,9 +336,20 @@ for fileIndex=1:4
rng(1); % For reproducibility rng(1); % For reproducibility
kMeansClasses = kmeans(features(:, 6:7), 3); kMeansClasses = kmeans(features(:, 6:7), 3);
% fixes classes enumeration % fixes classes enumeration
if fileIndex == 1
kMeansClasses(kMeansClasses==2) = 7; kMeansClasses(kMeansClasses==2) = 7;
kMeansClasses(kMeansClasses==3) = 2; kMeansClasses(kMeansClasses==3) = 2;
kMeansClasses(kMeansClasses==7) = 3; kMeansClasses(kMeansClasses==7) = 3;
elseif fileIndex == 2
kMeansClasses(kMeansClasses==1) = 7;
kMeansClasses(kMeansClasses==2) = 1;
kMeansClasses(kMeansClasses==7) = 2;
elseif fileIndex == 3
kMeansClasses(kMeansClasses==1) = 7;
kMeansClasses(kMeansClasses==2) = 1;
kMeansClasses(kMeansClasses==3) = 2;
kMeansClasses(kMeansClasses==7) = 3;
end
figure(); figure();
scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*'); scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*');

Loading…
Cancel
Save