[英]Plotting a decision tree manually with pyplot
I'm new to matplotlib
and I'm trying to plot my decision tree that was built from scratch (not with sklearn
) so it's basically a Node
object with left
, right
and other identification variables which was built recursively.我是matplotlib
新手,我正在尝试绘制从头开始构建的决策树(不是使用sklearn
),因此它基本上是一个Node
对象,其中包含left
、 right
和其他递归构建的标识变量。
This is my program:这是我的程序:
def plot_tree(node, x_axis=0, y_axis=10, space=5):
if node.label is not None:
ax.text(x_axis, y_axis, node.label[0],
bbox=dict(boxstyle='round', facecolor='green', edgecolor='g'), ha='center', va='center')
else:
ax.text(x_axis, y_axis, f'{node.value:.2f}\nidx:{node.feature_idx}',
bbox=dict(boxstyle='round', facecolor='red', edgecolor='r'), ha='center', va='center')
# x2, y2, w2, h2 = t2.get_tightbbox(fig.canvas.get_renderer()).bounds
# plt.annotate(' ', xy=(x2 + w2, y2 + h2), xytext=(x_axis, y_axis), xycoords='figure points',
# arrowprops=dict(arrowstyle="<|-,head_length=1,head_width=0.5", lw=2, color='b'))
plot_tree(node.left, x_axis + space, y_axis + space)
plot_tree(node.right, x_axis + space, y_axis - space)
if __name__ == '__main__':
node = root.load_tree()
fig, ax = plt.subplots(1, 1)
ax.axis('off')
ax.set_aspect('equal')
ax.autoscale_view()
ax.set_xlim(0, 30)
ax.set_ylim(-10, 30)
plt.tick_params(axis='both', labelsize=0, length=0)
plot_tree(node)
and my result:我的结果:
I know the y axis collides because of the y_axis + space
and y_axis - space
but I don't really know how to make it stay symmetrical in its spacing and not to have this.我知道 y 轴因为y_axis + space
和y_axis - space
而发生碰撞,但我真的不知道如何使它的间距保持对称而不是这样。 And as you see the arrows are commented out because they are a mess on their own, this library is very rich and it's kinda overwhelming figuring it out.正如你看到的箭头被注释掉了,因为它们本身就是一团糟,这个库非常丰富,弄清楚它有点不知所措。
Edit: this is a print representation of the tree:编辑:这是树的打印表示:
split is at feature: 27 and value 0.14235 and depth is: 1
split is at feature: 20 and value 17.615000000000002 and depth is: 2
label is: B and depth is: 3
split is at feature: 8 and value 0.15165 and depth is: 3
label is: B and depth is: 4
label is: M and depth is: 4
split is at feature: 13 and value 13.93 and depth is: 2
label is: B and depth is: 3
label is: M and depth is: 3
You are better off using Graphviz since it will take care of spacing for you.您最好使用 Graphviz,因为它会为您处理间距。 Download Graphviz and its Python bindings , then you can render graphs pretty easily like so:下载Graphviz和它的Python bindings ,然后你可以很容易地渲染图形,如下所示:
dot = graphviz.Digraph(comment="A graph", format="svg")
dot.node('A', 'King Arthur')
dot.node('B', 'Sir Bedevere the Wise')
dot.node('C', 'Sir Lancelot the Brave')
dot.edge('A', 'B')
dot.edge('A', 'C')
dot.render('digraph.gv', view=True)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.