Training and Visualizing a Decision Tree#
Data#
from sklearn.datasets import load_iris
import pandas as pd
iris = load_iris(as_frame=True)
iris_df: pd.DataFrame = iris.frame
# View 10 random samples
iris_df.sample(10, random_state=42)
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | target | |
---|---|---|---|---|---|
73 | 6.1 | 2.8 | 4.7 | 1.2 | 1 |
18 | 5.7 | 3.8 | 1.7 | 0.3 | 0 |
118 | 7.7 | 2.6 | 6.9 | 2.3 | 2 |
78 | 6.0 | 2.9 | 4.5 | 1.5 | 1 |
76 | 6.8 | 2.8 | 4.8 | 1.4 | 1 |
31 | 5.4 | 3.4 | 1.5 | 0.4 | 0 |
64 | 5.6 | 2.9 | 3.6 | 1.3 | 1 |
141 | 6.9 | 3.1 | 5.1 | 2.3 | 2 |
68 | 6.2 | 2.2 | 4.5 | 1.5 | 1 |
82 | 5.8 | 2.7 | 3.9 | 1.2 | 1 |
from sklearn.model_selection import train_test_split
X = iris.data.to_numpy()
y = iris.target.to_numpy()
X_train, X_test, y_train, y_test = train_test_split(
X, y,
test_size=0.2,
random_state=42
)
Training a Decision Tree#
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
# Define the model
tree_clf = DecisionTreeClassifier(
max_depth=2,
random_state=42
)
# Fit the model
tree_clf.fit(X_train, y_train)
# Evaluate on test dataset
y_pred = tree_clf.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.3f}")
Accuracy: 0.967
Visualization of the Tree#
from sklearn.tree import export_graphviz
export_graphviz(
tree_clf,
out_file="out/decision_tree_on_iris_dataset.dot",
feature_names=iris.feature_names,
class_names=iris.target_names,
# Each note will be represented as a rounded rectangle
rounded=True,
# The background will be colored
filled=True
)
To convert the .dot
file to an image file (e.g., PNG, JPEG), one can use the dot
command provided by Graphviz, which is an open source graph visualization software.
Basic usage:
dot -T<image format> <input dot file path> -o <output image file path>
For example, the following command generates a PNG image:
dot -Tpng out/decision_tree_on_iris_dataset.dot -o out/decision_tree_on_iris_dataset.png
And similarly,
dot -Tjpg out/decision_tree_on_iris_dataset.dot -o out/decision_tree_on_iris_dataset.jpg
will create a JPEG image.
We prefer PNG since it has higher resolution.