Browse Source

Critical fixes, Add feature extraction and model evaluation

master
Apostolos Fanakis 6 years ago
parent
commit
293784e885
  1. 11
      MyClassify.m
  2. 107
      spike_sorting.m

11
MyClassify.m

@ -0,0 +1,11 @@
function Acc = MyClassify(Data,group)
group=categorical(group);
idx=randperm(size(Data,1),floor(size(Data,1)*0.7));
train=Data(idx,:);
trGroup=group(idx);
test=Data; test(idx,:)=[];
teGroup=group; teGroup(idx)=[];
class = classify(test,train,trGroup);
Acc=(sum(class==teGroup)/numel(teGroup))*100;
end

107
spike_sorting.m

@ -11,9 +11,15 @@
%%
%% =================================================================================================
%% S.1
datasetMedians = zeros(8);
datasetFactors = zeros(8);
datasetMedians(8) = 0;
datasetFactors(8) = 0;
thresholdFactorInitValue = 3; % k starting value
thresholdFactorEndValue = 12; % k ending value
thresholdFactorStep = 0.01; % k jumping step
numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue);
numberOfSpikesPerFactor(numberOfFactors) = 0;
for fileIndex=1:8
fprintf('Loading test dataset no. %d\n', fileIndex);
filename = sprintf('dataset\\Data_Test_%d.mat', fileIndex);
@ -33,12 +39,6 @@ for fileIndex=1:8
dataMedian = median(abs(data)/0.6745);
datasetMedians(fileIndex) = dataMedian;
thresholdFactorInitValue = 2; % k starting value
thresholdFactorEndValue = 14; % k ending value
thresholdFactorStep = 0.01; % k jumping step
numberOfFactors = length(thresholdFactorInitValue:thresholdFactorStep:thresholdFactorEndValue);
numberOfSpikesPerFactor = zeros(numberOfFactors);
parfor factorIteration=1:numberOfFactors % runs for each k
% builds threshold
thresholdFactor = thresholdFactorInitValue + (factorIteration - 1) * thresholdFactorStep;
@ -82,7 +82,7 @@ for fileIndex=1:8
[minValue, closestIndex] = min(abs(numberOfSpikesTrimmed-Dataset.spikeNum));
datasetFactors(fileIndex) = thresholdFactorInitValue + (closestIndex - 1) * thresholdFactorStep;
clear dataset
clear Dataset
clear data
end
fprintf('\n');
@ -94,7 +94,8 @@ title('Polynomial curve fitting on median-threshold factor value pairs');
xlabel('Dataset median');
ylabel('Threshold factor');
hold on;
empiricalRule = polyfit(datasetMedians, datasetFactors, 8);
empiricalRule = polyfit(datasetMedians, datasetFactors, 3);
visualizationX = linspace(0, 0.5, 50);
visualizationY = polyval(empiricalRule, visualizationX);
plot(visualizationX, visualizationY);
@ -137,14 +138,24 @@ for fileIndex=1:4
while sample <= length(data)
sample = sample + 1;
if (data(sample) <= threshold)
% finds the index of the max sample for this spike
spikeEndIndex = sample;
[~, minIndex] = min(data(spikeStartIndex:spikeEndIndex));
[~, maxIndex] = max(data(spikeStartIndex:spikeEndIndex));
firstIndex = min([minIndex maxIndex]);
[~, relativeMaxIndex] = max(data(spikeStartIndex:spikeEndIndex));
absoluteMaxIndex = spikeStartIndex - 1 + relativeMaxIndex;
% defines an area of -41/+22 samples around the max
% and searches for the min
[~, relativeMinIndex] = min(data(absoluteMaxIndex-41:absoluteMaxIndex+22));
absoluteMinIndex = absoluteMaxIndex - 41 + relativeMinIndex;
% discernes the extrema (minimum or maximum) that
% occurs first
firstIndex = min([absoluteMaxIndex absoluteMinIndex]);
% Q.2.1
spikesTimesEst(numberOfSpikes) = firstIndex;
% Q.2.2
spikesEst(numberOfSpikes, :) = data(spikeStartIndex-1+firstIndex-34:spikeStartIndex-1+firstIndex+29);
spikesEst(numberOfSpikes, :) = data(firstIndex-34:firstIndex+29);
break;
end
end
@ -154,13 +165,79 @@ for fileIndex=1:4
fprintf('%d spikes found for dataset #%d\n', numberOfSpikes, fileIndex);
fprintf('actual number of spikes = %d\n', length(Dataset.spikeTimes));
fprintf('diff = %d\n\n', abs(length(Dataset.spikeTimes) - numberOfSpikes));
fprintf('diff = %d\n\n', numberOfSpikes - length(Dataset.spikeTimes));
figure();
hold on;
for spike=1:numberOfSpikes
plot(1:64, spikesEst(spike, :));
end
title(['Spikes of dataset #' num2str(fileIndex) ' aligned at the first extrema']);
xlabel('Samples');
ylabel('Trivial Unit');
drawnow;
hold off;
%% Q.2.3
realSpikeIndex = double(Dataset.spikeTimes);
numberOfCorrectEstimations = 0;
numberOfUndetectedSpikes = 0;
averageEstimationError = 0;
correctSpikes(2500, 64) = 0;
spikeClass(2500) = 0;
for trueSpikeIndex=1:length(realSpikeIndex)
[estimationError, closestIndex] = min(abs(spikesTimesEst-realSpikeIndex(trueSpikeIndex)));
if estimationError < 32
numberOfCorrectEstimations = numberOfCorrectEstimations + 1;
averageEstimationError = averageEstimationError + estimationError;
correctSpikes(numberOfCorrectEstimations, :) = spikesEst(closestIndex, :);
spikeClass(numberOfCorrectEstimations) = double(Dataset.spikeClass(trueSpikeIndex));
else
numberOfUndetectedSpikes = numberOfUndetectedSpikes + 1;
end
end
averageEstimationError = averageEstimationError / numberOfSpikes;
fprintf('Number of correct spike detections is %d\n', numberOfCorrectEstimations);
if numberOfSpikes-numberOfCorrectEstimations > 0
fprintf('Number of uncorrect spike detections is %d\n', numberOfSpikes-numberOfCorrectEstimations);
end
fprintf('Number of undetected spikes is %d\n', numberOfUndetectedSpikes);
fprintf('Average error of spike index estimation is %d\n\n', averageEstimationError);
%% Q.2.4
features(numberOfCorrectEstimations, 5) = 0;
for spike=1:numberOfCorrectEstimations
% finds index of max
[maxVal, features(spike, 1)] = max(correctSpikes(spike, :));
% calculates Vpeak-to-peak
features(spike, 2) = maxVal - min(correctSpikes(spike, :));
% calculates ZCR
features(spike, 3) = mean(abs(diff(sign(correctSpikes(spike, :)))));
% calculates the signal energy
features(spike, 4) = sum(correctSpikes(spike, :).^2);
% calculates the fft of the signal
asud = fft(correctSpikes(spike, :));
features(spike, 5) = asud(1)^2;
end
figure();
scatter(features(:, 1), features(:, 2));
title(['Dataset #' num2str(fileIndex) ' feature plot']);
xlabel('Index of max value');
ylabel('Peak-to-peak amplitude');
%% Q.2.5
accuracy = MyClassify(features, spikeClass');
fprintf('Accuracy achieved is %.2f%%\n\n', accuracy)
clearvars = {'spikesTimesEst', 'spikesEst', 'data', 'features', 'realSpikeIndex', ...
'correctSpikes', 'spikeClass', 'Data'};
clear(clearvars{:})
clear clearvars
end

Loading…
Cancel
Save