簡體   English   中英

如何使用 sklearn plot_tree 更改決策樹 plot 的 colors?

[英]How to change colors for decision tree plot using sklearn plot_tree?

How to change colors in decision tree plot using sklearn.tree.plot_tree without using graphviz as in this question: Changing colors for decision tree plot created using export graphviz ?

plt.figure(figsize=[21, 6])
ax1 = plt.subplot(121)
ax2 = plt.subplot(122)

ax1.plot(X[:, 0][y == 0], X[:, 1][y == 0], "bo")
ax1.plot(X[:, 0][y == 1], X[:, 1][y == 1], "g^")
ax1.contourf(xx, yy, pred.reshape(xx.shape), cmap=matplotlib.colors.ListedColormap(['b', 'g']), alpha=0.25)
ax1.set_title(title)

plot_tree(tree_clf, feature_names=["X", "y"], class_names=["blue", "green"], filled=True, rounded=True)

在此處輸入圖像描述

許多 matplotlib 函數遵循顏色循環器來分配默認 colors,但這里似乎並不適用。

以下方法循環通過生成的藝術家和結構來分配顏色,具體取決於多數 class 和雜質 (gini)。 請注意,我們不能使用 alpha,因為透明背景會顯示通常隱藏的箭頭部分。

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap, to_rgb
import numpy as np
from sklearn import tree

# generate some test data, separated by a diagonal line
X = np.random.rand(50, 2)
y = X[:, 0] - X[:, 1] > 0.2

clf = tree.DecisionTreeClassifier(random_state=2021)
clf = clf.fit(X, y)

fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=[21, 6])

colors = ['dodgerblue', 'limegreen']
ax1.plot(X[:, 0][y == 0], X[:, 1][y == 0], "bo")
ax1.plot(X[:, 0][y == 1], X[:, 1][y == 1], "g^")
xx, yy = np.meshgrid(np.linspace(0, 1, 100), np.linspace(0, 1, 100))
pred = clf.predict(np.c_[(xx.ravel(), yy.ravel())])
ax1.contourf(xx, yy, pred.reshape(xx.shape), cmap=ListedColormap(colors), alpha=0.25)

# ax2.set_prop_cycle(mpl.cycler(color=colors)) # doesn't seem to work

artists = tree.plot_tree(clf, feature_names=["X", "y"], class_names=["blue", "green"], filled=True, rounded=True)
for artist, impurity, value in zip(artists, clf.tree_.impurity, clf.tree_.value):
    # let the max value decide the color; whiten the color depending on impurity (gini)
    r, g, b = to_rgb(colors[np.argmax(value)])
    artist.get_bbox_patch().set_facecolor((r + (1 - r) * impurity, g + (1 - g) * impurity, b + (1 - b) * impurity))

plt.tight_layout()
plt.show()

改變 sklearn plot_tree 的顏色

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM