tabnet 变量重要性
feature_imp = pd.DataFrame(sorted(zip(clf.feature_importances_,X_train.columns)), columns=['Value','Feature'])
plt.figure(figsize=(12, 20))
sns.barplot(x="Value", y="Feature", data=feature_imp.sort_values(by="Value", ascending=False).iloc[:20,:])
plt.title('TabNet Features Importance')
plt.show()
留言
張貼留言