Skip to content

Plot Base

puma.plot_base.PlotLineObject dataclass #

Base data class defining properties of a plot object.

Attributes:

Name Type Description
xmin (float | None, optional)

Minimum value of the x-axis, by default None

xmax (float | None, optional)

Maximum value of the x-axis, by default None

colour (str | None, optional)

Colour of the object, by default None

label (str | None, optional)

Label of object, by default None

linestyle (str | None, optional)

Linestyle following numpy style, by default None

linewidth (float | None, optional)

Linewidth that will be used, by default None

alpha (float | None, optional)

Value for visibility of the plot lines, by default None

marker (str | None, optional)

Marker that is used in the plot. For example an x. By default None

markersize (int | None, optional)

Size of the marker. By default None

markeredgewidth (int | None, optional)

Edge width of the marker. By default None

is_marker (bool | None, optional)

Bool, to give info about if this is a marker or a line. By default None

args_to_store property #

Returns the arguments that need to be stored/loaded.

Returns:

Type Description
dict[str, typing.Any]

Dict with the arguments

decode staticmethod #

Inverse of encode, turning tags back into real objects.

Parameters:

Name Type Description Default
obj typing.Any

Object that is to be decoded

required

Returns:

Type Description
typing.Any

The decoded object

Source code in puma/plot_base.py
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
@staticmethod
def decode(obj: Any) -> Any:
    """Inverse of encode, turning tags back into real objects.

    Parameters
    ----------
    obj : Any
        Object that is to be decoded

    Returns
    -------
    Any
        The decoded object
    """
    # Setup the object so that it can be returned if no decoding is needed
    decoded: Any = obj

    # If a dict was used, go through and check for types
    if isinstance(obj, dict):
        if "__ndarray__" in obj:
            decoded = np.asarray(obj["__ndarray__"], dtype=obj["dtype"])
        elif "__label__" in obj:
            decoded = Label(**{
                k: PlotLineObject.decode(v) for k, v in obj["__label__"].items()
            })
        elif "__tuple__" in obj:
            decoded = tuple(PlotLineObject.decode(v) for v in obj["__tuple__"])
        elif "__cuts__" in obj:
            decoded = Cuts(**{k: PlotLineObject.decode(v) for k, v in obj["__cuts__"].items()})
        elif "__cut__" in obj:
            decoded = Cut(**{k: PlotLineObject.decode(v) for k, v in obj["__cut__"].items()})
        else:
            # If it's a regular dict, walk down the keys
            decoded = {k: PlotLineObject.decode(v) for k, v in obj.items()}

    # If a list was used, check that all sub-objects are correctly loaded
    elif isinstance(obj, list):
        decoded = [PlotLineObject.decode(v) for v in obj]

    # If no decoding is needed, return the object
    return decoded

encode staticmethod #

Return a JSON/YAML-safe version of obj, tagging special types.

Parameters:

Name Type Description Default
obj typing.Any

Object that is to be encoded

required

Returns:

Type Description
typing.Any

The encoded object

Source code in puma/plot_base.py
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@staticmethod
def encode(obj: Any) -> Any:
    """Return a JSON/YAML-safe version of obj, tagging special types.

    Parameters
    ----------
    obj : Any
        Object that is to be encoded

    Returns
    -------
    Any
        The encoded object
    """
    # Setup the object so that it can be returned if no encoding is needed
    encoded: Any = obj

    # Encode special cases which can't be easily stored in json and yaml
    if isinstance(obj, np.ndarray):
        encoded = {"__ndarray__": obj.tolist(), "dtype": str(obj.dtype)}
    elif isinstance(obj, Label):
        encoded = {"__label__": {k: PlotLineObject.encode(v) for k, v in obj.__dict__.items()}}
    elif isinstance(obj, tuple):
        encoded = {"__tuple__": [PlotLineObject.encode(v) for v in obj]}
    elif isinstance(obj, Cuts):
        encoded = {"__cuts__": {k: PlotLineObject.encode(v) for k, v in obj.__dict__.items()}}
    elif isinstance(obj, Cut):
        encoded = {"__cut__": {k: PlotLineObject.encode(v) for k, v in obj.__dict__.items()}}

    # For lists and dicts, walk through them and ensure correct encoding for sub-objects
    elif isinstance(obj, list):
        encoded = [PlotLineObject.encode(v) for v in obj]
    elif isinstance(obj, dict):
        encoded = {k: PlotLineObject.encode(v) for k, v in obj.items()}

    # If no encoding is needed, return the object
    return encoded

load classmethod #

Load attributes from file and construct the object without init.

Parameters:

Name Type Description Default
path str | pathlib.Path

Path in which the attributes are stored.

required
**extra_kwargs typing.Any

Extra kwargs to overwrite certain stored options.

{}

Returns:

Type Description
Class Instance

Instance of class with the given attributes.

Raises:

Type Description
ValueError

If the given file is neither json nor a yaml file.

