scikit-learn的tree.plot_tree很简单吗?很方便,所以我尝试总结一下如何轻松使用它


经常使用以Python实现的机器学习库scikit-learn,因为它可以轻松地尝试各种算法。说到花的形状,TensorFlow和PyTorch在刚性领域很难使用。 .. ..通过这种scikit-learn,从版本0.21.x开始实现了一种便于学习"决策树"后进行绘图的功能,这是一种典型的监督学习方法,因此我在与常规方法进行比较的同时进行了尝试GraphViz。

传统的可视化方法:使用GraphViz

以前,我安装并使用了另一个名为GraphViz的库。这需要很多时间和精力。 .. ..

安装GraphViz @Mac

1
2
brew install graphviz
pip install graphviz

安装GraphViz @Ubuntu

1
2
sudo apt install -y graphviz
pip install graphviz

使用GraphViz的方法

1
2
3
4
5
6
7
8
9
10
import graphviz
from sklearn import tree
from sklearn.datasets import load_iris

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)

graph = graphviz.Source(tree.export_graphviz(clf, class_names=iris.feature_names, filled=True))
graph

执行结果

通过执行graph.render('decision_tree'),可以将

的执行结果另存为PDF。

graphviz

使用tree.plot_tree

让我们绘制一个类似于使用

tree.plot_tree使用GraphViz绘制的图形。由于它存储在scikit-learn的tree模块中,因此不需要其他安装。 (默认情况下,filled选项为False,但是如果设置为True,它将被上色)

使用tree.plot_tree的方法

1
2
3
4
5
6
7
8
9
10
11
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, plot_tree

iris = load_iris()
clf = DecisionTreeClassifier()
clf = clf.fit(iris.data, iris.target)
iris = load_iris()
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, filled=True)
plt.show()

执行结果

我能够输出与使用GraphViz的方法相同的图形。如果在Jupyter Notebook上执行它,则可以按原样右键单击绘图结果并将其另存为图像。

plot_tree

2020/11/27后记:类别名称也显示在决策树

您还可以通过添加

class_name选项来显示最终的分类类名称。

1
2
3
plt.figure(figsize=(15, 10))
plot_tree(clf, feature_names=iris.feature_names, class_names=iris.target_names, filled=True)
plt.show()

plot_tree2

概要

使用scikit-learn的tree.plot_tree和常规GraphViz进行可视化决策树的方法,我意识到tree.plot_tree比常规方法更容易,更方便。我想在将来积极利用它。

参考

  • 1.10。决策树
  • sklearn.tree.tree_plot