Back to Problems

Generate a Confusion Matrix for Binary Classification

Task: Generate a Confusion Matrix

Your task is to implement the function confusion_matrix(data) that generates a confusion matrix for a binary classification problem. The confusion matrix provides a summary of the prediction results on a classification problem, allowing you to visualize how many data points were correctly or incorrectly labeled.

Input: A list of lists, where each inner list represents a pair [y_true, y_pred] for one observation.

Output: A \(2 X 2\) confusion matrix.

Example

Example:
data = [[1, 1], [1, 0], [0, 1], [0, 0], [0, 1]]
print(confusion_matrix(data))
Output:
[[1, 1], [2, 1]]

Generate Confusion Matrix

The confusion matrix is a very useful tool to get a better understanding of the performance of a classification model. In it, you can visualize how many data points were labeled according to their correct categories.

For a binary classification problem of a dataset with \(n\) observations, the confusion matrix is a \(2 \times 2\) matrix with the following structure:

\[ M = \begin{pmatrix} TP & FN \\ FP & TN \end{pmatrix} \] Where:
  • TP: True positives, the number of observations from the positive label that were correctly labeled as positive
  • FN: False negatives, the number of observations from the positive label that were incorrectly labeled as negative
  • FP: False positives, the number of observations from the negative label that were incorrectly labeled as positive
  • TN: True negatives, the number of observations from the negative label that were correctly labeled as negative
A confusion matrix is a great starting point for computing more advanced metrics such as precision and recall that capture the model's performance.
from collections import Counter

def confusion_matrix(data):
    # Count all occurrences
    counts = Counter(tuple(pair) for pair in data)
    # Get metrics
    TP, FN, FP, TN = counts[(1, 1)], counts[(1, 0)], counts[(0, 1)], counts[(0, 0)]
    # Define matrix and return
    confusion_matrix = [[TP, FN], [FP, TN]]
    return confusion_matrix

There’s no video solution available yet 😔, but you can be the first to submit one at: GitHub link.

Your Solution

Output will be shown here.