You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
327 lines
13 KiB
327 lines
13 KiB
%% AUTHOR[1] : Apostolos Fanakis (8261)
|
|
%% EMAIL[1] : apostolof@auth.gr
|
|
%% AUTHOR[2] : Charalampos Papadiakos (8302)
|
|
%% EMAIL[2] : charaldp@ece.auth.gr
|
|
%% AUTHOR[3] : Hlektra Mitsi ()
|
|
%% EMAIL[3] :
|
|
%% $DATE : 28-December-2018 12:45:00 $
|
|
%% $Revision : 1.00 $
|
|
%% DEVELOPED : 9.0.0.341360 (R2016a)
|
|
%% FILENAME : spike_sorting.m
|
|
%%
|
|
%% =================================================================================================
|
|
%% S.1
|
|
datasetMedians(8) = 0;
|
|
datasetFactors(8) = 0;
|
|
|
|
thresholdFactorInitValue = 3; % k starting value
|
|
thresholdFactorEndValue = 12; % k ending value
|
|
thresholdFactorStep = 0.01; % k jumping step
|
|
numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue);
|
|
numberOfSpikesPerFactor(numberOfFactors) = 0;
|
|
|
|
for fileIndex=1:8
|
|
fprintf('Loading test dataset no. %d\n', fileIndex);
|
|
filename = sprintf('dataset\\Data_Test_%d.mat', fileIndex);
|
|
Dataset = load(filename);
|
|
data = double(Dataset.data);
|
|
|
|
%% Q.1.1
|
|
figure();
|
|
plot(data(1:10000));
|
|
xlim([0, 10000]);
|
|
title(['First 10000 samples of dataset #' num2str(fileIndex)]);
|
|
xlabel('Sample #');
|
|
ylabel('Trivial Unit'); %TODO: Is this mVolts?
|
|
drawnow;
|
|
|
|
%% Q.1.2
|
|
dataMedian = median(abs(data)/0.6745);
|
|
datasetMedians(fileIndex) = dataMedian;
|
|
|
|
for factorIteration=1:numberOfFactors % runs for each k
|
|
% builds threshold
|
|
thresholdFactor = thresholdFactorInitValue + (factorIteration - 1) * thresholdFactorStep;
|
|
threshold = thresholdFactor * dataMedian;
|
|
|
|
% calculates number of spikes
|
|
sample = 1;
|
|
while sample <= length(data)
|
|
if data(sample) >= threshold
|
|
% spike found
|
|
numberOfSpikesPerFactor(factorIteration) = numberOfSpikesPerFactor(factorIteration) + 1;
|
|
|
|
% skips cheking until values are below threshold again
|
|
while sample <= length(data)
|
|
sample = sample + 1;
|
|
if (data(sample) <= threshold)
|
|
break;
|
|
end
|
|
end
|
|
end
|
|
sample = sample + 1;
|
|
end
|
|
end
|
|
|
|
figure();
|
|
% trims zeros
|
|
numberOfSpikesTrimmed = numberOfSpikesPerFactor(1:find(numberOfSpikesPerFactor,1,'last'));
|
|
endValue = thresholdFactorInitValue + thresholdFactorStep * (length(numberOfSpikesTrimmed) - 1);
|
|
plot(thresholdFactorInitValue:thresholdFactorStep:endValue, numberOfSpikesTrimmed);
|
|
title(['Number of spikes for different values of k for dataset #' num2str(fileIndex)]);
|
|
xlabel('Threshold factor (k)');
|
|
ylabel('Number of spikes');
|
|
hold on;
|
|
plot([thresholdFactorInitValue endValue], [double(Dataset.spikeNum), double(Dataset.spikeNum)]);
|
|
xlim([thresholdFactorInitValue endValue]);
|
|
drawnow;
|
|
hold off;
|
|
|
|
% finds dataset's theshold factor k that produces the closest number of
|
|
% spikes ot the ground truth
|
|
[minValue, closestIndex] = min(abs(numberOfSpikesTrimmed-Dataset.spikeNum));
|
|
datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep;
|
|
|
|
clear Dataset
|
|
clear data
|
|
end
|
|
fprintf('\n');
|
|
|
|
%% Q.1.3
|
|
figure();
|
|
plot(datasetMedians, datasetFactors, 'o');
|
|
title('Polynomial curve fitting on median-threshold factor value pairs');
|
|
xlabel('Dataset median');
|
|
ylabel('Threshold factor');
|
|
hold on;
|
|
|
|
empiricalRule = polyfit(datasetMedians, datasetFactors, 8);
|
|
visualizationX = linspace(0, 0.5, 50);
|
|
visualizationY = polyval(empiricalRule, visualizationX);
|
|
plot(visualizationX, visualizationY);
|
|
hold off
|
|
|
|
%% =================================================================================================
|
|
%% S.2
|
|
clearvars = {'closestIndex' 'datasetFactors' 'datasetMedians' 'endValue' 'minValue' ...
|
|
'numberOfFactors' 'numberOfSpikesPerFactor' 'numberOfSpikesTrimmed' 'thresholdFactorEndValue' ...
|
|
'thresholdFactorInitValue' 'thresholdFactorStep' 'visualizationX' 'visualizationY'};
|
|
clear(clearvars{:})
|
|
clear clearvars
|
|
|
|
for fileIndex=1:4
|
|
fprintf('Loading evaluation dataset no. %d \n', fileIndex);
|
|
filename = sprintf('dataset\\Data_Eval_E_%d.mat', fileIndex);
|
|
Dataset = load(filename);
|
|
data = double(Dataset.data);
|
|
|
|
%% Q.2.1 and Q.2.2
|
|
dataMedian = median(abs(data)/0.6745);
|
|
|
|
factorEstimation = polyval(empiricalRule, dataMedian);
|
|
factorEstimation = 4;
|
|
threshold = factorEstimation * dataMedian;
|
|
numberOfSpikes = 0;
|
|
spikesTimesEst(2500) = 0;
|
|
spikesEst(2500, 64) = 0;
|
|
|
|
figure();
|
|
plot(data(1:10000));
|
|
xlim([0, 10000]);
|
|
title(['First 10000 samples of dataset #' num2str(fileIndex)]);
|
|
xlabel('Sample #');
|
|
ylabel('Trivial Unit'); %TODO: Is this mVolts?
|
|
hold on;
|
|
plot([1 10000], [threshold, threshold]);
|
|
drawnow;
|
|
|
|
% calculates number of spikes
|
|
spikeStartIndex = 1;
|
|
spikeEndIndex = 1;
|
|
sample = 1;
|
|
while sample <= length(data)
|
|
if data(sample) >= threshold
|
|
% spike found
|
|
numberOfSpikes = numberOfSpikes + 1;
|
|
spikeStartIndex = sample;
|
|
|
|
% skips cheking until values are below threshold again
|
|
while sample <= length(data)
|
|
sample = sample + 1;
|
|
if (data(sample) <= threshold)
|
|
% finds the index of the max sample for this spike
|
|
spikeEndIndex = sample;
|
|
[~, relativeMaxIndex] = max(data(spikeStartIndex:spikeEndIndex));
|
|
absoluteMaxIndex = spikeStartIndex - 1 + relativeMaxIndex;
|
|
|
|
% defines an area of -41/+22 samples around the max
|
|
% and searches for the min
|
|
[~, relativeMinIndex] = min(data(absoluteMaxIndex-41:absoluteMaxIndex+22));
|
|
absoluteMinIndex = absoluteMaxIndex - 41 + relativeMinIndex;
|
|
|
|
% discernes the extrema (minimum or maximum) that
|
|
% occurs first
|
|
firstIndex = min([absoluteMaxIndex absoluteMinIndex]);
|
|
|
|
% Q.2.1
|
|
spikesTimesEst(numberOfSpikes) = firstIndex;
|
|
% Q.2.2
|
|
spikesEst(numberOfSpikes, :) = data(firstIndex-34:firstIndex+29);
|
|
break;
|
|
end
|
|
end
|
|
end
|
|
sample = sample + 1;
|
|
end
|
|
|
|
fprintf('%d spikes found for dataset #%d\n', numberOfSpikes, fileIndex);
|
|
fprintf('actual number of spikes = %d\n', length(Dataset.spikeTimes));
|
|
fprintf('diff = %d\n\n', numberOfSpikes - length(Dataset.spikeTimes));
|
|
|
|
figure();
|
|
hold on;
|
|
for spike=1:numberOfSpikes
|
|
plot(1:64, spikesEst(spike, :));
|
|
end
|
|
title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']);
|
|
xlabel('Samples');
|
|
ylabel('Trivial Unit');
|
|
drawnow;
|
|
hold off;
|
|
|
|
%% Q.2.3
|
|
realSpikeIndex = double(Dataset.spikeTimes);
|
|
numberOfCorrectEstimations = 0;
|
|
numberOfUndetectedSpikes = 0;
|
|
averageEstimationError = 0;
|
|
|
|
correctSpikes(2500, 64) = 0;
|
|
spikeClass(2500) = 0;
|
|
|
|
for trueSpikeIndex=1:length(realSpikeIndex)
|
|
[estimationError, closestIndex] = min(abs(spikesTimesEst-realSpikeIndex(trueSpikeIndex)));
|
|
if estimationError < 32
|
|
numberOfCorrectEstimations = numberOfCorrectEstimations + 1;
|
|
averageEstimationError = averageEstimationError + estimationError;
|
|
|
|
correctSpikes(numberOfCorrectEstimations, :) = spikesEst(closestIndex, :);
|
|
spikeClass(numberOfCorrectEstimations) = double(Dataset.spikeClass(trueSpikeIndex));
|
|
else
|
|
numberOfUndetectedSpikes = numberOfUndetectedSpikes + 1;
|
|
end
|
|
end
|
|
|
|
averageEstimationError = averageEstimationError / numberOfSpikes;
|
|
|
|
fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations);
|
|
if numberOfSpikes-numberOfCorrectEstimations > 0
|
|
fprintf('Number of uncorrect spike detections is %d\n', ...
|
|
numberOfSpikes - numberOfCorrectEstimations);
|
|
end
|
|
fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes);
|
|
fprintf('Average error of spike index estimation is %.2f\n\n', averageEstimationError);
|
|
|
|
%% Q.2.4
|
|
features(numberOfCorrectEstimations, 7) = 0;
|
|
pcaCoefficients = pca(correctSpikes, 'NumComponents', 2);
|
|
|
|
for spike=1:numberOfCorrectEstimations
|
|
% finds index of max
|
|
[maxVal, features(spike, 1)] = max(correctSpikes(spike, :));
|
|
% calculates Vpeak-to-peak
|
|
features(spike, 2) = maxVal - min(correctSpikes(spike, :));
|
|
% calculates ZCR
|
|
features(spike, 3) = mean(abs(diff(sign(correctSpikes(spike, :)))));
|
|
% calculates the signal energy
|
|
features(spike, 4) = sum(correctSpikes(spike, :).^2);
|
|
% calculates the fft of the signal
|
|
asud = fft(correctSpikes(spike, :));
|
|
features(spike, 5) = asud(1)^2;
|
|
|
|
features(spike, 6) = correctSpikes(spike, :) * pcaCoefficients(:, 1);
|
|
features(spike, 7) = correctSpikes(spike, :) * pcaCoefficients(:, 2);
|
|
end
|
|
|
|
figure();
|
|
scatter(features(:, 1), features(:, 2));
|
|
title(['Dataset #' num2str(fileIndex) ' feature plot']);
|
|
xlabel('Index of max value');
|
|
ylabel('Peak-to-peak amplitude');
|
|
figure();
|
|
scatter(features(:, 6), features(:, 7));
|
|
title(['Dataset #' num2str(fileIndex) ' feature plot']);
|
|
xlabel('PCA feature 1');
|
|
ylabel('PCA feature 2');
|
|
|
|
%% Q.2.5
|
|
accuracy = MyClassify(features, spikeClass');
|
|
fprintf('Accuracy achieved is %.2f%%\n\n', accuracy);
|
|
|
|
% clustering using DB-SCAN algorithm
|
|
% code for DB-SCAN downloaded from here:
|
|
% https://www.peterkovesi.com/matlabfns/
|
|
[~, dbScanClasses, ~] = dbscan(features(:, 6:7)', 0.4, 20);
|
|
% fixes classes enumeration
|
|
dbScanClasses(dbScanClasses==1) = 7;
|
|
dbScanClasses(dbScanClasses==3) = 1;
|
|
dbScanClasses(dbScanClasses==7) = 3;
|
|
|
|
figure();
|
|
scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o');
|
|
hold on;
|
|
scatter(features(dbScanClasses == 1, 6), features(dbScanClasses == 1, 7), [], 'b', '*');
|
|
scatter(features(dbScanClasses == 2, 6), features(dbScanClasses == 2, 7), [], 'r', '*');
|
|
scatter(features(dbScanClasses == 3, 6), features(dbScanClasses == 3, 7), [], 'g', '*');
|
|
title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with DB-SCAN']);
|
|
xlabel('PCA feature 1');
|
|
ylabel('PCA feature 2');
|
|
accuracy = classperf(spikeClass',dbScanClasses);
|
|
fprintf('Accuracy achieved with DB-SCAN is %.2f%%\n\n', accuracy.CorrectRate*100);
|
|
|
|
% hierarchical clustering
|
|
distances = pdist(features(:, 6:7));
|
|
linkages = linkage(distances, 'ward');
|
|
hierarchicalClusters = cluster(linkages, 'maxclust', 3);
|
|
% fixes classes enumeration
|
|
hierarchicalClusters(hierarchicalClusters==1) = 7;
|
|
hierarchicalClusters(hierarchicalClusters==2) = 1;
|
|
hierarchicalClusters(hierarchicalClusters==7) = 2;
|
|
|
|
figure();
|
|
scatter(features(hierarchicalClusters == 0, 6), features(hierarchicalClusters == 0, 7), [], 'k', 'o');
|
|
hold on;
|
|
scatter(features(hierarchicalClusters == 2, 6), features(hierarchicalClusters == 2, 7), [], 'r', '*');
|
|
scatter(features(hierarchicalClusters == 3, 6), features(hierarchicalClusters == 3, 7), [], 'g', '*');
|
|
scatter(features(hierarchicalClusters == 1, 6), features(hierarchicalClusters == 1, 7), [], 'b', '*');
|
|
title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']);
|
|
xlabel('PCA feature 1');
|
|
ylabel('PCA feature 2');
|
|
accuracy = classperf(spikeClass',hierarchicalClusters);
|
|
fprintf('Accuracy achieved with K-Means is %.2f%%\n\n', accuracy.CorrectRate*100);
|
|
|
|
|
|
% clustering using kmeans algorithm
|
|
rng(1); % For reproducibility
|
|
kMeansClasses = kmeans(features(:, 6:7), 3);
|
|
% fixes classes enumeration
|
|
kMeansClasses(kMeansClasses==2) = 7;
|
|
kMeansClasses(kMeansClasses==3) = 2;
|
|
kMeansClasses(kMeansClasses==7) = 3;
|
|
|
|
figure();
|
|
scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*');
|
|
hold on;
|
|
scatter(features(kMeansClasses == 2, 6), features(kMeansClasses == 2, 7), [], 'r', '*');
|
|
scatter(features(kMeansClasses == 3, 6), features(kMeansClasses == 3, 7), [], 'g', '*');
|
|
title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']);
|
|
xlabel('PCA feature 1');
|
|
ylabel('PCA feature 2');
|
|
accuracy = classperf(spikeClass',kMeansClasses);
|
|
fprintf('Accuracy achieved with K-Means is %.2f%%\n\n', accuracy.CorrectRate*100);
|
|
|
|
clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ...
|
|
'correctSpikes', 'spikeClass', 'Data', 'Dataset', 'pcaCoefficients', 'accuracy', 'asud', ...
|
|
'dbScanClasses', 'kMeansClasses'};
|
|
clear(clearvars{:})
|
|
clear clearvars
|
|
end
|
|
|