Skip to content

ROC curves#

In the following a small example how to plot a roc curve with the puma API.

Then we can start the actual plotting part:

"""Produce roc curves from tagger output and labels."""

from __future__ import annotations

import numpy as np
from ftag import Flavours
from ftag.utils import calculate_rejection, get_discriminant

from puma import Roc, RocPlot
from puma.utils import get_dummy_2_taggers, logger

# The line below generates dummy data which is similar to a NN output
df = get_dummy_2_taggers(add_pt=True)

logger.info("caclulate tagger discriminants")
discs_dips = get_discriminant(
    jets=df,
    tagger="dips",
    signal=Flavours["bjets"],
    flavours=Flavours.by_category("single-btag"),
    fraction_values={
        "fc": 0.018,
        "fu": 0.982,
        "ftau": 0,
    },
)
discs_rnnip = get_discriminant(
    jets=df,
    tagger="rnnip",
    signal=Flavours["bjets"],
    flavours=Flavours.by_category("single-btag"),
    fraction_values={
        "fc": 0.018,
        "fu": 0.982,
        "ftau": 0,
    },
)

# defining target efficiency
sig_eff = np.linspace(0.49, 1, 20)

# defining boolean arrays to select the different flavour classes
is_light = df["HadronConeExclTruthLabelID"] == 0
is_c = df["HadronConeExclTruthLabelID"] == 4
is_b = df["HadronConeExclTruthLabelID"] == 5

n_jets_light = sum(is_light)
n_jets_c = sum(is_c)

logger.info("Calculate rejection")
rnnip_ujets_rej = calculate_rejection(discs_rnnip[is_b], discs_rnnip[is_light], sig_eff)
rnnip_cjets_rej = calculate_rejection(discs_rnnip[is_b], discs_rnnip[is_c], sig_eff)
dips_ujets_rej = calculate_rejection(discs_dips[is_b], discs_dips[is_light], sig_eff)
dips_cjets_rej = calculate_rejection(discs_dips[is_b], discs_dips[is_c], sig_eff)

# Define now the actual ROC curve objects
rnnip_ujets_roc = Roc(
    sig_eff=sig_eff,
    bkg_rej=rnnip_ujets_rej,
    n_test=n_jets_light,
    rej_class="ujets",
    signal_class="bjets",
    label="RNNIP",
)
dips_ujets_roc = Roc(
    sig_eff=sig_eff,
    bkg_rej=dips_ujets_rej,
    n_test=n_jets_light,
    rej_class="ujets",
    signal_class="bjets",
    label="DIPS r22",
)
rnnip_cjets_roc = Roc(
    sig_eff=sig_eff,
    bkg_rej=rnnip_cjets_rej,
    n_test=n_jets_c,
    rej_class="cjets",
    signal_class="bjets",
    label="RNNIP",
)
dips_cjets_roc = Roc(
    sig_eff=sig_eff,
    bkg_rej=dips_cjets_rej,
    n_test=n_jets_c,
    rej_class="cjets",
    signal_class="bjets",
    label="DIPS r22",
)

# ROC curve objects can also be stored as yaml or json files
rnnip_ujets_roc.save("rnnip_ujets_roc.yaml")
dips_ujets_roc.save("dips_ujets_roc.yaml")
rnnip_cjets_roc.save("rnnip_cjets_roc.yaml")
dips_cjets_roc.save("dips_cjets_roc.yaml")

# here the plotting of the roc starts
logger.info("Plotting ROC curves.")
plot_roc = RocPlot(
    n_ratio_panels=2,
    ylabel="Background rejection",
    xlabel="$b$-jet efficiency",
    atlas_second_tag="$\\sqrt{s}=13$ TeV, dummy jets \ndummy sample, $f_{c}=0.018$",
    figsize=(6.5, 6),
    y_scale=1.4,
)

# Add the ROC curve objects to the plot
plot_roc.add_roc(roc_curve=rnnip_ujets_roc, reference=True)
plot_roc.add_roc(roc_curve=dips_ujets_roc)
plot_roc.add_roc(roc_curve=rnnip_cjets_roc, reference=True)
plot_roc.add_roc(roc_curve=dips_cjets_roc)

# setting which flavour rejection ratio is drawn in which ratio panel
plot_roc.set_ratio_class(1, "ujets")
plot_roc.set_ratio_class(2, "cjets")

plot_roc.draw()
plot_roc.savefig("roc.png", transparent=False)

# If you want to load now the ROC curves from file, you can do so
loaded_rnnip_ujets_roc = Roc.load("rnnip_ujets_roc.yaml")
loaded_dips_ujets_roc = Roc.load("dips_ujets_roc.yaml")
loaded_rnnip_cjets_roc = Roc.load("rnnip_cjets_roc.yaml", colour="red")
loaded_dips_cjets_roc = Roc.load("dips_cjets_roc.yaml", colour="red")


# Now init a new plot
loaded_plot_roc = RocPlot(
    n_ratio_panels=2,
    ylabel="Background rejection",
    xlabel="$b$-jet efficiency",
    atlas_second_tag="$\\sqrt{s}=13$ TeV, Loaded ROC Curves",
    figsize=(6.5, 6),
    y_scale=1.4,
)

loaded_plot_roc.add_roc(roc_curve=rnnip_ujets_roc, reference=True)
loaded_plot_roc.add_roc(roc_curve=dips_ujets_roc)
loaded_plot_roc.add_roc(roc_curve=rnnip_cjets_roc, reference=True)
loaded_plot_roc.add_roc(roc_curve=dips_cjets_roc)

# setting which flavour rejection ratio is drawn in which ratio panel
loaded_plot_roc.set_ratio_class(1, "ujets")
loaded_plot_roc.set_ratio_class(2, "cjets")

loaded_plot_roc.draw()
loaded_plot_roc.savefig("roc_loaded.png", transparent=False)