Source code for process_improve.batch.plotting
from __future__ import annotations
# Built-in libraries
import json
import math
import random
from collections.abc import Callable, Sequence
from functools import partial
from typing import Any, cast
import numpy as np
from pydantic import BaseModel, ConfigDict, Field
# Plotting settings -- ENG-13 (#295): live in the optional ``[plotting]``
# extra. Module import succeeds without them; any actual plot call raises
# a clear "install the extra" ImportError via ``_MissingExtra``.
try:
import plotly.graph_objects as go
import seaborn as sns
from plotly.offline import plot as plotoffline
except ImportError: # pragma: no cover - exercised via env-without-plotly
from process_improve._extras import _MissingExtra
go = _MissingExtra("plotly", "plotting") # type: ignore[assignment]
sns = _MissingExtra("seaborn", "plotting") # type: ignore[assignment]
plotoffline = _MissingExtra("plotly", "plotting") # type: ignore[assignment]
from .data_input import check_valid_batch_dict
[docs]
class MultiTagPlotSettings(BaseModel):
"""Settings for :func:`plot_multitags`.
All fields have sensible defaults; pass a plain dict of overrides to
``plot_multitags(..., settings=...)``.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
nrows: int = 1
ncols: int = 0
x_axis_label: str = "Time, grouped per tag"
title: str = ""
show_legend: bool = True
mode: str = "lines"
html_image_height: int = 900
html_aspect_ratio_w_over_h: float = 16 / 9
default_line_width: float = 2
colour_map: Callable = sns.husl_palette
animate: bool = False
animate_batches_to_highlight: list = Field(default_factory=list)
animate_show_slider: bool = True
animate_show_pause: bool = True
animate_slider_prefix: str = "Index: "
# Fraction of figure height; depends on whether the legend is shown and the
# length of the batch names.
animate_slider_vertical_offset: float = -0.3
# The animated lines are drawn on top of the historical lines.
animate_line_width: float = 4
# If None, uses the max number of frames needed for one frame per time step.
animate_n_frames: int | None = None
animate_framerate_milliseconds: int = 0
[docs]
def get_rgba_from_triplet(incolour: Sequence[float], alpha: float = 1, as_string: bool = False) -> str | list:
"""
Convert the input colour triplet (list) to a Plotly rgba(r,g,b,a) string if
`as_string` is True. If `False` it will return the list of 3 integer RGB
values.
E.g. [0.9677975592919913, 0.44127456009157356, 0.5358103155058701] -> 'rgba(246,112,136,1)'
"""
if not 3 <= len(incolour) <= 4:
raise ValueError(
f"`incolour` must be a list of 3 or 4 values; got {len(incolour)} entries."
)
colours = [max(0, math.floor(c * 255)) for c in list(incolour)[0:3]]
if as_string:
return f"rgba({colours[0]},{colours[1]},{colours[2]},{float(alpha)})"
else:
return colours
[docs]
def plot_to_HTML(filename: str, fig: dict) -> str: # noqa: N802
"""Export a Plotly figure to an HTML file."""
config = dict(
scrollZoom=True,
displayModeBar=True,
editable=False,
displaylogo=False,
showLink=False,
resonsive=True,
)
return plotoffline(
figure_or_data=fig,
config=config,
filename=filename,
include_plotlyjs="cdn",
include_mathjax="cdn",
auto_open=False,
)
[docs]
def plot_all_batches_per_tag( # noqa: PLR0912, PLR0913
df_dict: dict,
tag: str,
tag_y2: str | None = None,
time_column: str | None = None,
extra_info: str = "",
batches_to_highlight: dict | None = None,
x_axis_label: str = "Time [sequence order]",
highlight_width: int = 5,
html_image_height: int = 900,
html_aspect_ratio_w_over_h: float = 16 / 9,
y1_limits: tuple = (None, None),
y2_limits: tuple = (None, None),
mode: str = "lines",
) -> go.Figure:
"""Plot a particular `tag` over all batches in the given dataframe `df`.
Parameters
----------
df_dict : dict
Standard data format for batches.
tag : str
Which tag to plot? [on the y1 (left) axis]
tag_y2 : str, optional
Which tag to plot? [on the y2 (right) axis]
Tag will be plotted with different scaling on the secondary axis, to allow time-series
comparisons to be easier.
time_column : str, optional
Which tag on the x-axis. If not specified, creates sequential integers, starting from 0
if left as the default, `None`.
extra_info : str, optional
Used in the plot title to add any extra details, by default ""
batches_to_highlight : dict, optional
Keys are JSON strings parseable by ``json.loads`` into a Plotly line specifier.
For example::
batches_to_highlight = {'{"width": 2, "color": "rgba(255,0,0,0.5)"}': redlist}
will plot the batch identifiers in ``redlist`` with that colour and linewidth.
x_axis_label : str, optional
String label for the x-axis, by default "Time [sequence order]"
highlight_width: int, optional
The width of the highlighted lines; default = 5.
html_image_height : int, optional
HTML image output height, by default 900
html_aspect_ratio_w_over_h : float, optional
HTML image aspect ratio: 16/9 (therefore the default width will be 1600 px)
y1_limits: tuple, optional
Axis limits enforced on the y1 (left) axis. Default is (None, None) which means the data
themselves are used to determine the limits. Specify BOTH limits. Plotly requires
(at the moment https://github.com/plotly/plotly.js/issues/400) that you specify both.
Order: (low limit, high limit)
y2_limits: tuple, optional
Axis limits enforced on the y2 (right) axis. Default is (None, None) which means the data
themselves are used to determine the limits. Specify BOTH limits. Plotly requires
(at the moment https://github.com/plotly/plotly.js/issues/400) that you specify both.
mode: str, optional
Plotly trace draw mode, by default "lines". Use "lines+markers" to also
show a marker at each data point, or "markers" for markers only.
Returns
-------
go.Figure
Standard Plotly fig object (dictionary-like).
"""
if batches_to_highlight is None:
batches_to_highlight = {}
default_line_width = 2
unique_items = list(df_dict.keys())
n_colours = len(unique_items)
random.seed(13)
colours = list(sns.husl_palette(n_colours))
random.shuffle(colours)
colours = [get_rgba_from_triplet(c, as_string=True) for c in colours]
line_styles = {k: dict(width=default_line_width, color=v) for k, v in zip(unique_items, colours, strict=False)}
for key, val in batches_to_highlight.items():
# SEC-32 (#281): each key must be a JSON-encoded plotly line-style
# spec. Decode once outside the comprehension so a bad key raises
# a clear ValueError at the API surface rather than a confusing
# ``JSONDecodeError`` deep inside the comprehension.
try:
style_spec = json.loads(key)
except json.JSONDecodeError as exc:
raise ValueError(
f"batches_to_highlight: each key must be a JSON-encoded "
f"plotly line-style spec (e.g. '{{\"width\":4,\"color\":\"red\"}}'). "
f"Got {key!r}."
) from exc
line_styles.update({item: style_spec for item in val if item in df_dict})
highlight_list = []
for val in batches_to_highlight.values():
highlight_list.extend(val)
highlight_list = list(set(highlight_list))
fig = go.Figure()
for batch_id, batch_df in df_dict.items():
if tag not in batch_df.columns:
raise KeyError(f"Tag '{tag}' not found in the batch with id {batch_id}.")
if tag_y2 and tag_y2 not in batch_df.columns:
raise KeyError(f"Tag '{tag_y2}' not found in the batch with id {batch_id}.")
time_data = batch_df[time_column] if time_column in batch_df.columns else list(range(batch_df.shape[0]))
if batch_id in highlight_list:
continue # come to this later
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag],
name=batch_id,
line=line_styles[batch_id],
mode=mode,
opacity=0.8,
yaxis="y1",
)
)
if tag_y2:
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag_y2],
name=batch_id,
line=line_styles[batch_id],
mode=mode,
opacity=0.8,
yaxis="y2",
)
)
# Add the highlighted batches last: therefore, sadly, we have to do another run-through.
# Plotly does not yet support z-orders.
for batch_id, batch_df in df_dict.items():
time_data = batch_df[time_column] if time_column in batch_df.columns else list(range(batch_df.shape[0]))
if batch_id in highlight_list:
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag],
line=line_styles[batch_id],
name=batch_id,
mode=mode,
opacity=0.8,
yaxis="y1",
)
)
if tag_y2:
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag_y2],
line=line_styles[batch_id],
name=batch_id,
mode=mode,
opacity=0.8,
yaxis="y2",
)
)
yaxis1_dict = dict(title=tag, matches="y1", showticklabels=True, side="left")
if (y1_limits[0] is not None) or (y1_limits[1] is not None):
yaxis1_dict["autorange"] = False
yaxis1_dict["range"] = y1_limits
yaxis2_dict: dict[str, Any] = dict(title=tag_y2, matches="y2", showticklabels=True, side="right")
if (y2_limits[0] is not None) or (y2_limits[1] is not None):
yaxis2_dict["autorange"] = False
yaxis2_dict["range"] = y2_limits
fig.update_layout(
title=f"Plot of: '{tag}'"
+ (f" on left axis; with '{tag_y2}' on right axis." if tag_y2 else ".")
+ (f" [{extra_info}]" if extra_info else ""),
hovermode="closest",
showlegend=True,
autosize=False,
xaxis=dict(title=x_axis_label),
yaxis=yaxis1_dict,
width=html_aspect_ratio_w_over_h * html_image_height,
height=html_image_height,
)
if tag_y2:
fig.update_layout(yaxis2=yaxis2_dict)
return fig
[docs]
def colours_per_batch_id(
batch_ids: list,
batches_to_highlight: dict,
default_line_width: float,
use_default_colour: bool = False,
colour_map: Callable | None = None,
) -> dict[Any, dict]:
"""
Return a colour to use for each trace in the plot. A dictionary: keys are batch ids, and
the value is a colour and line width setting for Plotly.
use_default_colour: bool
If True, then the default colour is used (grey: 0.5, 0.5, 0.5)
"""
if colour_map is None:
colour_map = partial(sns.color_palette, "hls")
random.seed(13)
n_colours = len(batch_ids)
colours = list(colour_map(n_colours)) if not (use_default_colour) else [(0.5, 0.5, 0.5)] * n_colours
random.shuffle(colours)
colours = [get_rgba_from_triplet(c, as_string=True) for c in colours]
colour_assignment = {
key: dict(width=default_line_width, color=val)
for key, val in zip(list(batch_ids), colours, strict=False)
}
for key, val in batches_to_highlight.items():
# SEC-32 (#281) -- decode outside the comprehension so a bad key
# raises ValueError at the API surface rather than JSONDecodeError
# inside the dict-merge.
try:
colour_spec = json.loads(key)
except json.JSONDecodeError as exc:
raise ValueError(
f"batches_to_highlight: each key must be a JSON-encoded "
f"colour spec. Got {key!r}."
) from exc
colour_assignment.update({item: colour_spec for item in val if item in batch_ids})
return colour_assignment
# flake8: noqa: C901
[docs]
def plot_multitags( # noqa: PLR0912, PLR0913, PLR0915
df_dict: dict,
batch_list: list | None = None,
tag_list: list | None = None,
time_column: str | None = None,
batches_to_highlight: dict | None = None,
settings: dict | None = None,
fig: go.Figure | None = None,
) -> go.Figure:
"""
Plot all the tags for a batch; or a subset of tags, if specified in `tag_list`.
Parameters
----------
df_dict : dict
Standard data format for batches.
batch_list : list [default: None, will plot all batches in df_dict]
Which batches to plot; if provided, must be a list of valid keys into df_dict.
tag_list : list [default: None, will plot all tags in the dataframes]
Which tags to plot; tags will also be plotted in this order, or in the order of the
first dataframe if not specified.
time_column : str, optional
Which tag on the x-axis. If not specified, creates sequential integers, starting from 0
if left as the default, `None`.
batches_to_highlight : dict, optional
Keys are JSON strings parseable by ``json.loads`` into a Plotly line specifier.
For example::
batches_to_highlight = {'{"width": 2, "color": "rgba(255,0,0,0.5)"}': redlist}
will plot the batch identifiers in ``redlist`` with that colour and linewidth.
settings : dict
Default settings::
{
"nrows": 1, # int: number of subplot rows
"ncols": None, # int or None: columns (None = auto)
"x_axis_label": "Time, grouped per tag",# str: x-axis label
"title": "", # str: overall plot title
"show_legend": True, # bool: show legend
"mode": "lines", # str: Plotly trace mode
# e.g. "lines+markers"
"html_image_height": 900, # int: image height in pixels
"html_aspect_ratio_w_over_h": 16/9, # float: width as ratio of height
}
fig : go.Figure
If supplied, uses the existing Plotly figure to draw in.
"""
if batches_to_highlight is None:
batches_to_highlight = {}
font_size = 12
hovertemplate = "Time: %{x}\ny: %{y}"
config = MultiTagPlotSettings(**(settings or {}))
if len(config.animate_batches_to_highlight) == 0:
config.animate = False
if config.animate:
# override for animations, because we want to see everything in frame zero
config.default_line_width = 0.5
# Override these settings for animations, because we want to see everything in frame zero
animation_colour_assignment = colours_per_batch_id(
batch_ids=list(df_dict.keys()),
batches_to_highlight=batches_to_highlight or dict(),
default_line_width=config.animate_line_width,
use_default_colour=False,
colour_map=config.colour_map,
)
else:
# Adjust the other animate settings in such a way that the regular functionality works
config.animate_show_slider = False
config.animate_show_pause = False
config.animate_line_width = 0
config.animate_n_frames = 0
config.animate_batches_to_highlight = []
if fig is None:
fig = go.Figure()
batch1 = df_dict[next(iter(df_dict.keys()))]
if tag_list is None:
tag_list = list(batch1.columns)
tag_list = list(tag_list) # Force it; sometimes we get non-list inputs
if batch_list is None:
batch_list = list(df_dict.keys())
batch_list = list(batch_list)
if config.animate:
for batch_id in config.animate_batches_to_highlight:
batch_list.remove(batch_id)
# Afterwards, add them back, at the end.
batch_list.extend(config.animate_batches_to_highlight)
if time_column in tag_list:
tag_list.remove(time_column)
# Check that the tag_list is present in all batches.
if not check_valid_batch_dict(
{k: v[tag_list] for k, v in df_dict.items() if k in batch_list},
no_nan=False,
):
raise ValueError("One or more batches in df_dict failed validation.")
if config.ncols == 0:
config.ncols = int(np.ceil(len(tag_list) / int(config.nrows)))
# Build a fresh dict per cell: `[[...] * ncols] * nrows` would alias one
# row list across every row, so a placeholder like `[[None]]` does not work.
specs: list[list[dict[str, str | bool | int | float] | None]] = [
[{"type": "scatter"} for _ in range(int(config.ncols))] for _ in range(int(config.nrows))
]
fig.set_subplots(
rows=config.nrows,
cols=config.ncols,
shared_xaxes="all",
shared_yaxes=False,
start_cell="top-left",
vertical_spacing=0.2 / config.nrows,
horizontal_spacing=0.2 / config.ncols,
subplot_titles=tag_list,
specs=specs,
)
colour_assignment = colours_per_batch_id(
batch_ids=list(df_dict.keys()),
batches_to_highlight=batches_to_highlight,
default_line_width=config.default_line_width,
# if animating, yes, use grey for all lines; unless `batches_to_highlight` was specified
use_default_colour=config.animate if config.animate and (len(batches_to_highlight) == 0) else False,
colour_map=config.colour_map,
)
# Initial plot (what is visible before animation starts)
longest_time_length: int = 0
for batch_id in batch_list:
batch_df = df_dict[batch_id]
# Time axis values
time_data = batch_df[time_column] if time_column in batch_df.columns else list(range(batch_df.shape[0]))
longest_time_length = max(longest_time_length, len(time_data))
row = col = 1
for tag in tag_list:
showlegend = config.show_legend if tag == tag_list[0] else False
# This feels right, but leads to the animated batched taking the places of the
# first few non-animated batches in the legend.
# Ugh, even without this, they still overwrite them. Sadly.
# if batch_id in config.animate_batches_to_highlight:
# showlegend = False # overridden. If required, we will add it during the animation
trace = go.Scatter(
x=time_data,
y=batch_df[tag],
name=batch_id,
mode=config.mode,
hovertemplate=hovertemplate,
line=colour_assignment[batch_id],
legendgroup=batch_id,
# Only add batch_id to legend the first time it is plotted (the first subplot)
showlegend=showlegend,
xaxis=cast("tuple[Any, ...]", fig.get_subplot(row, col))[1]["anchor"],
yaxis=cast("tuple[Any, ...]", fig.get_subplot(row, col))[0]["anchor"],
)
fig.add_trace(trace)
col += 1
if col > config.ncols:
row += 1
col = 1
# Create the slider; will be ignore later if not required
# https://plotly.com/python/reference/layout/sliders/
slider_baseline_dict = {
"active": 0,
"yanchor": "top",
"xanchor": "left",
"font": {"size": font_size},
"currentvalue": {
"font": {"size": font_size},
"prefix": config.animate_slider_prefix,
"visible": True,
"xanchor": "left",
},
"transition": {
"duration": config.animate_framerate_milliseconds,
"easing": "linear",
},
"pad": {"b": 0, "t": 0},
"lenmode": "fraction",
"len": 0.9,
"x": 0.05,
"y": config.animate_slider_vertical_offset,
"name": "Slider",
"steps": [],
}
# Create other animation settings. Again, these will be ignored if not needed
frames: list = []
slider_steps = []
frame_settings = dict(
frame={"duration": config.animate_framerate_milliseconds, "redraw": True},
mode="immediate",
transition={"duration": 0},
)
n_frames = (
config.animate_n_frames
if config.animate_n_frames is not None and config.animate_n_frames >= 0
else longest_time_length
)
config.animate_n_frames = n_frames
for raw_index in np.linspace(0, longest_time_length, n_frames):
# TO OPTIMIZE: add hover template only on the last iteration
# TO OPTIMIZE: can you add only the incremental new piece of animation?
index = int(np.floor(raw_index))
frame_name = f"{index}" # this is the link with the slider and the animation in the play button
one_frame = generate_one_frame(
df_dict,
tag_list,
fig,
up_to_index=index + 1,
time_column=time_column,
batch_ids_to_animate=config.animate_batches_to_highlight,
animation_colour_assignment=animation_colour_assignment,
show_legend=config.show_legend,
hovertemplate=hovertemplate,
max_columns=config.ncols,
mode=config.mode,
)
frames.append(go.Frame(data=one_frame, name=frame_name))
slider_dict = dict(
args=[
[frame_name],
frame_settings,
],
label=frame_name,
method="animate",
)
slider_steps.append(slider_dict)
# Buttons: for animations
button_play = dict(
label="Play",
method="animate",
args=[
None,
dict(
frame=dict(duration=0, redraw=False),
transition=dict(duration=30, easing="quadratic-in-out"),
fromcurrent=True,
mode="immediate",
),
],
)
button_pause = dict(
label="Pause",
method="animate",
args=[
# https://plotly.com/python/animations/
# Note the None is in a list.
[None],
dict(
frame=dict(duration=0, redraw=False),
transition=dict(duration=0),
mode="immediate",
),
],
)
# OK, pull things together to render the fig
slider_baseline_dict["steps"] = slider_steps
button_list: list[Any] = []
if config.animate:
fig.update(frames=frames)
button_list.append(button_play)
if config.animate_show_pause:
button_list.append(button_pause)
fig.update_layout(
title=config.title,
hovermode="closest",
showlegend=config.show_legend,
autosize=False,
width=config.html_aspect_ratio_w_over_h * config.html_image_height,
height=config.html_image_height,
sliders=[slider_baseline_dict] if config.animate_show_slider else [],
updatemenus=[
dict(
type="buttons",
showactive=False,
y=0,
x=1.05,
xanchor="left",
yanchor="bottom",
buttons=button_list,
)
],
)
return fig
[docs]
def generate_one_frame( # noqa: PLR0913
df_dict: dict,
tag_list: list,
fig: go.Figure,
up_to_index: int,
time_column: str | None,
batch_ids_to_animate: list,
animation_colour_assignment: dict,
show_legend: bool = False,
hovertemplate: str = "",
max_columns: int = 0,
mode: str = "lines",
) -> list[go.Scatter]:
"""
Return a list of dictionaries.
Each entry in the list is for each subplot; in the order of the subplots.
Since each subplot is a tag, we need the `tag_list` as input.
"""
output = []
row = col = 1
for tag in tag_list:
for batch_id in batch_ids_to_animate:
# These 4 lines are duplicated from the outside function
if time_column in df_dict[batch_id].columns:
time_data = df_dict[batch_id][time_column]
else:
time_data = list(range(df_dict[batch_id].shape[0]))
output.append(
go.Scatter(
x=time_data[0:up_to_index],
y=df_dict[batch_id][tag][0:up_to_index],
name=batch_id,
mode=mode,
hovertemplate=hovertemplate,
line=animation_colour_assignment[batch_id],
legendgroup=batch_id,
showlegend=show_legend if tag == tag_list[0] else False,
xaxis=cast("tuple[Any, ...]", fig.get_subplot(row, col))[1]["anchor"],
yaxis=cast("tuple[Any, ...]", fig.get_subplot(row, col))[0]["anchor"],
)
)
# One level outdented: if the loop for the tags, not in the loop for
# the `batch_ids_to_animate`!
col += 1
if col > max_columns:
row += 1
col = 1
return output