机器学习调参神器--网格搜索
调好参,过好年!
·
超参数是模型中的参数中不能通过学习得到的参数。在scikit-learn中,典型的例子有支持向量分类器的参数C,kernel和gamma,Lasso的参数alpha等。在超参数集中搜索以获得最佳交叉验证分数的方法是可实现并且推荐的,网格搜索GridSearchCV应运而生!
实例
以支持向量机模型为例,训练鸢尾花数据集,搜索最优参数组合C和gamma。
from sklearn.svm import SVC
from sklearn import datasets
from sklearn.model_selection import GridSearchCV,train_test_split
# 导入鸢尾花数据
iris = datasets.load_iris()
# 设置待搜索参数及其参数值
param_grid = {'C':[0.001,0.01,0.1,1,10,100],
'gamma':[0.001,0.01,0.1,1,10,100]}
# 建立网格搜索模型
grid = GridSearchCV(SVC(),param_grid,cv=5,return_train_score=True)
# 划分数据集
X_train,X_test,y_train,y_test = train_test_split(iris.data,iris.target,random_state=0)
# 拟合训练集
grid.fit(X_train,y_train)
# 最优模型在测试集的得分
grid.score(X_test,y_test)
# 最优参数模型
grid.best_estimator_
# 最优参数组合
grid.best_params_
分析
上述的参数C和gamma分别有6个不同的取值,所以有36种不同的参数组合,利用GridSearchCV分别对训练集进行交叉验证评估,筛选出最优的参数组合,并且自动生成效果最好的最优模型,此乃神器也~~~~
哈哈哈哈哈哈哈哈哈哈哈
更多推荐
所有评论(0)