From 1f1163940118712d449b3cc2d1f869a03ba541fc Mon Sep 17 00:00:00 2001 From: Apostolof Date: Sat, 8 Dec 2018 13:26:43 +0200 Subject: [PATCH] Init random forest --- .../model_training.py | 15 +++++++++++++++ classifier/pipeline.py | 11 +++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/classifier/classification_model_training/model_training.py b/classifier/classification_model_training/model_training.py index 69d980f..c506012 100644 --- a/classifier/classification_model_training/model_training.py +++ b/classifier/classification_model_training/model_training.py @@ -51,6 +51,17 @@ def simpleTrain(dataset, target, model='all'): elif model == 'bayes': return bayesAccuracy +def randomForest(dataset, target): + from sklearn.ensemble import RandomForestClassifier + from sklearn.model_selection import train_test_split + + trainingSet, testSet, trainingTarget, testTarget = train_test_split(dataset, + target, test_size=0.4, random_state=0) + clf = RandomForestClassifier(n_estimators=500, criterion = 'entropy', + n_jobs = -1, random_state = 4) + clf = clf.fit(trainingSet, trainingTarget) + print("Random forest accuracy: {0:.2f}".format(100*clf.score(testSet, testTarget))) + def kFCrossValid(dataset, target, model = 'svm'): from sklearn.model_selection import cross_val_score from sklearn import metrics @@ -73,6 +84,10 @@ def kFCrossValid(dataset, target, model = 'svm'): # Naive Bayes from sklearn.naive_bayes import GaussianNB clf = GaussianNB() + elif model == 'rndForest': + from sklearn.ensemble import ExtraTreesClassifier + clf = ExtraTreesClassifier(n_estimators=1500, criterion = 'entropy', + n_jobs = -1, random_state = 4) else: print('Error. model specified not supported') return None diff --git a/classifier/pipeline.py b/classifier/pipeline.py index 1235ba4..2fdf6eb 100644 --- a/classifier/pipeline.py +++ b/classifier/pipeline.py @@ -1,11 +1,18 @@ import numpy as np from preprocessing.data_preprocessing import createSingleFeaturesArray, standardization, PCA -from classification_model_training.model_training import simpleTrain +from classification_model_training.model_training import simpleTrain, kFCrossValid dataset, target, featureKeys = createSingleFeaturesArray( 'feature_extraction/music_features/', 'feature_extraction/speech_features/') dataset = standardization(dataset) +# dataset = PCA(dataset) +print('Simple train accuracy achieved = ' + str(simpleTrain(dataset, target))) +kFCrossValid(dataset, target, model = 'svm') +kFCrossValid(dataset, target, model = 'rndForest') + dataset = PCA(dataset) -print('Max accuracy achieved = ' + str(simpleTrain(dataset, target))) \ No newline at end of file +print('Simple train accuracy achieved = ' + str(simpleTrain(dataset, target))) +kFCrossValid(dataset, target, model = 'svm') +kFCrossValid(dataset, target, model = 'rndForest') \ No newline at end of file