Precision and Recall#
The function precision_recall_scores_per_class
in puma.utils.precision_recall_scores
computes the per-class precision and recall classification metrics, for a multiclass classification task with N_c classes. The metrics are computed by comparing the classifier's predicted labels array with the target labels array. The function returns two arrays where the entry i is the precision or recall score respectively, related to class i. The scores are defined as follows.
Precision#
Fixed a class i between the N_c possible classes, the precision score (also called purity) for that class measures the ability of the classifier not to label as class i a sample belonging to another class. It is defined as:
where tp is the true positives count and fp is the false positives count in the test set.
Recall#
Fixed a class i between the N_c possible classes, the recall score for that class measures the ability of the classifier to detect every sample belonging to class i in the set. It is defined as:
where tp is the true positives count and fn is the false negatives count in the test set.
Example#
from __future__ import annotations
import numpy as np
from puma.utils.precision_recall_scores import precision_recall_scores_per_class
# Sample size
N = 100
# Number of target classes
Nclass = 3
# Dummy target labels
targets = np.random.randint(0, Nclass, size=N)
# Making sure that there is at least one sample for each class
targets = np.append(targets, np.array(list(range(Nclass))))
# Dummy predicted labels
predictions = np.random.randint(0, Nclass, size=(N + Nclass))
# Unweighted precision and recall
uw_precision, uw_recall = precision_recall_scores_per_class(targets, predictions)
print("Unweighted case:")
print("Per-class precision:")
print(uw_precision)
print("Per-class recall:")
print(uw_recall)
print(" ")
# Weighted precision and recall
# Dummy sample weights
sample_weights = np.random.rand(N + Nclass)
w_precision, w_recall = precision_recall_scores_per_class(targets, predictions, sample_weights)
print("Weighted case:")
print("Per-class precision:")
print(w_precision)
print("Per-class recall:")
print(w_recall)
print(" ")