Source code in puma/plot_base.py
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
@classmethod
def load(cls, path: str | Path, **extra_kwargs: Any) -> Self:
    """Load attributes from file and construct the object without __init__.

    Parameters
    ----------
    path : str | Path
        Path in which the attributes are stored.
    **extra_kwargs : Any
        Extra kwargs to overwrite certain stored options.

    Returns
    -------
    Class Instance
        Instance of class with the given attributes.

    Raises
    ------
    ValueError
        If the given file is neither json nor a yaml file.
    """
    # Ensure path is a path object
    path = Path(path)

    # Check if json and load it as such
    if path.suffix == ".json":
        with path.open() as f:
            data = json.load(f)

    # Check if yaml and load it as such
    elif path.suffix in {".yaml", ".yml"}:
        with path.open() as f:
            data = yaml.safe_load(f)

    # Else ValueError
    else:
        raise ValueError("Unknown file extension. Use '.json', '.yaml' or '.yml'.")

    # Convert back to numpy where appropriate
    data = cls.decode(data)

    # allow caller to override
    data.update(extra_kwargs)

    # Init the class without running __init__
    obj: Self = cls.__new__(cls)

    # Set attributes verbatim
    for key, val in data.items():
        setattr(obj, key, val)
    return obj

save #

Store class attributes in a file (json or yaml).

Saving can be performed to a yaml and a json file.

Parameters:

Name Type Description Default
path str | pathlib.Path

Path to which the class object attributes are written.

required

Raises:

Type Description
ValueError

If an unknown file extension was given

Source code in puma/plot_base.py
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
def save(self, path: str | Path) -> None:
    """Store class attributes in a file (json or yaml).

    Saving can be performed to a yaml and a json file.

    Parameters
    ----------
    path : str | Path
        Path to which the class object attributes are written.

    Raises
    ------
    ValueError
        If an unknown file extension was given
    """
    # Ensure path is a path object
    path = Path(path)

    # Get the attributes as a dict
    data = self.encode(self.args_to_store)

    # Check for json and store it as such
    if path.suffix == ".json":
        with path.open("w") as f:
            json.dump(data, f, indent=2)

    # Check for yaml and store it as such
    elif path.suffix in {".yaml", ".yml"}:
        with path.open("w") as f:
            yaml.safe_dump(data, f)

    # Else ValueError
    else:
        raise ValueError("Unknown file extension. Use '.json', '.yaml' or '.yml'!")

puma.plot_base.PlotObject dataclass #

Data base class defining properties of a plot object.

Attributes:

Name Type Description
title (str, optional)

Title of the plot, by default ""

draw_errors (bool, optional)

Draw statistical uncertainty on the lines, by default True

xmin (float | None, optional)

Minimum value of the x-axis, by default None

xmax (float | None, optional)

Maximum value of the x-axis, by default None

ymin (float | None, optional)

Minimum value of the y-axis, by default None

ymax (float | None, optional)

Maximum value of the y-axis, by default None

ymin_ratio (list[float | None] | None, optional)

Set the lower y limit of each of the ratio subplots, by default None.

ymax_ratio (list[float | None] | None, optional)

Set the upper y limit of each of the ratio subplots, by default None.

y_scale (float, optional)

Scaling up the y axis, e.g. to fit the ATLAS Tag. Applied if ymax not defined, by default 1.3

logx (bool, optional)

Set the log of x-axis, by default False

logy (bool, optional)

Set log of y-axis of main panel, by default True

xlabel (str | None, optional)

Label of the x-axis, by default None

ylabel (str | None, optional)

Label of the y-axis, by default None

ylabel_ratio (list[str] | None, optional)

List of labels for the y-axis in the ratio plots, by default "Ratio"

label_fontsize (int, optional)

Used fontsize in label, by default 12

fontsize (int, optional)

Used fontsize, by default 10

n_ratio_panels (int, optional)

Amount of ratio panels between 0 and 2, by default 0

vertical_split (bool, optional)

Set to False if you would like to split the figure horizontally. If set to True the figure is split vertically (e.g. for pie chart), by default False.

figsize (tuple[float, float] | None, optional)

Tuple of figure size (width, height) in inches, by default None

dpi (int, optional)

DPI used for plotting, by default 400

transparent (bool, optional)

Specify if the background of the plot should be transparent, by default False

grid (bool, optional)

Set the grid for the plots, by default True.

figure_layout (str, optional)

Set the layout that is used for the plot, by default "constrained"

leg_fontsize (int | None, optional)

Fontsize of the legend, by default None (falls back to fontsize)

leg_loc (str, optional)

Position of the legend in the plot, by default "upper right"

leg_linestyle_loc (str, optional)

Position of the linestyle legend in the plot, by default "upper center"

leg_ncol (int, optional)

Number of legend columns, by default 1

apply_atlas_style (bool, optional)

Apply ATLAS style for matplotlib, by default True

use_atlas_tag (bool, optional)

Use the ATLAS Tag in the plots, by default True

atlas_first_tag (str, optional)

First row of the ATLAS tag (i.e. "ATLAS "), by default "Simulation Internal"

