随机森林规则发现2023版

有时候在别人的建模环境是不方便安装graphviz的,这个时候可以用内置函数plot_tree
from sklearn.ensemble import RandomForestClassifier
from sklearn import tree
rf = RandomForestClassifier(n_estimators=1000,max_depth=3,min_samples_leaf=3000,max_features='sqrt',bootstrap=False)
rf.fit(X_train,y_train)
for i,dt in enumerate(rf.estimators_):
    if (y_train.groupby(dt.apply(X_train)).mean().min() <= 0.05) & \
       (y_valid.groupby(dt.apply(X_valid)).mean().min() <= 0.05) & \
       (y_oot.groupby(dt.apply(X_oot)).mean().min() <= 0.06):
        plt.figure(figsize=(12,9))
        tree.plot_tree(dt,fontsize=8)
        plt.show()
这里就是找到在三个train、valid、oot上bad rate都小于一定阈值的规则,这里是做白名单找好人

留言

熱門文章