Browse Source

Init random forest

master
Apostolos Fanakis 6 years ago
parent
commit
1f11639401
No known key found for this signature in database GPG Key ID: 56CE2DEDE9F1FB78
  1. 15
      classifier/classification_model_training/model_training.py
  2. 11
      classifier/pipeline.py

15
classifier/classification_model_training/model_training.py

@ -51,6 +51,17 @@ def simpleTrain(dataset, target, model='all'):
elif model == 'bayes':
return bayesAccuracy
def randomForest(dataset, target):
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
trainingSet, testSet, trainingTarget, testTarget = train_test_split(dataset,
target, test_size=0.4, random_state=0)
clf = RandomForestClassifier(n_estimators=500, criterion = 'entropy',
n_jobs = -1, random_state = 4)
clf = clf.fit(trainingSet, trainingTarget)
print("Random forest accuracy: {0:.2f}".format(100*clf.score(testSet, testTarget)))
def kFCrossValid(dataset, target, model = 'svm'):
from sklearn.model_selection import cross_val_score
from sklearn import metrics
@ -73,6 +84,10 @@ def kFCrossValid(dataset, target, model = 'svm'):
# Naive Bayes
from sklearn.naive_bayes import GaussianNB
clf = GaussianNB()
elif model == 'rndForest':
from sklearn.ensemble import ExtraTreesClassifier
clf = ExtraTreesClassifier(n_estimators=1500, criterion = 'entropy',
n_jobs = -1, random_state = 4)
else:
print('Error. model specified not supported')
return None

11
classifier/pipeline.py

@ -1,11 +1,18 @@
import numpy as np
from preprocessing.data_preprocessing import createSingleFeaturesArray, standardization, PCA
from classification_model_training.model_training import simpleTrain
from classification_model_training.model_training import simpleTrain, kFCrossValid
dataset, target, featureKeys = createSingleFeaturesArray(
'feature_extraction/music_features/',
'feature_extraction/speech_features/')
dataset = standardization(dataset)
# dataset = PCA(dataset)
print('Simple train accuracy achieved = ' + str(simpleTrain(dataset, target)))
kFCrossValid(dataset, target, model = 'svm')
kFCrossValid(dataset, target, model = 'rndForest')
dataset = PCA(dataset)
print('Max accuracy achieved = ' + str(simpleTrain(dataset, target)))
print('Simple train accuracy achieved = ' + str(simpleTrain(dataset, target)))
kFCrossValid(dataset, target, model = 'svm')
kFCrossValid(dataset, target, model = 'rndForest')
Loading…
Cancel
Save