atlas_second_tag (str | None, optional)

Second row of the ATLAS tag, by default None

atlas_fontsize (int | None, optional)

Fontsize of ATLAS label, by default None (falls back to fontsize)

atlas_vertical_offset (float, optional)

Vertical offset of the ATLAS tag, by default 7

atlas_horizontal_offset (float, optional)

Horizontal offset of the ATLAS tag, by default 8

atlas_brand (str | None, optional)

brand argument handed to atlasify. Use an empty string or None to remove it, by default "ATLAS"

atlas_tag_outside (bool, optional)

outside argument handed to atlasify. Decides if the ATLAS logo is plotted outside of the plot (on top), by default False

atlas_second_tag_distance (float, optional)

Distance between atlas_first_tag and atlas_second_tag in units of line spacing, by default 0

plotting_done (bool, optional)

Indicates if plotting is done. Only then atlasify() can be called, by default False

__check_figsize #

Check figsize is a tuple/list of length 2.

Raises:

Type Description
ValueError

If shape of figsize is not a tuple or list with length 2

Source code in puma/plot_base.py
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
def __check_figsize(self) -> None:
    """Check `figsize` is a tuple/list of length 2.

    Raises
    ------
    ValueError
        If shape of `figsize` is not a tuple or list with length 2
    """
    if self.figsize is None:
        return
    if isinstance(self.figsize, list) and len(self.figsize) == 2:
        self.figsize = tuple(self.figsize)
    elif not isinstance(self.figsize, tuple) or len(self.figsize) != 2:
        raise ValueError(
            f"You passed `figsize` as {self.figsize} which is not allowed. "
            "Either a tuple or a list of size 2 is allowed"
        )

__check_yratio #

Check yratio is a sequence of length n_ratio_panels.

Parameters:

Name Type Description Default
yratio typing.Sequence[float | None] | None

List of min or max limits of ratio plots

required

Raises:

Type Description
ValueError

If yratio is not a list and it's length is not equal to number of ratio panels

Source code in puma/plot_base.py
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
def __check_yratio(self, yratio: Sequence[float | None] | None) -> None:
    """Check `yratio` is a sequence of length n_ratio_panels.

    Parameters
    ----------
    yratio : Sequence[float | None] | None
        List of min or max limits of ratio plots

    Raises
    ------
    ValueError
        If `yratio` is not a list and it's length
        is not equal to number of ratio panels
    """
    if yratio is None:
        return
    if not isinstance(yratio, (list, tuple)) or len(yratio) != self.n_ratio_panels:
        raise ValueError(
            f"You passed `min/max_yratio` as {yratio} which is not allowed. "
            f"Either a tuple or a list of size {self.n_ratio_panels} is allowed"
        )

__post_init__ #

Check for allowed values.

Raises:

Type Description
ValueError

If n_ratio_panels not in [0, 1, 2, 3]

Source code in puma/plot_base.py
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
def __post_init__(self) -> None:
    """Check for allowed values.

    Raises
    ------
    ValueError
        If n_ratio_panels not in [0, 1, 2, 3]
    """
    self.__check_figsize()
    allowed_n_ratio_panels = [0, 1, 2, 3]
    if self.n_ratio_panels not in allowed_n_ratio_panels:
        raise ValueError(
            f"{self.n_ratio_panels} not allwed value for `n_ratio_panels`. "
            f"Allowed are {allowed_n_ratio_panels}"
        )
    self.__check_yratio(self.ymin_ratio)
    self.ymin_ratio = (
        [None] * self.n_ratio_panels if self.ymin_ratio is None else self.ymin_ratio
    )
    self.__check_yratio(self.ymax_ratio)
    self.ymax_ratio = (
        [None] * self.n_ratio_panels if self.ymax_ratio is None else self.ymax_ratio
    )

    if self.ylabel_ratio is None:
        self.ylabel_ratio = ["Ratio"] * self.n_ratio_panels
    elif isinstance(self.ylabel_ratio, str):
        self.ylabel_ratio = [self.ylabel_ratio]
    if len(self.ylabel_ratio) != self.n_ratio_panels:
        raise ValueError(
            f"You passed `ylabel_ratio` of length {len(self.ylabel_ratio)}, "
            f"but `n_ratio_panels` of {self.n_ratio_panels}. "
            f"These should be equal."
        )
    if self.leg_fontsize is None:
        self.leg_fontsize = self.fontsize
    if self.atlas_fontsize is None:
        self.atlas_fontsize = self.fontsize
    if not self.apply_atlas_style and (
        self.atlas_first_tag is not None or self.atlas_second_tag is not None
    ):
        logger.warning(
            "You specified an ATLAS tag, but `apply_atlas_style` is set to false. "
            "Tag will therefore not be shown on plot."
        )

puma.plot_base.PlotBase #

Bases: puma.plot_base.PlotObject

Base class for plotting.

Parameters:

Name Type Description Default
**kwargs typing.Any

Keyword arguments from puma.PlotObject

{}
Source code in puma/plot_base.py
505
506
507
508
509
510
def __init__(self, **kwargs: Any) -> None:
    super().__init__(**kwargs)
    self.axis_top: Axes | None = None
    self.ratio_axes: list[Axes] = []
    self.axis_leg: Axes | None = None
    self.fig: Figure | None = None

atlasify #

Apply ATLAS style to all axes using the atlasify package.

Parameters:

Name Type Description Default
force bool

Force ATLAS style also if class variable is False, by default False

False
Source code in puma/plot_base.py
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
def atlasify(self, force: bool = False) -> None:
    """Apply ATLAS style to all axes using the atlasify package.

    Parameters
    ----------
    force : bool, optional
        Force ATLAS style also if class variable is False, by default False
    """
    if not self.plotting_done and not force:
        logger.warning(
            "`atlasify()` has to be called after plotting --> "
            "ATLAS style will not be adapted. If you want to do it anyway, "
            "you can use `force`."
        )
        return

    if self.apply_atlas_style or force:
        assert self.axis_top is not None
        logger.debug("Initialise ATLAS style using atlasify.")
        if self.use_atlas_tag:
            atlasify.atlasify(
                atlas=self.atlas_first_tag,
                subtext=self.atlas_second_tag,
                axes=self.axis_top,
                font_size=self.atlas_fontsize,
                label_font_size=self.atlas_fontsize,
                sub_font_size=self.atlas_fontsize,
                offset=self.atlas_vertical_offset,
                indent=self.atlas_horizontal_offset,
                enlarge=1,
                brand="" if self.atlas_brand is None else self.atlas_brand,
                outside=self.atlas_tag_outside,
                subtext_distance=self.atlas_second_tag_distance,
            )
        else:
            atlasify.atlasify(atlas=False, axes=self.axis_top, enlarge=1)

        for ratio_axis in self.ratio_axes:
            atlasify.atlasify(atlas=False, axes=ratio_axis, enlarge=1)

        if self.vertical_split and self.axis_leg is not None:
            atlasify.atlasify(atlas=False, axes=self.axis_leg, enlarge=1)

        if force:
            if not self.apply_atlas_style:
                logger.warning(
                    "Initialising ATLAS style even though `apply_atlas_style` is set to False."
                )
            if not self.plotting_done:
                logger.warning(
                    "Initialising ATLAS style even though `plotting_done` is set to False."
                )

close_window #

Properly close the Tkinter window and exit the main loop.

Parameters:

Name Type Description Default
root tkinter.Tk | None

The Tkinter root window instance to be closed.

required
Source code in puma/plot_base.py
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
def close_window(self, root: tk.Tk | None) -> None:
    """Properly close the Tkinter window and exit the main loop.

    Parameters
    ----------
    root : tk.Tk | None
        The Tkinter root window instance to be closed.
    """
    if root is not None:
        logger.debug("Closing plot window.")

        # Stop the Tkinter main loop and destroy the window
        root.quit()
        root.destroy()

        # Explicitly delete the root object (optional but helps with garbage collection)
        del root

draw_vlines #

Draw vertical lines and optional labels on the axes.

Parameters:

Name Type Description Default
xs typing.Sequence[float]

List of working points x values to draw

required
labels typing.Sequence[str | None] | None

List with labels for the vertical lines. Must be the same order as the xs. If None, the xvalues * 100 will be used as labels. By default None

None
ys typing.Sequence[float] | None

List with the y height of the vertical lines in percent of the upper plot (0 is bottom, 1 is top). Must be the same order as the xs and the labels. By default None

None
same_height bool

Working point lines on same height, by default False

False
colour str

Colour of the vertical line, by default "#000000" (black)

'#000000'
linestyle str

Linestyle of the vertical line, by default "dashed"

'dashed'
fontsize int

Fontsize of the vertical line text. By default 10.

10
Source code in puma/plot_base.py
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
def draw_vlines(
    self,
    xs: Sequence[float],
    labels: Sequence[str | None] | None = None,
    ys: Sequence[float] | None = None,
    same_height: bool = False,
    colour: str = "#000000",
    linestyle: str = "dashed",
    fontsize: int = 10,
) -> None:
    """Draw vertical lines and optional labels on the axes.

    Parameters
    ----------
    xs : Sequence[float]
        List of working points x values to draw
    labels : Sequence[str | None] | None, optional
        List with labels for the vertical lines. Must be the same
        order as the xs. If None, the xvalues * 100 will be
        used as labels. By default None
    ys : Sequence[float] | None, optional
        List with the y height of the vertical lines in percent of the
        upper plot (0 is bottom, 1 is top). Must be the same
        order as the xs and the labels. By default None
    same_height : bool, optional
        Working point lines on same height, by default False
    colour : str, optional
        Colour of the vertical line, by default "#000000" (black)
    linestyle : str, optional
        Linestyle of the vertical line, by default "dashed"
    fontsize : int, optional
        Fontsize of the vertical line text. By default 10.
    """
    assert self.axis_top is not None
    for i, vline_x in enumerate(xs):
        # Set y-point of the WP lines/text
        ytext = (0.65 if same_height else 0.65 - i * 0.07) if ys is None else ys[i]

        self.axis_top.axvline(
            x=vline_x,
            ymax=ytext,
            color=colour,
            linestyle=linestyle,
            linewidth=1.0,
        )

        # Set the number above the line
        self.axis_top.text(
            x=vline_x - 0.005,
            y=ytext + 0.005,
            s=(labels[i] if labels else None),
            transform=self.axis_top.get_xaxis_text1_transform(0)[0],
            fontsize=fontsize,
        )

        for ratio_axis in self.ratio_axes:
            ratio_axis.axvline(x=vline_x, color=colour, linestyle=linestyle, linewidth=1.0)

