简体   繁体   中英

How can i plot a truncated dendrogram plot using plotly?

I want to plot a dendrogram plot for hierarchical clustering using plotly and show a small subset of the plot as with the large number of samples the plot can be very dense at the bottom.

I have plotted the plot using the plotly wrapper function create_dendrogram with the below code:

from scipy.cluster.hierarchy import linkage
import plotly.figure_factory as ff
fig = ff.create_dendrogram(test_df, linkagefun=lambda x: linkage(test_df, 'average', metric='euclidean'))
fig.update_layout(autosize=True, hovermode='closest')
fig.update_xaxes(mirror=False, showgrid=True, showline=False, showticklabels=False)
fig.update_yaxes(mirror=False, showgrid=True, showline=True)
fig.show()

在此处输入图像描述

And below is the plot using matplotlib which is used by default by the scipy library truncated to 4 levels for ease of interpretation:

from scipy.cluster.hierarchy import dendrogram,linkage
x = linkage(test_df,method='average')
dendrogram(x,truncate_mode='level',p=4)
plt.show()

在此处输入图像描述

As you can see the truncation is very useful to interpret large number of samples, how can i acheive this in plotly?

There does not seem to be a straight-forward way to do this with ff.create_dendrogram() . This does not mean it's impossible though. But I would at least consider the brilliant functionalities that Dash Clustergram has to offer. If you insist on sticking to ff.create_dendrogram() , this is going to get a bit more messy than Plotly users rightfully have grown accustomed to. You haven't provided a data sample, so let's use the Plotly Basic Dendrogram example instead:

Plot 1

在此处输入图像描述

Code 1

import plotly.figure_factory as ff
import numpy as np
np.random.seed(1)

X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = ff.create_dendrogram(X)
fig.update_layout(width=800, height=500)
f = fig.full_figure_for_development(warn=False)
fig.show()

The good news is that the exact same snippet will produce the following truncated plot after we've taken a few steps that I'll explain in the details below.

Plot 2

在此处输入图像描述

The details

If anyone who got this far in my answer knows a better way to do the following, then please share.

1. ff.create_dendrogram() is a wrapper for scipy.cluster.hierarchy.dendrogram

You can call help(ff.create_dendrogram) and learn that:

[...]This is a thin wrapper around scipy.cluster.hierarchy.dendrogram.

From the available arguments you can also see that none seem to handle anything related to truncating:

create_dendrogram(X, orientation='bottom', labels=None, colorscale=None, distfun=None, linkagefun=<function at 0x0000016F09D4CEE0>, hovertext=None, color_threshold=None)

2. Take a closer look at scipy.cluster.hierarchy.dendrogram

Here we can see that some central elements have been left out after implementing the function in ff.create_dendrogram(X) when we compare it to the source :

scipy.cluster.hierarchy.dendrogram(Z, p=30, truncate_mode=None, color_threshold=None, get_leaves=True, orientation='top', labels=None, count_sort=False, distance_sort=False, show_leaf_counts=True, no_plot=False, no_labels=False, leaf_font_size=None, leaf_rotation=None, leaf_label_func=None, show_contracted=False, link_color_func=None, ax=None, above_threshold_color='C0')

truncate_mode should be exactly what we're looking for. So, now we know that scipy probably has all we need to build the foundation for a truncated dendrogram, but what's next?

3. Find where scipy.cluster.hierarchy.dendrogram is hiding in ff.create_dendrogram(X)

ff.create_dendrogram.__code__ will reveal where the source code exists in your system. In my case this is:

"C:\Users\vestland\Miniconda3\envs\dashy\lib\site-packages\plotly\figure_factory\_dendrogram.py"

So if you would like you can take a closer look at the complete source in your corresponding folder. If you do, you'll see one particularly interesting section where some attributes that we have listed above are taken care of:

