Scikit-learn: Improve decision tree plotting in Jupyter environment

Created on 1 Feb 2016  Â·  32Comments  Â·  Source: scikit-learn/scikit-learn

Recently I used built-in visualizer for DecisionTreeClassifier in Jupyter and can say that
its interface could be better (example is taken from docs):

>>> from IPython.display import Image  
>>> dot_data = StringIO()  
>>> tree.export_graphviz(clf, out_file=dot_data,  
                         feature_names=iris.feature_names,  
                         class_names=iris.target_names,  
                         filled=True, rounded=True,  
                         special_characters=True)  
>>> graph = pydot.graph_from_dot_data(dot_data.getvalue())  
>>> Image(graph.create_png())  

In addition this does not work on Python 3, since at the time of writing pydot cannot be installed for Python 3.
The ideal solution will be something like

from sklearn import tree
tc = tree.DecisionTreeClassifier()
...
tree.plot(tc) # or even tc.plot()

but in this case tree module should depend on pydot and IPython.display.image modules.
I can fix this issue, but what is the best way to do this?

Enhancement Moderate help wanted

Most helpful comment

btw, this works:

from sklearn.tree import export_graphviz
import graphviz

export_graphviz(tree, out_file="mytree.dot")
with open("mytree.dot") as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)

And the returned graphviz value is automatically rendered in Jupyter notebook. It would be nice not to need to create a tempfile, though. So I'd like to do

from sklearn.tree import convert_to_graphviz
convert_to_graphviz(tree)

That would have an optional dependency on graphviz, though.
The other option would be to not create the object and just return the string:

from sklearn.tree import convert_to_graphviz
import graphviz

graphviz.Source(export_graphviz(tree))

All 32 comments

One of the possible solutions might be something like

class OptionalImportError(Exception):
    ...


def plot(params):
    try:
        import pydot
        from IPython.display import Image  
    except ImportError:
        raise OptionalImportError
    # do actual plotting
    ...

anyone?

I think it would be nice to make it easier to plot trees. Currently it's a bit of a pain. But just returning the tree structure from export_graphviz would mean you can show the tree in IPython by just calling export_graphviz. You could just add a tree_to_graphiviz function or something next to it, that returns a graphvis object instead of writing to a file.

btw, this works:

from sklearn.tree import export_graphviz
import graphviz

export_graphviz(tree, out_file="mytree.dot")
with open("mytree.dot") as f:
    dot_graph = f.read()
graphviz.Source(dot_graph)

And the returned graphviz value is automatically rendered in Jupyter notebook. It would be nice not to need to create a tempfile, though. So I'd like to do

from sklearn.tree import convert_to_graphviz
convert_to_graphviz(tree)

That would have an optional dependency on graphviz, though.
The other option would be to not create the object and just return the string:

from sklearn.tree import convert_to_graphviz
import graphviz

graphviz.Source(export_graphviz(tree))

Thanks!

+(np.inf ** np.inf)

This is my current setup for diagnosing trees

import matplotlib.pyplot as plt
import pygraphviz as pgv
import networkx as nx
import pygraphviz
import matplotlib.image as img
import matplotlib.pyplot as plt

from sklearn.tree import export_graphviz
from StringIO import StringIO
from io import BytesIO

def get_graph(dtc, n_classes, feat_names=None, size=[7, 7]):
    dot_file = StringIO()
    image_file = BytesIO()

    # Get the dot graph of our decision tree
    export_graphviz(dtc, out_file=dot_file, feature_names=feat_names, rounded=True, filled=True,
                    special_characters=True, class_names=map(str, range(n_classes)), max_depth=10)
    dot_file.seek(0)

    # Convert this dot graph into an image
    g = pygraphviz.AGraph(dot_file.read())
    g.layout('dot')
    # g.draw doesn't work when the image object doesn't have a name (with a proper extension)
    image_file.name = "image.png"
    image_file.seek(0)
    g.draw(path=image_file)
    image_file.seek(0)

    # Plot it
    plt.figure().set_size_inches(*size)
    plt.axis('off')
    plt.imshow(img.imread(fname=image_file))
    plt.show()

This was closed by github when merging #7342. However I think we could still do better as suggested in https://github.com/scikit-learn/scikit-learn/issues/6261#issuecomment-182125071.

I agree with @ogrisel, and think we should actually aim higher, though not possibly in scikit-learn. I'm not sure.
We can only visualize trees in environments that use conda or have some way to install graphviz. Even with conda, it's confusing (you need to conda install graphviz AND pip install graphviz). Giving installation instructions without conda is basically impossible.

There are two possible solutions: find a better way to render graphvis in python (hard and unlikely, though for browser display we could possibly use D3: https://github.com/mstefaniuk/graph-viz-d3-js),
or write a custom (decision) tree visualization, which is 1000x easier than plotting generic graphs.
We could probably even do that with matplotlib without any graph stuff.

My motivation: I like to plot decision trees in tutorials, and I'd like the readers of my book to plot decision trees. And without anaconda it's near-impossible.

@amueller if i recall correctly, I actually installed graphviz with homebrew, and looks like you can download Windows binary from their website ... (not saying this is ideal, but it doesn't involve conda)

that said, totally agree with you that it'd be worthwhile redoing this in a more integrated & modern way

I work on OS X and the graphviz stuff seems to be no longer properly supported there. So I wrote a simple ASCII based decision tree visualizer for the sklearn DecisionTreeClassifier: tree _print (see attached). One of the major benefits of decision tree models is that they are easy to understand by looking at them so having a simple way of visualizing them is important. Perhaps this simple visualizer could be directly integrated into the environment...

treeviz.txt

Yes, I quite like the idea of a text-based tree output. A pull request is
welcome, IMO. I would also consider making a limited-depth version the
default __str__ for Tree objects.

On 6 April 2017 at 20:45, lutz-hamel notifications@github.com wrote:

I work on OS X and the graphviz stuff seems to be no longer properly
supported there. So I wrote a simple ASCII based decision tree visualizer
for the sklearn DecisionTreeClassifier: tree _print (see attached). One of
the major benefits of decision tree models is that they are easy to
understand by looking at them so having a simple way of visualizing
them is important. Perhaps this simple visualizer could be directly
integrated into the environment...

treeviz.txt
https://github.com/scikit-learn/scikit-learn/files/902272/treeviz.txt

—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
https://github.com/scikit-learn/scikit-learn/issues/6261#issuecomment-292136642,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AAEz6-mBC4mpDTcO0hjf_SFoqUQ_p3wFks5rtMJWgaJpZM4HQltu
.

OK, that sounds good. I'm a github newbie when it comes to large shared
projects. Could you point me to where/how I should execute the pull
request for Jupyter and also for the tree object?

Thanks!
Lutz

On Thu, Apr 6, 2017 at 7:09 AM, Joel Nothman notifications@github.com
wrote:

Yes, I quite like the idea of a text-based tree output. A pull request is
welcome, IMO. I would also consider making a limited-depth version the
default __str__ for Tree objects.

On 6 April 2017 at 20:45, lutz-hamel notifications@github.com wrote:

I work on OS X and the graphviz stuff seems to be no longer properly
supported there. So I wrote a simple ASCII based decision tree visualizer
for the sklearn DecisionTreeClassifier: tree _print (see attached). One
of
the major benefits of decision tree models is that they are easy to
understand by looking at them so having a simple way of visualizing
them is important. Perhaps this simple visualizer could be directly
integrated into the environment...

treeviz.txt
https://github.com/scikit-learn/scikit-learn/files/902272/treeviz.txt

—
You are receiving this because you are subscribed to this thread.
Reply to this email directly, view it on GitHub
6261#issuecomment-292136642>,
or mute the thread
SFoqUQ_p3wFks5rtMJWgaJpZM4HQltu>
.

—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
https://github.com/scikit-learn/scikit-learn/issues/6261#issuecomment-292141327,
or mute the thread
https://github.com/notifications/unsubscribe-auth/AS8RYFJMDG__yKCtwUxwTOhAZ2wHazWLks5rtMfOgaJpZM4HQltu
.

I'd start by implementing export_text alongside the graphviz equivalent,
write tests alongside its tests, add the changes and commit them to a
branch. Push the branch to your fork of scikit-learn on GitHub, then create
a pull request. HTH

I can give a try to a matplotlib version.

@lutz-hamel actually looks like the pypi package for graphviz now seems to be a better way to achieve this (and currently works well on Mac): https://github.com/scikit-learn/scikit-learn/pull/9071

+1 to the concept of a very simple text-based representation of the tree. I am admittedly a newbie in the data-science/ML space (coming from a research background). Lately, I've been translating some R examples into Python. Here's how the output of a tree-based classifier appears in R (ctree() from library partykit) using the famous iris dataset:

'data.frame':   150 obs. of  5 variables:
 $ Sepal.Length: num  5.1 4.9 4.7 4.6 5 5.4 4.6 5 4.4 4.9 ...
 $ Sepal.Width : num  3.5 3 3.2 3.1 3.6 3.9 3.4 3.4 2.9 3.1 ...
 $ Petal.Length: num  1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
 $ Petal.Width : num  0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
 $ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...

Model formula:
Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width

Fitted party:
[1] root
|   [2] Petal.Length <= 1.9: setosa (n = 35, err = 0.0%)
|   [3] Petal.Length > 1.9
|   |   [4] Petal.Width <= 1.7
|   |   |   [5] Petal.Length <= 4.6: versicolor (n = 30, err = 0.0%)
|   |   |   [6] Petal.Length > 4.6: versicolor (n = 9, err = 33.3%)
|   |   [7] Petal.Width > 1.7: virginica (n = 32, err = 0.0%)

Number of inner nodes:    3
Number of terminal nodes: 4

The corresponding plot looks like this:
tree_plot

Of course, the plot is nice. But the text version has exactly the same information except for p-values, and the text version is arguably just as user-friendly as the graphical plot. Advantages of the text version are compactness, simplicity, cross-platform compatibility, and also that it would work well via the terminal.

Thoughts on that type of approach?

Best wishes,
-- Alexander.

And https://github.com/scikit-learn/scikit-learn/pull/9424 for the text version.

I think it's good to have both, and I think that's the consensus. The text version is harder to read imho since it is very asymmetrical.

The first issue I posted is a pure matplotlib implementation, so no additional packages required and cross-platform (if you have matplotlib, which I think is a reasonable requirement)

I am new to the ML . I am trying to create tree graph following amuller advise , but no luck. I am trying to do the same in IPYTHON NOTEBOOK. I get nothing..

from sklearn.tree import convert_to_graphviz
import graphviz

graphviz.Source(export_graphviz(tree)) ===> Doesn't give me any output. Any advise

You can find an easier alternative in this link https://github.com/fsbeserra/treeViz. Check the visualization of the tree here

@fsbeserra that requires dot, right? The point of this PR is not to require dot.

Sorry I was referencing #9251

Text based output would still be cool.

No dot required

@fsbeserra really? Where is the tree rendered? I thought that was done by networkx, which relies on dot for all layouting, right? Or did networkx add graph layouting somewhere?

@amueller it was a fast implementation I made for an environment in which I did not have dot. I built the layout myself. Networkx is only used to build the graph structure, but the plotting layout was customized by me. Hope it helps!

Ah, ok, you're using a much simpler algorithms than I do. The problem is that our nodes have pretty big labels so I think your algorithm would result in trees that are way too wide.

Though I guess using interactivity and mouse-overs is one way to deal with that...

@fsbeserra thanks for sharing the notebook, that's a great attempt.

Rendering a decision tree in a Jupyter notebook can now be done without a temporary file:

import graphviz
from sklearn.tree import export_graphviz

graphviz.Source(export_graphviz(tree))

Thanks for that Terje. We also recently merged #9251 which plots trees using an improved layout with matplotlib.

Was this page helpful?
0 / 5 - 0 ratings