initialise_figure #

Create matplotlib Figure and subplots based on layout options.

Source code in puma/plot_base.py
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
def initialise_figure(self) -> None:
    """Create matplotlib Figure and subplots based on layout options."""
    if self.vertical_split:  # split figure vertically instead of horizonally
        if self.n_ratio_panels >= 1:
            logger.warning(
                "You set the number of ratio panels to %i but also set the"
                " vertical splitting to True. Therefore no ratiopanels are"
                " created.",
                self.n_ratio_panels,
            )
        self.fig = Figure(figsize=(6, 4.5) if self.figsize is None else self.figsize)
        g_spec = gridspec.GridSpec(1, 11, figure=self.fig)
        self.axis_top = self.fig.add_subplot(g_spec[0, :9])
        self.axis_leg = self.fig.add_subplot(g_spec[0, 9:])

    else:
        # you must use increments of 0.1 for the dimensions
        width = 4.7
        top_height = 2.7 if self.n_ratio_panels else 3.5
        ratio_height = 1.0
        height = top_height + self.n_ratio_panels * ratio_height
        figsize = (width, height) if self.figsize is None else self.figsize
        self.fig = Figure(figsize=figsize, layout=self.figure_layout)

        if self.n_ratio_panels == 0:
            self.axis_top = self.fig.gca()
        elif self.n_ratio_panels > 0:
            g_spec_height = (top_height + ratio_height * self.n_ratio_panels) * 10
            g_spec = gridspec.GridSpec(int(g_spec_height), 1, figure=self.fig)
            self.axis_top = self.fig.add_subplot(g_spec[: int(top_height * 10), 0])
            self.set_xaxis_ticklabels_invisible(self.axis_top)
            for i in range(1, self.n_ratio_panels + 1):
                start = int((top_height + ratio_height * (i - 1)) * 10)
                stop = int(start + ratio_height * 10)
                sub_axis = self.fig.add_subplot(g_spec[start:stop, 0], sharex=self.axis_top)
                if i < self.n_ratio_panels:
                    self.set_xaxis_ticklabels_invisible(sub_axis)
                self.ratio_axes.append(sub_axis)

    # type-narrowing: required before any use
    assert self.axis_top is not None
    assert self.fig is not None

    # Add the locator to all axes
    self.axis_top.yaxis.set_major_locator(
        locator=MaxNLocator(
            nbins="auto",
            prune="both",
            steps=[1, 2, 5, 10],
        )
    )
    for ratio_axis in self.ratio_axes:
        ratio_axis.yaxis.set_major_locator(
            locator=MaxNLocator(
                nbins="auto",
                prune="both",
                steps=[1, 2, 5, 10],
            )
        )

    if self.grid:
        self.axis_top.grid(lw=0.3)
        for ratio_axis in self.ratio_axes:
            ratio_axis.grid(lw=0.3)

is_running_in_jupyter #

Detect if running inside a Jupyter notebook.

Returns:

Type Description
bool

If the code is run inside a jupyter notebook

Source code in puma/plot_base.py
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
def is_running_in_jupyter(self) -> bool:
    """Detect if running inside a Jupyter notebook.

    Returns
    -------
    bool
        If the code is run inside a jupyter notebook
    """
    try:
        shell = get_ipython()

        # Running in standard Python interpreter
        if shell is None:
            return False

        shell_name = shell.__class__.__name__

        # Jupyter notebook or qtconsole
        if shell_name == "ZMQInteractiveShell":
            return True

        # Terminal running IPython
        if shell_name == "TerminalInteractiveShell":
            return False

    # Probably standard Python interpreter
    except (NameError, ImportError):
        return False

    else:
        # Other type (?)
        return False

make_legend #

Draw legend on a given axis.

Parameters:

Name Type Description Default
handles list[matplotlib.lines.Line2D]

List of matplotlib.lines.Line2D object returned when plotting

required
ax_mpl matplotlib.axes.Axes

matplotlib.axis.Axes object where the legend should be plotted

required
labels list[str] | None

Plot labels. If None, the labels are extracted from the handles. By default None

None
**kwargs typing.Any

Keyword arguments which can be passed to matplotlib axis