def get_dendrogram_traces(
    self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
    """
    Calculates all the elements needed for plotting a dendrogram.
.
.
.
P = sch.dendrogram(
        Z,
        orientation=self.orientation,
        labels=self.labels,
        no_plot=True,
        color_threshold=color_threshold,
    )

Here we are at the very core of the problem. And the first step to a complete answer to your question is simply to include truncate_mode and p in P like this:

P = sch.dendrogram(
    Z,
    orientation=self.orientation,
    labels=self.labels,
    no_plot=True,
    color_threshold=color_threshold,
    truncate_mode = 'level',
    p = 2
)

And here's how you do that:

4. Monkey patching

In Python, the term monkey patch only refers to dynamic modifications of a class or module at runtime, which means monkey patch is a piece of Python code that extends or modifies other code at runtime. And here's the essence of how you can do exactly that in our case:

import plotly.figure_factory._dendrogram as original_dendrogram
original_dendrogram._Dendrogram.get_dendrogram_traces = modified_dendrogram_traces

Where modified_dendrogram_traces is the complete function definition of modified_dendrogram_traces() with the amendments I've already mentioned. As well as a few imports that will be missing that otherwise are run when you call import plotly.figure_factory as ff

Enough details for now. Below is the whole thing. If this is something you can use, we could perhaps make the whole thing a bit more dynamical than hardcoding truncate_mode = 'level' and p = 2 .

Complete code:

from scipy.cluster.hierarchy import linkage
import plotly.figure_factory as ff
import plotly.figure_factory._dendrogram as original_dendrogram
import numpy as np

def modified_dendrogram_traces(
    self, X, colorscale, distfun, linkagefun, hovertext, color_threshold
):
    """
    Calculates all the elements needed for plotting a dendrogram.

    :param (ndarray) X: Matrix of observations as array of arrays
    :param (list) colorscale: Color scale for dendrogram tree clusters
    :param (function) distfun: Function to compute the pairwise distance
                               from the observations
    :param (function) linkagefun: Function to compute the linkage matrix
                                  from the pairwise distances
    :param (list) hovertext: List of hovertext for constituent traces of dendrogram
    :rtype (tuple): Contains all the traces in the following order:
        (a) trace_list: List of Plotly trace objects for dendrogram tree
        (b) icoord: All X points of the dendrogram tree as array of arrays
            with length 4
        (c) dcoord: All Y points of the dendrogram tree as array of arrays
            with length 4
        (d) ordered_labels: leaf labels in the order they are going to
            appear on the plot
        (e) P['leaves']: left-to-right traversal of the leaves

    """
    import plotly
    from plotly import exceptions, optional_imports
    np = optional_imports.get_module("numpy")
    scp = optional_imports.get_module("scipy")
    sch = optional_imports.get_module("scipy.cluster.hierarchy")
    scs = optional_imports.get_module("scipy.spatial")
    sch = optional_imports.get_module("scipy.cluster.hierarchy")
    d = distfun(X)
    Z = linkagefun(d)
    P = sch.dendrogram(
        Z,
        orientation=self.orientation,
        labels=self.labels,
        no_plot=True,
        color_threshold=color_threshold,
        truncate_mode = 'level',
        p = 2
    )

    icoord = scp.array(P["icoord"])
    dcoord = scp.array(P["dcoord"])
    ordered_labels = scp.array(P["ivl"])
    color_list = scp.array(P["color_list"])
    colors = self.get_color_dict(colorscale)

    trace_list = []

    for i in range(len(icoord)):
        # xs and ys are arrays of 4 points that make up the '∩' shapes
        # of the dendrogram tree
        if self.orientation in ["top", "bottom"]:
            xs = icoord[i]
        else:
            xs = dcoord[i]

        if self.orientation in ["top", "bottom"]:
            ys = dcoord[i]
        else:
            ys = icoord[i]
        color_key = color_list[i]
        hovertext_label = None
        if hovertext:
            hovertext_label = hovertext[i]
        trace = dict(
            type="scatter",
            x=np.multiply(self.sign[self.xaxis], xs),
            y=np.multiply(self.sign[self.yaxis], ys),
            mode="lines",
            marker=dict(color=colors[color_key]),
            text=hovertext_label,
            hoverinfo="text",
        )

        try:
            x_index = int(self.xaxis[-1])
        except ValueError:
            x_index = ""

        try:
            y_index = int(self.yaxis[-1])
        except ValueError:
            y_index = ""

        trace["xaxis"] = "x" + x_index
        trace["yaxis"] = "y" + y_index

        trace_list.append(trace)

    return trace_list, icoord, dcoord, ordered_labels, P["leaves"]

original_dendrogram._Dendrogram.get_dendrogram_traces = modified_dendrogram_traces
X = np.random.rand(15, 12) # 15 samples, with 12 dimensions each
fig = ff.create_dendrogram(X)
fig.update_layout(width=800, height=500)
f = fig.full_figure_for_development(warn=False)
fig.show()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM