【Python-ML】SKlearn库网格搜索和交叉验证
生活随笔
收集整理的這篇文章主要介紹了
【Python-ML】SKlearn库网格搜索和交叉验证
小編覺得挺不錯的,現(xiàn)在分享給大家,幫大家做個參考.
# -*- coding: utf-8 -*-
'''
Created on 2018年1月18日
@author: Jason.F
@summary:
GridSearch網(wǎng)格搜索:同一模型下組合參數(shù)選擇最優(yōu);
嵌套交叉驗(yàn)證:不同模型選擇最優(yōu);
'''
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.cross_validation import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.grid_search import GridSearchCV
from sklearn.svm import SVC
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.tree import DecisionTreeClassifier
#導(dǎo)入數(shù)據(jù)
df = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/breast-cancer-wisconsin/wdbc.data',header=None)
X=df.loc[:,2:].values
y=df.loc[:,1].values
le=LabelEncoder()
y=le.fit_transform(y)#類標(biāo)整數(shù)化
print (le.transform(['M','B']))
#劃分訓(xùn)練集合測試集
X_train,X_test,y_train,y_test = train_test_split (X,y,test_size=0.20,random_state=1)
#建立pipeline
pipe_svc=Pipeline([('scl',StandardScaler()),('clf',SVC(random_state=1))])
param_range=[0.0001,0.001,0.01,0.1,1.0,10.0,100.0,1000.0]
param_grid=[{'clf__C':param_range,'clf__kernel':['linear']},\{'clf__C':param_range,'clf__gamma':param_range,'clf__kernel':['rbf']}]
#網(wǎng)格搜索
gs=GridSearchCV(estimator=pipe_svc,param_grid=param_grid,scoring='accuracy',cv=10,n_jobs=1)
gs=gs.fit(X_train,y_train)
print (gs.best_score_)
print (gs.best_params_)
#選擇最佳模型
clf=gs.best_estimator_
clf.fit(X_train,y_train)
print ('Test accuracy:%.3f' % clf.score(X_test,y_test))#嵌套交叉驗(yàn)證
gs_svm=GridSearchCV(estimator=pipe_svc,param_grid=param_grid,scoring='accuracy',cv=10,n_jobs=1)#-1
scores_svm=cross_val_score(gs_svm,X,y,scoring='accuracy',cv=5)
print ('SVM CV accuracy:%.3f +/- %.3f'%(np.mean(scores_svm),np.std(scores_svm)))#比較決策樹交叉驗(yàn)證
gs_dt=GridSearchCV(estimator=DecisionTreeClassifier(random_state=0),param_grid=[{'max_depth':[1,2,3,4,5,6,7,None]}],scoring='accuracy',cv=5)
scores_dt=cross_val_score(gs_dt,X,y,scoring='accuracy',cv=5)
print ('DT CV accuracy:%.3f +/- %.3f'%(np.mean(scores_dt),np.std(scores_dt)))
結(jié)果:
[1 0] 0.978021978022 {'clf__C': 0.1, 'clf__kernel': 'linear'} Test accuracy:0.965 SVM CV accuracy:0.972 +/- 0.012 DT CV accuracy:0.917 +/- 0.009總結(jié)
以上是生活随笔為你收集整理的【Python-ML】SKlearn库网格搜索和交叉验证的全部內(nèi)容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 【Python-ML】SKlearn库学
- 下一篇: websocket python爬虫_p