{}
Source code in puma/plot_base.py
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
def make_legend(
    self,
    handles: list[lines.Line2D],
    ax_mpl: Axes,
    labels: list[str] | None = None,
    **kwargs: Any,
) -> None:
    """Draw legend on a given axis.

    Parameters
    ----------
    handles : list[lines.Line2D]
        List of matplotlib.lines.Line2D object returned when plotting
    ax_mpl : Axes
        `matplotlib.axis.Axes` object where the legend should be plotted
    labels : list[str] | None, optional
        Plot labels. If None, the labels are extracted from the `handles`.
        By default None
    **kwargs : Any
        Keyword arguments which can be passed to matplotlib axis
    """
    if labels is None:
        # remove the handles which have label=None
        handles = [handle for handle in handles if handle.get_label() is not None]
    ax_mpl.add_artist(
        ax_mpl.legend(
            handles=handles,
            labels=([handle.get_label() for handle in handles] if labels is None else labels),
            loc=self.leg_loc,
            fontsize=self.leg_fontsize,
            ncol=self.leg_ncol,
            **kwargs,
        )
    )

make_linestyle_legend #

Create a legend to indicate what different linestyles correspond to.

Parameters:

Name Type Description Default
linestyles typing.Sequence[str]

List of the linestyles to draw in the legend

required
labels typing.Sequence[str]

List of the corresponding labels. Has to be in the same order as the linestyles

required
loc str | None

Location of the legend (matplotlib supported locations), by default None

None
bbox_to_anchor tuple[float, float] | tuple[float, float, float, float] | None

Allows to specify the precise position of this legend. Either a 2-tuple (x, y) or a 4-tuple (x, y, width, height), by default None

None
axis_for_legend matplotlib.axes.Axes | None

Axis on which to draw the legend, by default None

None
Source code in puma/plot_base.py
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
def make_linestyle_legend(
    self,
    linestyles: Sequence[str],
    labels: Sequence[str],
    loc: str | None = None,
    bbox_to_anchor: tuple[float, float] | tuple[float, float, float, float] | None = None,
    axis_for_legend: Axes | None = None,
) -> None:
    """Create a legend to indicate what different linestyles correspond to.

    Parameters
    ----------
    linestyles : Sequence[str]
        List of the linestyles to draw in the legend
    labels : Sequence[str]
        List of the corresponding labels. Has to be in the same order as the
        linestyles
    loc : str | None, optional
        Location of the legend (matplotlib supported locations), by default None
    bbox_to_anchor : tuple[float, float] | tuple[float, float, float, float] | None, optional
        Allows to specify the precise position of this legend. Either a 2-tuple
        (x, y) or a 4-tuple (x, y, width, height), by default None
    axis_for_legend : Axes | None, optional
        Axis on which to draw the legend, by default None
    """
    axis_for_legend = self.axis_top if axis_for_legend is None else axis_for_legend
    assert axis_for_legend is not None

    lines_list: list[lines.Line2D] = []
    for linestyle, label in zip(linestyles, labels):
        lines_list.append(
            lines.Line2D(
                [],
                [],
                color="k",
                label=label,
                linestyle=linestyle,
            )
        )

    linestyle_legend = axis_for_legend.legend(
        handles=lines_list,
        labels=[handle.get_label() for handle in lines_list],
        loc=(loc if loc is not None else self.leg_linestyle_loc),
        fontsize=self.leg_fontsize,
        bbox_to_anchor=bbox_to_anchor,
        frameon=False,
    )
    axis_for_legend.add_artist(linestyle_legend)

savefig #

Save plot to disk.

Parameters:

Name Type Description Default
plot_name str

File name of the plot

required
transparent bool | None

Specify if plot background is transparent, by default False

None
dpi int | None

DPI for plotting, by default 400

None
**kwargs typing.Any

Keyword arguments passed to matplotlib.figure.Figure.savefig()

{}
Source code in puma/plot_base.py
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
def savefig(
    self,
    plot_name: str,
    transparent: bool | None = None,
    dpi: int | None = None,
    **kwargs: Any,
) -> None:
    """Save plot to disk.

    Parameters
    ----------
    plot_name : str
        File name of the plot
    transparent : bool | None, optional
        Specify if plot background is transparent, by default False
    dpi : int | None, optional
        DPI for plotting, by default 400
    **kwargs : Any
        Keyword arguments passed to `matplotlib.figure.Figure.savefig()`
    """
    assert self.fig is not None
    logger.debug("Saving plot to %s", plot_name)
    self.fig.savefig(
        plot_name,
        transparent=self.transparent if transparent is None else transparent,
        dpi=self.dpi if dpi is None else dpi,
        bbox_inches="tight",
        pad_inches=0.04,
        **kwargs,
    )

set_log #

Set log scale of axes as configured.

Source code in puma/plot_base.py
649
650
651
652
653
654
655
656
657
658
659
660
def set_log(self) -> None:
    """Set log scale of axes as configured."""
    assert self.axis_top is not None
    if self.logx:
        self.axis_top.set_xscale("log")
        for ratio_axis in self.ratio_axes:
            ratio_axis.set_xscale("log")

    if self.logy:
        self.axis_top.set_yscale("log")
        ymin, ymax = self.axis_top.get_ylim()
        self.y_scale = ymin * ((ymax / ymin) ** self.y_scale) / ymax

