Training and Visualizing a Decision Tree#

Data#

from sklearn.datasets import load_iris
import pandas as pd

iris = load_iris(as_frame=True)
iris_df: pd.DataFrame = iris.frame

# View 10 random samples
iris_df.sample(10, random_state=42)
sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) target
73 6.1 2.8 4.7 1.2 1
18 5.7 3.8 1.7 0.3 0
118 7.7 2.6 6.9 2.3 2
78 6.0 2.9 4.5 1.5 1
76 6.8 2.8 4.8 1.4 1
31 5.4 3.4 1.5 0.4 0
64 5.6 2.9 3.6 1.3 1
141 6.9 3.1 5.1 2.3 2
68 6.2 2.2 4.5 1.5 1
82 5.8 2.7 3.9 1.2 1
from sklearn.model_selection import train_test_split

X = iris.data.to_numpy()
y = iris.target.to_numpy()

X_train, X_test, y_train, y_test = train_test_split(
    X, y,
    test_size=0.2,
    random_state=42
)

Training a Decision Tree#

from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

# Define the model
tree_clf = DecisionTreeClassifier(
    max_depth=2,
    random_state=42
)

# Fit the model
tree_clf.fit(X_train, y_train)

# Evaluate on test dataset
y_pred = tree_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.3f}")
Accuracy: 0.967

Visualization of the Tree#

from sklearn.tree import export_graphviz

export_graphviz(
    tree_clf,
    out_file="out/decision_tree_on_iris_dataset.dot",
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    
    # Each note will be represented as a rounded rectangle
    rounded=True,
    
    # The background will be colored
    filled=True
)

To convert the .dot file to an image file (e.g., PNG, JPEG), one can use the dot command provided by Graphviz, which is an open source graph visualization software.

Basic usage:

dot -T<image format> <input dot file path> -o <output image file path>

For example, the following command generates a PNG image:

dot -Tpng out/decision_tree_on_iris_dataset.dot -o out/decision_tree_on_iris_dataset.png

And similarly,

dot -Tjpg out/decision_tree_on_iris_dataset.dot -o out/decision_tree_on_iris_dataset.jpg

will create a JPEG image.

We prefer PNG since it has higher resolution.

A decision tree trained on iris dataset