diff --git a/MyClassify.m b/MyClassify.m new file mode 100644 index 0000000..eaa38c5 --- /dev/null +++ b/MyClassify.m @@ -0,0 +1,11 @@ +function Acc = MyClassify(Data,group) +group=categorical(group); +idx=randperm(size(Data,1),floor(size(Data,1)*0.7)); +train=Data(idx,:); +trGroup=group(idx); +test=Data; test(idx,:)=[]; +teGroup=group; teGroup(idx)=[]; +class = classify(test,train,trGroup); +Acc=(sum(class==teGroup)/numel(teGroup))*100; +end + diff --git a/spike_sorting.m b/spike_sorting.m index 4cb1468..2d0d04c 100644 --- a/spike_sorting.m +++ b/spike_sorting.m @@ -11,9 +11,15 @@ %% %% ================================================================================================= %% S.1 -datasetMedians = zeros(8); -datasetFactors = zeros(8); +datasetMedians(8) = 0; +datasetFactors(8) = 0; +thresholdFactorInitValue = 3; % k starting value +thresholdFactorEndValue = 12; % k ending value +thresholdFactorStep = 0.01; % k jumping step +numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue); +numberOfSpikesPerFactor(numberOfFactors) = 0; + for fileIndex=1:8 fprintf('Loading test dataset no. %d\n', fileIndex); filename = sprintf('dataset\\Data_Test_%d.mat', fileIndex); @@ -33,12 +39,6 @@ for fileIndex=1:8 dataMedian = median(abs(data)/0.6745); datasetMedians(fileIndex) = dataMedian; - thresholdFactorInitValue = 2; % k starting value - thresholdFactorEndValue = 14; % k ending value - thresholdFactorStep = 0.01; % k jumping step - numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue); - numberOfSpikesPerFactor = zeros(numberOfFactors); - parfor factorIteration=1:numberOfFactors % runs for each k % builds threshold thresholdFactor = thresholdFactorInitValue + (factorIteration - 1) * thresholdFactorStep; @@ -82,7 +82,7 @@ for fileIndex=1:8 [minValue, closestIndex] = min(abs(numberOfSpikesTrimmed-Dataset.spikeNum)); datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep; - clear dataset + clear Dataset clear data end fprintf('\n'); @@ -94,7 +94,8 @@ title('Polynomial curve fitting on median-threshold factor value pairs'); xlabel('Dataset median'); ylabel('Threshold factor'); hold on; -empiricalRule = polyfit(datasetMedians, datasetFactors, 8); + +empiricalRule = polyfit(datasetMedians, datasetFactors, 3); visualizationX = linspace(0, 0.5, 50); visualizationY = polyval(empiricalRule, visualizationX); plot(visualizationX, visualizationY); @@ -137,14 +138,24 @@ for fileIndex=1:4 while sample <= length(data) sample = sample + 1; if (data(sample) <= threshold) + % finds the index of the max sample for this spike spikeEndIndex = sample; - [~, minIndex] = min(data(spikeStartIndex:spikeEndIndex)); - [~, maxIndex] = max(data(spikeStartIndex:spikeEndIndex)); - firstIndex = min([minIndex maxIndex]); + [~, relativeMaxIndex] = max(data(spikeStartIndex:spikeEndIndex)); + absoluteMaxIndex = spikeStartIndex - 1 + relativeMaxIndex; + + % defines an area of -41/+22 samples around the max + % and searches for the min + [~, relativeMinIndex] = min(data(absoluteMaxIndex-41:absoluteMaxIndex+22)); + absoluteMinIndex = absoluteMaxIndex - 41 + relativeMinIndex; + + % discernes the extrema (minimum or maximum) that + % occurs first + firstIndex = min([absoluteMaxIndex absoluteMinIndex]); + % Q.2.1 spikesTimesEst(numberOfSpikes) = firstIndex; % Q.2.2 - spikesEst(numberOfSpikes, :) = data(spikeStartIndex-1+firstIndex-34:spikeStartIndex-1+firstIndex+29); + spikesEst(numberOfSpikes, :) = data(firstIndex-34:firstIndex+29); break; end end @@ -154,13 +165,79 @@ for fileIndex=1:4 fprintf('%d spikes found for dataset #%d\n', numberOfSpikes, fileIndex); fprintf('actual number of spikes = %d\n', length(Dataset.spikeTimes)); - fprintf('diff = %d\n\n', abs(length(Dataset.spikeTimes) - numberOfSpikes)); + fprintf('diff = %d\n\n', numberOfSpikes - length(Dataset.spikeTimes)); figure(); hold on; for spike=1:numberOfSpikes plot(1:64, spikesEst(spike, :)); end + title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']); + xlabel('Samples'); + ylabel('Trivial Unit'); drawnow; hold off; + + %% Q.2.3 + realSpikeIndex = double(Dataset.spikeTimes); + numberOfCorrectEstimations = 0; + numberOfUndetectedSpikes = 0; + averageEstimationError = 0; + + correctSpikes(2500, 64) = 0; + spikeClass(2500) = 0; + + for trueSpikeIndex=1:length(realSpikeIndex) + [estimationError, closestIndex] = min(abs(spikesTimesEst-realSpikeIndex(trueSpikeIndex))); + if estimationError < 32 + numberOfCorrectEstimations = numberOfCorrectEstimations + 1; + averageEstimationError = averageEstimationError + estimationError; + + correctSpikes(numberOfCorrectEstimations, :) = spikesEst(closestIndex, :); + spikeClass(numberOfCorrectEstimations) = double(Dataset.spikeClass(trueSpikeIndex)); + else + numberOfUndetectedSpikes = numberOfUndetectedSpikes + 1; + end + end + + averageEstimationError = averageEstimationError / numberOfSpikes; + + fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations); + if numberOfSpikes-numberOfCorrectEstimations > 0 + fprintf('Number of uncorrect spike detections is %d\n', numberOfSpikes-numberOfCorrectEstimations); + end + fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes); + fprintf('Average error of spike index estimation is %d\n\n', averageEstimationError); + + %% Q.2.4 + features(numberOfCorrectEstimations, 5) = 0; + + for spike=1:numberOfCorrectEstimations + % finds index of max + [maxVal, features(spike, 1)] = max(correctSpikes(spike, :)); + % calculates Vpeak-to-peak + features(spike, 2) = maxVal - min(correctSpikes(spike, :)); + % calculates ZCR + features(spike, 3) = mean(abs(diff(sign(correctSpikes(spike, :))))); + % calculates the signal energy + features(spike, 4) = sum(correctSpikes(spike, :).^2); + % calculates the fft of the signal + asud = fft(correctSpikes(spike, :)); + features(spike, 5) = asud(1)^2; + end + + figure(); + scatter(features(:, 1), features(:, 2)); + title(['Dataset #' num2str(fileIndex) ' feature plot']); + xlabel('Index of max value'); + ylabel('Peak-to-peak amplitude'); + + %% Q.2.5 + accuracy = MyClassify(features, spikeClass'); + fprintf('Accuracy achieved is %.2f%%\n\n', accuracy) + + clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ... + 'correctSpikes', 'spikeClass', 'Data'}; + clear(clearvars{:}) + clear clearvars end