set_ratio_label #

Associate the rejection class to a ratio panel.

Parameters:

Name Type Description Default
ratio_panel int

Index of the ratio panel to label.

required
label str

y-axis label of the ratio panel.

required

Raises:

Type Description
ValueError

If the requested ratio panel does not exist.

Source code in puma/plot_base.py
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
def set_ratio_label(self, ratio_panel: int, label: str) -> None:
    """Associate the rejection class to a ratio panel.

    Parameters
    ----------
    ratio_panel : int
        Index of the ratio panel to label.
    label : str
        y-axis label of the ratio panel.

    Raises
    ------
    ValueError
        If the requested ratio panel does not exist.
    """
    if ratio_panel > self.n_ratio_panels:
        raise ValueError(f"Plot has {self.n_ratio_panels} ratio panels, not {ratio_panel}")
    assert self.ylabel_ratio is not None
    self.ylabel_ratio[ratio_panel - 1] = label

set_tick_params #

Set tick params on all relevant axes.

Parameters:

Name Type Description Default
labelsize int | None

Label size of x- and y- axis ticks, by default None. If None then using global fontsize

None
**kwargs typing.Any

Keyword arguments passed to matplotlib.axes.Axes.set_xlabel()

{}
Source code in puma/plot_base.py
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
def set_tick_params(self, labelsize: int | None = None, **kwargs: Any) -> None:
    """Set tick params on all relevant axes.

    Parameters
    ----------
    labelsize : int | None, optional
        Label size of x- and y- axis ticks, by default None.
        If None then using global fontsize
    **kwargs : Any
        Keyword arguments passed to `matplotlib.axes.Axes.set_xlabel()`
    """
    assert self.axis_top is not None
    labelsize_eff = self.fontsize if labelsize is None else labelsize
    self.axis_top.tick_params(axis="y", labelsize=labelsize_eff, **kwargs)
    if self.n_ratio_panels == 0:
        self.axis_top.tick_params(axis="x", labelsize=labelsize_eff, **kwargs)
    for i, ratio_axis in enumerate(self.ratio_axes):
        ratio_axis.tick_params(axis="y", labelsize=labelsize_eff, **kwargs)
        if i == self.n_ratio_panels - 1:
            ratio_axis.tick_params(axis="x", labelsize=labelsize_eff, **kwargs)

set_title #

Set title of top panel.

Parameters:

Name Type Description Default
title str | None

Title of top panel, if None using the value form the class variables, by default None

None
**kwargs typing.Any

Keyword arguments passed to matplotlib.axes.Axes.set_title()

{}
Source code in puma/plot_base.py
635
636
637
638
639
640
641
642
643
644
645
646
647
def set_title(self, title: str | None = None, **kwargs: Any) -> None:
    """Set title of top panel.

    Parameters
    ----------
    title : str | None, optional
        Title of top panel, if None using the value form the class variables,
        by default None
    **kwargs : Any
        Keyword arguments passed to `matplotlib.axes.Axes.set_title()`
    """
    assert self.axis_top is not None
    self.axis_top.set_title(self.title if title is None else title, **kwargs)

set_xaxis_ticklabels_invisible #

Helper function to set the ticklabels of the xaxis invisible.

Parameters:

Name Type Description Default
ax matplotlib.axes.Axes

Axis you want to modify

required
Source code in puma/plot_base.py
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
def set_xaxis_ticklabels_invisible(self, ax: Axes):
    """Helper function to set the ticklabels of the xaxis invisible.

    Parameters
    ----------
    ax : Axes
        Axis you want to modify
    """
    for label in ax.get_xticklabels():
        label.set_visible(False)

set_xlabel #

Set x-axis label on the bottom-most axis.

Parameters:

Name Type Description Default
label str | None

x-axis label, by default None

None
**kwargs typing.Any

Keyword arguments passed to matplotlib.axes.Axes.set_xlabel()

{}
Source code in puma/plot_base.py
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
def set_xlabel(self, label: str | None = None, **kwargs: Any) -> None:
    """Set x-axis label on the bottom-most axis.

    Parameters
    ----------
    label : str | None, optional
        x-axis label, by default None
    **kwargs : Any
        Keyword arguments passed to `matplotlib.axes.Axes.set_xlabel()`
    """
    assert self.axis_top is not None
    xlabel_args = {
        "xlabel": self.xlabel if label is None else label,
        "horizontalalignment": "right",
        "x": 1.0,
        "fontsize": self.label_fontsize,
    }
    if self.n_ratio_panels == 0:
        self.axis_top.set_xlabel(**xlabel_args, **kwargs)
    else:
        self.ratio_axes[-1].set_xlabel(**xlabel_args, **kwargs)

set_xlim #

Set limits of x-axis.

Parameters:

