簡體   English   中英

更改使用導出 graphviz 創建的決策樹圖的顏色

[英]Changing colors for decision tree plot created using export graphviz

我正在使用 scikit 的回歸樹函數和 graphviz 來生成一些決策樹的精彩、易於解釋的視覺效果:

dot_data = tree.export_graphviz(Run.reg, out_file=None, 
                         feature_names=Xvar,  
                         filled=True, rounded=True,  
                         special_characters=True) 
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_png('CART.png')
graph.write_svg("CART.svg")

在此處輸入圖片說明

這運行完美,但如果可能,我想更改配色方案? 該圖表示 CO 2通量,因此我想將負值設為綠色,將正值設為棕色。 我可以導出為 svg 並手動更改所有內容,但是當我這樣做時,文本與框不太對齊,因此手動更改顏色並修復所有文本為我的工作流程添加了一個非常乏味的步驟,我真的很喜歡避免! 在此處輸入圖片說明

此外,我還看到了一些樹,其中連接節點的線的長度與拆分解釋的方差百分比成正比。 如果可能的話,我也希望能夠做到這一點?

  • 您可以通過graph.get_edge_list()獲取所有邊的列表
  • 每個源節點應該有兩個目標節點,索引較低的一個被評估為True,索引較高的為False
  • 顏色可以通過set_fillcolor()分配

在此處輸入圖片說明

import pydotplus
from sklearn.datasets import load_iris
from sklearn import tree
import collections

clf = tree.DecisionTreeClassifier(random_state=42)
iris = load_iris()

clf = clf.fit(iris.data, iris.target)

dot_data = tree.export_graphviz(clf,
                                feature_names=iris.feature_names,
                                out_file=None,
                                filled=True,
                                rounded=True)
graph = pydotplus.graph_from_dot_data(dot_data)

colors = ('brown', 'forestgreen')
edges = collections.defaultdict(list)

for edge in graph.get_edge_list():
    edges[edge.get_source()].append(int(edge.get_destination()))

for edge in edges:
    edges[edge].sort()    
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        dest.set_fillcolor(colors[i])

graph.write_png('tree.png')

另外,我已經看到一些樹,其中連接節點的線的長度與分裂解釋的方差百分比成正比。 如果可能的話,我也希望能夠做到這一點!?

您可以使用set_weight()set_len()但這有點棘手,需要一些擺弄才能正確使用,但這里有一些代碼可以幫助您入門。

for edge in edges:
    edges[edge].sort()
    src = graph.get_node(edge)[0]
    total_weight = int(src.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
    for i in range(2):
        dest = graph.get_node(str(edges[edge][i]))[0]
        weight = int(dest.get_attributes()['label'].split('samples = ')[1].split('<br/>')[0])
        graph.get_edge(edge, str(edges[edge][0]))[0].set_weight((1 - weight / total_weight) * 100)
        graph.get_edge(edge, str(edges[edge][0]))[0].set_len(weight / total_weight)
        graph.get_edge(edge, str(edges[edge][0]))[0].set_minlen(weight / total_weight)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM