Ich verwende Python sklearn (Version 0.17), um das ideale Modell für einen Datensatz auszuwählen. Um dies zu tun, habe ich die folgenden Schritte ausgeführt:
cross_validation.train_test_split
mit test_size = 0.2
. GridSearchCV
, um den idealen k-nearest-neighbors-Klassifikator auf dem Trainingssatz auszuwählen. GridSearchCV
zurückgegebenen Klassifikator an plot_learning_curve
. plot_learning_curve
gab das unten gezeigte Diagramm. GridSearchCV
zurückgegebenen Klassifikator für den erhaltenen Testsatz aus. Aus der Handlung können wir sehen, dass die Punktzahl für die max. Trainingsgröße ist etwa 0,43. Dieser Score ist der von sklearn.learning_curve.learning_curve
function zurückgegebene Score.
Aber wenn ich den besten Klassifikator auf dem Test-Set betreibe, erhalte ich eine Genauigkeit von 0,61, wie von sklearn.metrics.accuracy_score
(korrekt vorhergesagte Labels / Anzahl der Labels) zurückgegeben wird
Link zum Bild:
Dies ist der Code, den ich verwende. Ich habe die Funktion plot_learning_curve
nicht eingeschlossen, da sie viel Speicherplatz beanspruchen würde. Ich habe das plot_learning_curve
von hier
Ist das normal? Ich kann verstehen, dass es einige Unterschiede in den Scores geben könnte, aber das ist eine Differenz von 0,18, die, wenn sie in Prozentsätze umgerechnet wird, 43% gegenüber 61% ist. Der Klassifikationsbericht gibt auch einen durchschnittlichen Rückruf von 0,61.
Mache ich etwas falsch? Gibt es einen Unterschied darin, wie learning_curve
Scores berechnet? Ich habe auch versucht, scoring='accuracy'
an learning_curve
function zu übergeben, um zu sehen, ob es mit der Genauigkeit übereinstimmt, aber es hat keinen Unterschied gemacht.
Jeder Rat wäre sehr hilfreich.
Ich verwende die Weinqualität (weiß) Datensatz von UCI und entfernte auch die Kopfzeile vor dem Ausführen des Codes.
Wenn Sie die Funktion learning_curve
aufrufen, führt sie eine Kreuzvalidierung für Ihre gesamten Daten durch. Wenn Sie den Parameter cv
leer lassen, handelt es sich um eine 3-fache Kreuzvalidierungsspaltungsstrategie. Und hier kommt der knifflige Teil, denn wie in der Dokumentation angegeben: "Wenn der Schätzer ein Klassifikator ist oder wenn y weder binär noch multiklassig ist, wird KFold verwendet." Und Ihr Schätzer ist ein Klassifikator.
Also, was ist der Unterschied zwischen KFold und StratifiedKFold?
KFold = Teilt in k aufeinanderfolgende Falten ( ohne standardmäßiges Mischen )
StratifiedKFold="Die Falten werden erstellt, indem der Prozentsatz der Proben für jede Klasse beibehalten wird."
Lassen Sie uns ein einfaches Beispiel machen:
Dies erklärt Ihre geringe Genauigkeit beim Zeichnen Ihrer Lernkurve (~ 0,43%). Natürlich ist dies ein Extrembeispiel , um die Situation zu veranschaulichen, aber Ihre Daten sind irgendwie strukturiert und Sie müssen sie mischen.
Wenn Sie jedoch die Genauigkeit von ~ 61% erhalten, haben Sie die Daten nach der Methode train_test_split
aufgeteilt, die standardmäßig Daten mischt und die Proportionen beibehält.
Sehen Sie sich das an, ich habe einen einfachen Test durchgeführt, um meine Hypothese zu stützen:
%Vor% In Ihrem Beispiel haben Sie die learning_curve
mit all Ihren Daten X,y
gefüllt. Ich mache hier einen kleinen Trick, um Daten zu teilen, die test_size=0.
bedeuten, alle Daten sind in train
Variablen. Auf diese Weise behalte ich immer noch alle Daten, aber es wird jetzt gemischt, während es durch die Funktion train_test_split
ging.
Dann habe ich deine Plotfunktion aufgerufen, aber mit gemischten Daten:
%Vor% Nun ist die Ausgabe mit max num Trainingsbeispielen anstelle von 0.43
0.59
, was bei Ihren GridSearch Ergebnissen viel sinnvoller ist.
Beobachtung : Ich denke, der entscheidende Punkt bei der Planung der Lernkurve ist, ob unser Schätzer bessere Ergebnisse erzielen kann oder nicht. So können Sie beispielsweise entscheiden, wann Es ist nicht notwendig, weitere Beispiele hinzuzufügen). Wie in
train_sizes
füttern Sie nur die Wertenp.linspace(.05, 1.0, 5) --> [ 0.05 , 0.2875, 0.525 , 0.7625, 1. ]
Ich bin mir nicht ganz sicher, ob dies die Verwendung ist, die Sie in dieser Art von Test verfolgen.
Tags und Links python scikit-learn