Name Type Description Default
xmin float | None

Min of x-axis, by default None

None
xmax float | None

Max of x-axis, by default None

None
**kwargs typing.Any

Keyword arguments passed to matplotlib.axes.Axes.set_xlim()

{}
Source code in puma/plot_base.py
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
def set_xlim(self, xmin: float | None = None, xmax: float | None = None, **kwargs: Any) -> None:
    """Set limits of x-axis.

    Parameters
    ----------
    xmin : float | None, optional
        Min of x-axis, by default None
    xmax : float | None, optional
        Max of x-axis, by default None
    **kwargs : Any
        Keyword arguments passed to `matplotlib.axes.Axes.set_xlim()`
    """
    assert self.axis_top is not None
    self.axis_top.set_xlim(
        self.xmin if xmin is None else xmin,
        self.xmax if xmax is None else xmax,
        **kwargs,
    )

set_y_lim #

Set limits of y-axis (main and ratios).

Source code in puma/plot_base.py
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
def set_y_lim(self) -> None:
    """Set limits of y-axis (main and ratios)."""
    assert self.axis_top is not None
    ymin, ymax = self.axis_top.get_ylim()
    self.axis_top.set_ylim(
        self.ymin if self.ymin is not None else ymin,
        (ymin + (ymax - ymin) * self.y_scale) if self.ymax is None else self.ymax,
    )

    if self.ymin_ratio is None or self.ymax_ratio is None:
        return

    for i, ratio_axis in enumerate(self.ratio_axes):
        if self.ymin_ratio[i] is not None or self.ymax_ratio[i] is not None:
            ymin_i, ymax_i = ratio_axis.get_ylim()
            ymin_i = self.ymin_ratio[i] if self.ymin_ratio[i] is not None else ymin_i
            ymax_i = self.ymax_ratio[i] if self.ymax_ratio[i] is not None else ymax_i
            ratio_axis.set_ylim(bottom=ymin_i, top=ymax_i)

set_ylabel #

Set y-axis label.

Parameters:

Name Type Description Default
ax_mpl matplotlib.axes.Axes

matplotlib axis object

required
label str | None

x-axis label, by default None

None
align str | None

Alignment of y-axis label, by default "right"

'right'
**kwargs typing.Any

Keyword arguments passed to matplotlib.axes.Axes.set_ylabel()

{}
Source code in puma/plot_base.py
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
def set_ylabel(
    self,
    ax_mpl: Axes,
    label: str | None = None,
    align: str | None = "right",
    **kwargs: Any,
) -> None:
    """Set y-axis label.

    Parameters
    ----------
    ax_mpl : Axes
        matplotlib axis object
    label : str | None, optional
        x-axis label, by default None
    align : str | None, optional
        Alignment of y-axis label, by default "right"
    **kwargs : Any
        Keyword arguments passed to `matplotlib.axes.Axes.set_ylabel()`
    """
    assert self.fig is not None
    label_options: dict[str, Any] = {"fontsize": self.label_fontsize}
    if align:
        label_options["horizontalalignment"] = align
        if align == "right":
            label_options["y"] = 1
        elif align == "left":
            label_options["y"] = 0

    ax_mpl.set_ylabel(
        self.ylabel if label is None else label,
        **label_options,
        **kwargs,
    )
    self.fig.align_labels()

show #

Show the plot using tkinter in CLI and detect Jupyter to avoid issues.

Parameters:

Name Type Description Default
auto_close_after int | None

After how many milliseconds, the window is automatically closed, by default None

None

Raises:

Type Description
ValueError

If the figure is not initalized yet

Source code in puma/plot_base.py
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
def show(self, auto_close_after: int | None = None) -> None:
    """Show the plot using tkinter in CLI and detect Jupyter to avoid issues.

    Parameters
    ----------
    auto_close_after : int | None, optional
        After how many milliseconds, the window is automatically closed, by default None

    Raises
    ------
    ValueError
        If the figure is not initalized yet
    """
    if self.is_running_in_jupyter():
        logger.debug("Detected Jupyter Notebook, displaying inline.")
        assert self.fig is not None
        display(self.fig)
        return

    logger.debug("Showing plot using tkinter")

    # Ensure figure is initialized
    if self.fig is None:
        raise ValueError("You need to initalize the figure before using show().")

    # Create the Tkinter root window
    root = tk.Tk()
    root.title("Plot Display")

    # Embed the figure into a Tkinter canvas
    canvas = FigureCanvasTkAgg(self.fig, master=root)
    canvas_widget = canvas.get_tk_widget()
    canvas_widget.pack(fill="both", expand=True)

    # Render the figure
    canvas.draw()

    # If auto_close_after is set, close the window automatically
    if auto_close_after:
        logger.debug(f"Auto-closing window after {auto_close_after} ms")
        root.after(auto_close_after, lambda: self.close_window(root))

    # Handle window close event manually
    root.protocol("WM_DELETE_WINDOW", lambda: self.close_window(root))

    # Start Tkinter event loop
    root.mainloop()