Precision and Recall#
The function precision_recall_scores_per_class in puma.utils.precision_recall_scores computes the per-class precision and recall classification metrics, for a multiclass classification task with \(N_c\) classes. The metrics are computed by comparing the classifier's predicted labels array with the target labels array. The function returns two arrays where the entry \(i\) is the precision or recall score respectively, related to class \(i\). The scores are defined as follows.
Precision#
Fixed a class \(i\) between the \(N_c\) possible classes, the precision score (also called purity) for that class measures the ability of the classifier not to label as class \(i\) a sample belonging to another class. It is defined as:
where \(tp\) is the true positives count and \(fp\) is the false positives count in the test set.
Recall#
Fixed a class \(i\) between the \(N_c\) possible classes, the recall score for that class measures the ability of the classifier to detect every sample belonging to class \(i\) in the set. It is defined as:
where \(tp\) is the true positives count and \(fn\) is the false negatives count in the test set.
Example#
"""Calculate and print the precision recall."""
from __future__ import annotations
import numpy as np
from puma.utils.precision_recall_scores import precision_recall_scores_per_class
# Sample size
N = 100
# Number of target classes
Nclass = 3
# Dummy target labels
targets = np.random.randint(0, Nclass, size=N)
# Making sure that there is at least one sample for each class
targets = np.append(targets, np.array(list(range(Nclass))))
# Dummy predicted labels
predictions = np.random.randint(0, Nclass, size=(N + Nclass))
# Unweighted precision and recall
uw_precision, uw_recall = precision_recall_scores_per_class(targets, predictions)
print("Unweighted case:")
print("Per-class precision:")
print(uw_precision)
print("Per-class recall:")
print(uw_recall)
print(" ")
# Weighted precision and recall
# Dummy sample weights
sample_weights = np.random.rand(N + Nclass)
w_precision, w_recall = precision_recall_scores_per_class(targets, predictions, sample_weights)
print("Weighted case:")
print("Per-class precision:")
print(w_precision)
print("Per-class recall:")
print(w_recall)
print(" ")