Your task is to implement a function that calculates the Gini Impurity for a set of classes. Gini impurity is commonly used in decision tree algorithms to measure the impurity or disorder within a node.
Write a function gini_impurity(y)
that takes in a list of class labels y
and returns the Gini Impurity rounded to three decimal places.
Example: y = [0, 1, 1, 1, 0] print(gini_impurity(y)) # Expected Output: # 0.48
Gini impurity is a statistical measurement of the impurity or disorder in a list of elements. It is commonly used in decision tree algorithms to decide the optimal split at tree nodes. It is calculated as follows, where \( p_i \) is the probability of each class - \( \frac{n_i}{n} \):
\[ \text{Gini Impurity} = 1 - \sum_{i=1}^{C} p_i^2 \]A Gini impurity of 0 indicates a node where all elements belong to the same class, whereas a Gini impurity of 0.5 indicates maximum impurity, where elements are evenly distributed among each class. This means that a lower impurity implies a more homogeneous distribution of elements, suggesting a good split, as decision trees aim to minimize it at each node.
Advantages:
Limitations:
Suppose we have the set: [0, 1, 1, 1, 0]. The probability of each class is calculated as follows:
\[ p_{0} = \frac{2}{5} \quad p_{1} = \frac{3}{5} \]The Gini Impurity is then calculated as follows:
\[ \text{Gini Impurity} = 1 - (p_0^2 + p_1^2) = 1 - \left(\left(\frac{2}{5}\right)^2 + \left(\frac{3}{5}\right)^2\right) = 0.48 \]import numpy as np def gini_impurity(y: list[int]) -> float: classes = set(y) n = len(y) gini_impurity = 0 for cls in classes: gini_impurity += (y.count(cls)/n)**2 return round(1-gini_impurity,3)
There’s no video solution available yet 😔, but you can be the first to submit one at: GitHub link.