%% AUTHOR[1] : Apostolos Fanakis (8261) %% EMAIL[1] : apostolof@auth.gr %% AUTHOR[2] : Charalampos Papadiakos (8302) %% EMAIL[2] : charaldp@ece.auth.gr %% AUTHOR[3] : Hlektra Mitsi (8536) %% EMAIL[3] : ilektraem@auth.ece.gr %% $DATE : 28-December-2018 12:45:00 $ %% $Revision : 1.00 $ %% DEVELOPED : 9.0.0.341360 (R2016a) %% FILENAME : spike_sorting.m %% %% ================================================================================================= %% S.1 datasetMedians(8) = 0; datasetFactors(8) = 0; thresholdFactorInitValue = 1; % k starting value thresholdFactorEndValue = 7; % k ending value thresholdFactorStep = 0.01; % k jumping step numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue); for fileIndex=1:8 fprintf('Loading test dataset no. %d\n', fileIndex); filename = sprintf('dataset\\Data_Test_%d.mat', fileIndex); Dataset = load(filename); data = double(Dataset.data); %% Q.1.1 figure(); plot(data(1:10000)); xlim([0, 10000]); title(['First 10000 samples of dataset #' num2str(fileIndex)]); xlabel('Sample #'); ylabel('Trivial Unit'); %TODO: Is this mVolts? drawnow; %% Q.1.2 dataMedian = median(abs(data)/0.6745); datasetMedians(fileIndex) = dataMedian; numberOfSpikesPerFactor(numberOfFactors) = 0; for factorIteration=1:numberOfFactors % runs for each k % builds threshold thresholdFactor = thresholdFactorInitValue + (factorIteration - 1) * thresholdFactorStep; threshold = thresholdFactor * dataMedian; % calculates number of spikes sample = 1; while sample <= length(data) if data(sample) >= threshold % spike found numberOfSpikesPerFactor(factorIteration) = numberOfSpikesPerFactor(factorIteration) + 1; % skips cheking until values are below threshold again while sample <= length(data) sample = sample + 1; if (data(sample) <= threshold) break; end end end sample = sample + 1; end end figure(); % trims zeros numberOfSpikesTrimmed = numberOfSpikesPerFactor(1:find(numberOfSpikesPerFactor,1,'last')); endValue = thresholdFactorInitValue + thresholdFactorStep * (length(numberOfSpikesTrimmed) - 1); plot(thresholdFactorInitValue:thresholdFactorStep:endValue, numberOfSpikesTrimmed); title(['Number of spikes for different values of k for dataset #' num2str(fileIndex)]); xlabel('Threshold factor (k)'); ylabel('Number of spikes'); hold on; plot([thresholdFactorInitValue endValue], [double(Dataset.spikeNum), double(Dataset.spikeNum)]); xlim([thresholdFactorInitValue endValue]); drawnow; hold off; % finds dataset's theshold factor k that produces the closest number of % spikes ot the ground truth [~, closestIndex] = min(abs(numberOfSpikesPerFactor-Dataset.spikeNum)); datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep; clear Dataset clear data clear numberOfSpikesPerFactor end fprintf('\n'); %% Q.1.3 figure(); plot(datasetMedians, datasetFactors, 'o'); title('Polynomial curve fitting on median-threshold factor value pairs'); xlabel('Dataset median'); ylabel('Threshold factor'); hold on; empiricalRule = polyfit(datasetMedians, datasetFactors, 1); visualizationX = linspace(0, 0.5, 50); visualizationY = polyval(empiricalRule, visualizationX); plot(visualizationX, visualizationY); hold off %% ================================================================================================= %% S.2 clearvars = {'closestIndex' 'datasetFactors' 'datasetMedians' 'endValue' 'minValue' ... 'numberOfFactors' 'numberOfSpikesPerFactor' 'numberOfSpikesTrimmed' 'thresholdFactorEndValue' ... 'thresholdFactorInitValue' 'thresholdFactorStep' 'visualizationX' 'visualizationY'}; clear(clearvars{:}) clear clearvars for fileIndex=1:4 fprintf('========================================================\n'); fprintf('Loading evaluation dataset no. %d\n', fileIndex); filename = sprintf('dataset\\Data_Eval_E_%d.mat', fileIndex); Dataset = load(filename); data = double(Dataset.data); %% Q.2.1 and Q.2.2 dataMedian = median(abs(data)/0.6745); factorEstimation = polyval(empiricalRule, dataMedian); threshold = factorEstimation * dataMedian; numberOfSpikes = 0; spikesTimesEst(2500) = 0; spikesEst(2500, 64) = 0; figure(); plot(data(1:10000)); xlim([0, 10000]); title(['First 10000 samples of dataset #' num2str(fileIndex)]); xlabel('Sample #'); ylabel('Trivial Unit'); %TODO: Is this mVolts? hold on; plot([1 10000], [threshold, threshold]); drawnow; % calculates number of spikes spikeStartIndex = 1; spikeEndIndex = 1; sample = 1; while sample <= length(data) if data(sample) >= threshold % spike found numberOfSpikes = numberOfSpikes + 1; spikeStartIndex = sample; % Q.2.1 spikesTimesEst(numberOfSpikes) = spikeStartIndex; % skips cheking until values are below threshold again while sample <= length(data) sample = sample + 1; if (data(sample) <= threshold) % finds the index of the max sample for this spike spikeEndIndex = sample; [~, relativeMaxIndex] = max(data(spikeStartIndex:spikeEndIndex)); absoluteMaxIndex = spikeStartIndex - 1 + relativeMaxIndex; % defines an area of -41/+22 samples around the max % and searches for the min [~, relativeMinIndex] = min(data(absoluteMaxIndex-41:absoluteMaxIndex+22)); absoluteMinIndex = absoluteMaxIndex - 41 + relativeMinIndex; % discernes the extrema (minimum or maximum) that % occurs first firstIndex = min([absoluteMaxIndex absoluteMinIndex]); % Q.2.2 spikesEst(numberOfSpikes, :) = data(firstIndex-34:firstIndex+29); break; end end end sample = sample + 1; end fprintf('%d spikes found for dataset #%d\n', numberOfSpikes, fileIndex); fprintf('actual number of spikes = %d\n', length(Dataset.spikeTimes)); fprintf('diff = %d\n\n', numberOfSpikes - length(Dataset.spikeTimes)); figure(); hold on; for spike=1:numberOfSpikes plot(1:64, spikesEst(spike, :)); end title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']); xlabel('Samples'); ylabel('Trivial Unit'); xlim([1 64]); drawnow; hold off; %% Q.2.3 realSpikeIndex = double(Dataset.spikeTimes); numberOfCorrectEstimations = 0; numberOfUndetectedSpikes = 0; averageEstimationError = 0; correctSpikes(2500, 64) = 0; spikeClass(2500) = 0; for trueSpikeIndex=1:length(realSpikeIndex) [estimationError, closestIndex] = min(abs(spikesTimesEst-realSpikeIndex(trueSpikeIndex))); if estimationError < 32 numberOfCorrectEstimations = numberOfCorrectEstimations + 1; averageEstimationError = averageEstimationError + estimationError; correctSpikes(numberOfCorrectEstimations, :) = spikesEst(closestIndex, :); spikeClass(numberOfCorrectEstimations) = double(Dataset.spikeClass(trueSpikeIndex)); else numberOfUndetectedSpikes = numberOfUndetectedSpikes + 1; end end averageEstimationError = averageEstimationError / numberOfSpikes; fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations); if numberOfSpikes-numberOfCorrectEstimations > 0 fprintf('Number of uncorrect spike detections is %d\n', ... numberOfSpikes - numberOfCorrectEstimations); end fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes); fprintf('Average error of spike index estimation is %.2f\n\n', averageEstimationError); %% Q.2.4 features(numberOfCorrectEstimations, 7) = 0; pcaCoefficients = pca(correctSpikes, 'NumComponents', 2); for spike=1:numberOfCorrectEstimations % finds index of max [maxVal, features(spike, 1)] = max(correctSpikes(spike, :)); % calculates Vpeak-to-peak features(spike, 2) = maxVal - min(correctSpikes(spike, :)); % calculates ZCR features(spike, 3) = mean(abs(diff(sign(correctSpikes(spike, :))))); % calculates the signal energy features(spike, 4) = sum(correctSpikes(spike, :).^2); % calculates the fft of the signal asud = fft(correctSpikes(spike, :)); features(spike, 5) = asud(1)^2; features(spike, 6) = correctSpikes(spike, :) * pcaCoefficients(:, 1); features(spike, 7) = correctSpikes(spike, :) * pcaCoefficients(:, 2); end figure(); scatter(features(:, 1), features(:, 2)); title(['Dataset #' num2str(fileIndex) ' feature plot']); xlabel('Index of max value'); ylabel('Peak-to-peak amplitude'); figure(); scatter(features(:, 6), features(:, 7)); title(['Dataset #' num2str(fileIndex) ' feature plot']); xlabel('PCA feature 1'); ylabel('PCA feature 2'); %% Q.2.5 accuracy = MyClassify(features, spikeClass'); fprintf('Accuracy achieved is %.2f%%\n', accuracy); % clustering using DB-SCAN algorithm % code for DB-SCAN downloaded from here: % https://www.peterkovesi.com/matlabfns/ if fileIndex == 1 distThreshold = 0.4; minPts = 20; elseif fileIndex == 2 distThreshold = 0.15; minPts = 30; elseif fileIndex == 3 distThreshold = 0.40; minPts = 50; else distThreshold = 0.30; minPts = 40; end [~, dbScanClasses, ~] = dbscan(features(:, 6:7)', distThreshold, minPts); %[dbScanClasses, accuracy] = ShuffledClassAccuracy(spikeClass, dbScanClasses); % fixes classes enumeration if fileIndex == 1 dbScanClasses(dbScanClasses==1) = 7; dbScanClasses(dbScanClasses==3) = 1; dbScanClasses(dbScanClasses==7) = 3; elseif fileIndex == 3 dbScanClasses(dbScanClasses==3) = 7; dbScanClasses(dbScanClasses==2) = 3; dbScanClasses(dbScanClasses==1) = 2; dbScanClasses(dbScanClasses==7) = 1; elseif fileIndex == 4 dbScanClasses(dbScanClasses==3) = 7; dbScanClasses(dbScanClasses==2) = 3; dbScanClasses(dbScanClasses==7) = 2; end figure(); scatter(features(dbScanClasses == 0, 6), features(dbScanClasses == 0, 7), [], 'k', 'o'); hold on; scatter(features(dbScanClasses == 1, 6), features(dbScanClasses == 1, 7), [], 'b', '*'); scatter(features(dbScanClasses == 2, 6), features(dbScanClasses == 2, 7), [], 'r', '*'); scatter(features(dbScanClasses == 3, 6), features(dbScanClasses == 3, 7), [], 'g', '*'); title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with DB-SCAN']); xlabel('PCA feature 1'); ylabel('PCA feature 2'); accuracy = classperf(spikeClass',dbScanClasses); fprintf('Accuracy achieved with DB-SCAN is %.2f%%\n', accuracy.CorrectRate*100); % hierarchical clustering distances = pdist(features(:, 6:7)); linkages = linkage(distances, 'ward'); hierarchicalClusters = cluster(linkages, 'maxclust', 3); % fixes classes enumeration if fileIndex == 1 hierarchicalClusters(hierarchicalClusters==1) = 7; hierarchicalClusters(hierarchicalClusters==2) = 1; hierarchicalClusters(hierarchicalClusters==7) = 2; end figure(); scatter(features(hierarchicalClusters == 0, 6), features(hierarchicalClusters == 0, 7), [], 'k', 'o'); hold on; scatter(features(hierarchicalClusters == 2, 6), features(hierarchicalClusters == 2, 7), [], 'r', '*'); scatter(features(hierarchicalClusters == 3, 6), features(hierarchicalClusters == 3, 7), [], 'g', '*'); scatter(features(hierarchicalClusters == 1, 6), features(hierarchicalClusters == 1, 7), [], 'b', '*'); title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with hierarchical clustering']); xlabel('PCA feature 1'); ylabel('PCA feature 2'); accuracy = classperf(spikeClass',hierarchicalClusters); fprintf('Accuracy achieved with hierarchical clustering is %.2f%%\n', accuracy.CorrectRate*100); % clustering using kmeans algorithm rng(1); % For reproducibility kMeansClasses = kmeans(features(:, 6:7), 3); % fixes classes enumeration if fileIndex == 1 kMeansClasses(kMeansClasses==2) = 7; kMeansClasses(kMeansClasses==3) = 2; kMeansClasses(kMeansClasses==7) = 3; elseif fileIndex == 2 kMeansClasses(kMeansClasses==1) = 7; kMeansClasses(kMeansClasses==2) = 1; kMeansClasses(kMeansClasses==7) = 2; elseif fileIndex == 3 kMeansClasses(kMeansClasses==2) = 7; kMeansClasses(kMeansClasses==3) = 2; kMeansClasses(kMeansClasses==7) = 3; elseif fileIndex == 4 kMeansClasses(kMeansClasses==3) = 7; kMeansClasses(kMeansClasses==2) = 3; kMeansClasses(kMeansClasses==7) = 2; end figure(); scatter(features(kMeansClasses == 1, 6), features(kMeansClasses == 1, 7), [], 'b', '*'); hold on; scatter(features(kMeansClasses == 2, 6), features(kMeansClasses == 2, 7), [], 'r', '*'); scatter(features(kMeansClasses == 3, 6), features(kMeansClasses == 3, 7), [], 'g', '*'); title(['Dataset #' num2str(fileIndex) ' feature plot after clustering with K-Means']); xlabel('PCA feature 1'); ylabel('PCA feature 2'); accuracy = classperf(spikeClass',kMeansClasses); fprintf('Accuracy achieved with K-Means is %.2f%%\n\n', accuracy.CorrectRate*100); clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ... 'correctSpikes', 'spikeClass', 'Data', 'Dataset', 'pcaCoefficients', 'accuracy', 'asud', ... 'dbScanClasses', 'kMeansClasses'}; clear(clearvars{:}) clear clearvars end