模型融合思想很简单,就是将多种不同类型的模型结合起来共同预测结果——”三个臭皮匠,顶个诸葛亮“。模型融合主要有以下方法:
- 平均:简单平均和加权平均
- 投票:简单投票和加权投票
- stacking:多层模型,利用预测结果再拟合预测
- blending:选取部分数据预测,得到的值作为新特征
数据来自:零基础入门金融风控-贷款违约预测-天池大赛-阿里云天池
代码 Notebook 在这里:ModelFusing,或用 nbviewer 查看。
平均和投票
这两个比较简单,平均就是多个模型的均值。投票可以用 sklearn 的 VotingClassifier:
1 | vclf = VotingClassifier( |
这里要注意下这个 voting 参数,可以选择 soft 或 hard,两者的区别在于:
- soft 是把模型预测出来的 probability 平均后取大的作为预测值
- hard 是把模型预测出来的 label 按出现次数多的作为预测值
举个例子:
1 | # 代码来自 sklearn 源码 |
其中 np.bincount
是统计非负整数数组中每个值的出现次数:
1 | np.bincount([3]) # array([0, 0, 0, 1]) 0-3 的个数 |
然后取出现次数最多的作为预测值。
Stacking
Stacking 的思想是从一系列基模型中获得的预测结果作为特征来训练模型。训练基模型时,一般会使用 K 折交叉验证。Stacking 的基本步骤如下:
- 将训练数据划分成 k 个互斥子集
- 使用某个基模型在 k-1 个子集上训练,剩下 1 个子集上测试;重复 k 次后所有训练数据都有一个预测值,将该预测值作为一个新的特征
- 用模型训练所有的训练数据(不使用 K-fold),用来预测测试集,结果同样作为新特征
- 在所有的基模型上重复步骤 2-3,将会得到新的新特征(每个模型对应一个)
- 使用所有训练集对应的新特征训练 final 模型,然后在测试集上进行预测得到最终结果
sklearn
有这个功能,不过这个很简单,我们也可以自己实现一下:
1 | from dataclasses import dataclass |
和 sklearn
的 StackingClassifier
对比一下:
1 | from sklearn.datasets import load_iris |
补充说明一下 StratifiedKFold
,它是 K-fold 的变种,folds 是通过保留每个类别样本的百分比来操作的。拿官方文档的例子来说明:
1 | from sklearn.model_selection import StratifiedKFold, KFold |
上面方括号里的两个数字分别是 label(0 和 1)的数量,显然,StratifiedKFold
在训练和测试数据集中都保留了类别的比例,而 KFold 却没有。所以如果 Label 不均衡,最好不要使用 KFold。
Blending
Blending 和 Stacking 非常相似,基本步骤如下:
- 训练集划分为 base 和 holdout 两部分
- 在 base 训练集上训练 base 模型,预测 heldout 和 测试集
- final 模型使用 heldout 的原始特征和预测结果作为特征训练
- final 模型使用测试集的原始特征和预测结果作为特征进行预测
相比 Stacking 的优势如下:
- 比 Stacking 简单
- 可以防止信息泄露(base 和 final 使用不同的数据)
不足如下:
- 使用的数据少了
- final 模型可能在 holdout 上过拟合
- CV(Cross-Validator)相比一个简单的 holdout 会更加稳固
还是直接上代码:
1 |
|
还是用上面的数据测试:
1 | bl = Blending([clf1, clf2], clff) |
以上就是常用的模型融合方法,其实设计思路都挺浅显的,实际应用时,感觉最好选择算法模型不相同(互补)的分类器,比如可以选树模型 GBDT、SVM、随机森林、KNN 等。另外要注意的是,模型融合大多见于比赛,实际应用时不光要考虑性能,很多时候还要考虑解释性,一般不会这么做。
参考资料
- Stacking and Blending — An Intuitive Explanation | by Steven Yu | Medium
- Do you want to learn about stacking, blending and ensembling machine learning models?
- sklearn.ensemble.StackingClassifier — scikit-learn 0.23.2 documentation
- 3.1. Cross-validation: evaluating estimator performance — scikit-learn 0.23.2 documentation
- Kaggle Ensembling Guide | MLWave
- team-learning-data-mining/Task5 模型融合.md at master · datawhalechina/team-learning-data-mining