How To Plot A Confusion Matrix In Python
In this post I will demonstrate how to plot the Confusion Matrix. I will be using the confusion martrix from the Scikit-Learn library (sklearn.metrics
) and Matplotlib for displaying the results in a more intuitive visual format.
The documentation for Confusion Matrix is pretty good, but I struggled to find a quick way to add labels and visualize the output into a 2x2 table.
For a good introductory read on confusion matrix check out this great post:
http://www.dataschool.io/simple-guide-to-confusion-matrix-terminology
This is a mockup of the look I am trying to achieve:
- TN = True Negative
- FN = False Negative
- FP = False Positive
- TP = True Positive
Let’s go through a quick Logistic Regression example using Scikit-Learn. For data I will use the popular Iris dataset (to read more about it reference https://en.wikipedia.org/wiki/Iris_flower_data_set).
We will use the confusion matrix to evaluate the accuracy of the classification and plot it using matplotlib:
sepal length (cm) | sepal width (cm) | petal length (cm) | petal width (cm) | Target | |
---|---|---|---|---|---|
5.1 | 3.5 | 1.4 | 0.2 | 0 | |
4.9 | 3.0 | 1.4 | 0.2 | 0 | |
4.7 | 3.2 | 1.3 | 0.2 | 0 | |
4.6 | 3.1 | 1.5 | 0.2 | 0 | |
5.0 | 3.6 | 1.4 | 0.2 | 0 |
We can examine our data quickly using Pandas correlation function to pick a suitable feature for our logistic regression. We will use the default pearson method.
So, let’s pick the two with highest potential: Petal Width (cm) and Petal Lengthh (cm) as our (X) independent variables. For our Target/dependent variable (Y) we can pick the Versicolor class. The Target class actually has three choices, to simplify our task and narrow it down to a binary classifier I will pick Versicolor to narrow our classification classes to (0 or 1): either it is versicolor (1) or it is Not versicolor (0).
Let’s now create our X and Y:
We will split our data into a test and train sets, then start building our Logistic Regression model. We will use an 80/20 split.
Before we create our classifier, we will need to normalize the data (feature scaling) using the utility function StandardScalar
part of Scikit-Learn preprocessing package.
Now we are ready to build our Logistic Classifier:
Now, let’s evaluate our classifier with the confusion matrix:
Visually the above doesn’t easily convey how is our classifier performing, but we mainly focus on the top right and bottom left (these are the errors or misclassifications).
The confusion matrix tells us we a have total of 15 (13 + 2) misclassified data out of the 30 test points (in terms of: Versicolor, or Not Versicolor). A better way to visualize this can be accomplished with the code below:
To plot and display the decision boundary that separates the two classes (Versicolor or Not Versicolor ):
Will from the two plots we can easily see that the classifier is not doing a good job. And before digging into why (which will be another post on how to determine if data is linearly separable or not), we can assume that it’s because the data is not linearly separable (for the IRIS dataset in fact only setosa class is linearly separable).
We can try another non-linear classifier, in this case we can use SVM with a Gaussian RBF Kernel:
Here is the plot to show the decision boundary
SVM with RBF Kernel produced a significant improvement: down from 15 misclassifications to only 1.
Hope this helps.
note: code was written using Jupyter Notebook