'Scikit-Learn: How to retrieve prediction probabilities for a KFold CV?
I have a dataset that consists of images and associated descriptions. I've split these into two separate datasets with their own classifiers (visual and textual) and now I want to combine the predictions of these two classifiers to form a final prediction.
However, my classes are binary, either 1 or 0. I end up with two lists of n_samples filled with 1's and 0's. I assume that for most algorithms/classifiers this is not enough information to make a useful prediction (i.e. when one classifier predicts 1 and the other 0).
Therefore I thought I could use the probabilities of the predictions as some form of decisive weighting. SVC in SKlearn has the svm.SVC.predict_proba
function. Returning an array that may look like this:
[[ 0.9486674 0.0513326 ]
[ 0.97346471 0.02653529]
[ 0.9486674 0.0513326 ]]
But I can't seem to combine this with my Kfold cross-validation function cross_validation.cross_val_predict
as this is a prediction function on its own and does not include a similar probability prediction output. Is there any way to combine the two? Or am I missing something?
Possibly: Am I attacking my problem entirely wrong and is there a better way to combine the predictions of two binary classifiers?
Thanks in advance
Solution 1:[1]
You need to do a GridSearchCrossValidation instead of just CV. CV is used for performance evaluation and itself doesn't fit the estimator actually.
from sklearn.datasets import make_classification
from sklearn.svm import SVC
from sklearn.grid_search import GridSearchCV
# unbalanced classification
X, y = make_classification(n_samples=1000, weights=[0.1, 0.9])
# use grid search for tuning hyperparameters
svc = SVC(class_weight='auto', probability=True)
params_space = {'kernel': ['linear', 'poly', 'rbf']}
# set cv to your K-fold cross-validation
gs = GridSearchCV(svc, params_space, n_jobs=-1, cv=5)
# fit the estimator
gs.fit(X, y)
gs.predict_proba(X)
Out[136]:
array([[ 0.0074817 , 0.9925183 ],
[ 0.03655982, 0.96344018],
[ 0.0074933 , 0.9925067 ],
...,
[ 0.02487791, 0.97512209],
[ 0.01426704, 0.98573296],
[ 0.98574072, 0.01425928]])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Jianxun Li |