Source code for process_improve.batch.plotting
from __future__ import annotations
# Built-in libraries
import json
import math
import random
from collections.abc import Callable
from functools import partial
from typing import Any
import numpy as np
# Plotting settings
import plotly.graph_objects as go
import seaborn as sns
from plotly.offline import plot as plotoffline
from .data_input import check_valid_batch_dict
[docs]
def get_rgba_from_triplet(incolour: list, alpha: float = 1, as_string: bool = False) -> str | list:
"""
Convert the input colour triplet (list) to a Plotly rgba(r,g,b,a) string if
`as_string` is True. If `False` it will return the list of 3 integer RGB
values.
E.g. [0.9677975592919913, 0.44127456009157356, 0.5358103155058701] -> 'rgba(246,112,136,1)'
"""
assert 3 <= len(incolour) <= 4, "`incolour` must be a list of 3 or 4 values; ignores 4th entry"
colours = [max(0, math.floor(c * 255)) for c in list(incolour)[0:3]]
if as_string:
return f"rgba({colours[0]},{colours[1]},{colours[2]},{float(alpha)})"
else:
return colours
[docs]
def plot_to_HTML(filename: str, fig: dict) -> str: # noqa: N802
"""Export a Plotly figure to an HTML file."""
config = dict(
scrollZoom=True,
displayModeBar=True,
editable=False,
displaylogo=False,
showLink=False,
resonsive=True,
)
return plotoffline(
figure_or_data=fig,
config=config,
filename=filename,
include_plotlyjs="cdn",
include_mathjax="cdn",
auto_open=False,
)
[docs]
def plot_all_batches_per_tag( # noqa: PLR0912, PLR0913
df_dict: dict,
tag: str,
tag_y2: str | None = None,
time_column: str | None = None,
extra_info: str = "",
batches_to_highlight: dict | None = None,
x_axis_label: str = "Time [sequence order]",
highlight_width: int = 5,
html_image_height: int = 900,
html_aspect_ratio_w_over_h: float = 16 / 9,
y1_limits: tuple = (None, None),
y2_limits: tuple = (None, None),
) -> go.Figure:
"""Plot a particular `tag` over all batches in the given dataframe `df`.
Parameters
----------
df_dict : dict
Standard data format for batches.
tag : str
Which tag to plot? [on the y1 (left) axis]
tag_y2 : str, optional
Which tag to plot? [on the y2 (right) axis]
Tag will be plotted with different scaling on the secondary axis, to allow time-series
comparisons to be easier.
time_column : str, optional
Which tag on the x-axis. If not specified, creates sequential integers, starting from 0
if left as the default, `None`.
extra_info : str, optional
Used in the plot title to add any extra details, by default ""
batches_to_highlight : dict, optional
Keys are JSON strings parseable by ``json.loads`` into a Plotly line specifier.
For example::
batches_to_highlight = {'{"width": 2, "color": "rgba(255,0,0,0.5)"}': redlist}
will plot the batch identifiers in ``redlist`` with that colour and linewidth.
x_axis_label : str, optional
String label for the x-axis, by default "Time [sequence order]"
highlight_width: int, optional
The width of the highlighted lines; default = 5.
html_image_height : int, optional
HTML image output height, by default 900
html_aspect_ratio_w_over_h : float, optional
HTML image aspect ratio: 16/9 (therefore the default width will be 1600 px)
y1_limits: tuple, optional
Axis limits enforced on the y1 (left) axis. Default is (None, None) which means the data
themselves are used to determine the limits. Specify BOTH limits. Plotly requires
(at the moment https://github.com/plotly/plotly.js/issues/400) that you specify both.
Order: (low limit, high limit)
y2_limits: tuple, optional
Axis limits enforced on the y2 (right) axis. Default is (None, None) which means the data
themselves are used to determine the limits. Specify BOTH limits. Plotly requires
(at the moment https://github.com/plotly/plotly.js/issues/400) that you specify both.
Returns
-------
go.Figure
Standard Plotly fig object (dictionary-like).
"""
if batches_to_highlight is None:
batches_to_highlight = {}
default_line_width = 2
unique_items = list(df_dict.keys())
n_colours = len(unique_items)
random.seed(13)
colours = list(sns.husl_palette(n_colours))
random.shuffle(colours)
colours = [get_rgba_from_triplet(c, as_string=True) for c in colours]
line_styles = {k: dict(width=default_line_width, color=v) for k, v in zip(unique_items, colours, strict=False)}
for key, val in batches_to_highlight.items():
line_styles.update({item: json.loads(key) for item in val if item in df_dict})
highlight_list = []
for val in batches_to_highlight.values():
highlight_list.extend(val)
highlight_list = list(set(highlight_list))
fig = go.Figure()
for batch_id, batch_df in df_dict.items():
assert tag in batch_df.columns, f"Tag '{tag}' not found in the batch with id {batch_id}."
if tag_y2:
assert tag_y2 in batch_df.columns, f"Tag '{tag}' not found in the batch with id {batch_id}."
time_data = batch_df[time_column] if time_column in batch_df.columns else list(range(batch_df.shape[0]))
if batch_id in highlight_list:
continue # come to this later
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag],
name=batch_id,
line=line_styles[batch_id],
mode="lines",
opacity=0.8,
yaxis="y1",
)
)
if tag_y2:
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag_y2],
name=batch_id,
line=line_styles[batch_id],
mode="lines",
opacity=0.8,
yaxis="y2",
)
)
# Add the highlighted batches last: therefore, sadly, we have to do another run-through.
# Plotly does not yet support z-orders.
for batch_id, batch_df in df_dict.items():
time_data = batch_df[time_column] if time_column in batch_df.columns else list(range(batch_df.shape[0]))
if batch_id in highlight_list:
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag],
line=line_styles[batch_id],
name=batch_id,
mode="lines",
opacity=0.8,
yaxis="y1",
)
)
if tag_y2:
fig.add_trace(
go.Scatter(
x=time_data,
y=batch_df[tag_y2],
line=line_styles[batch_id],
name=batch_id,
mode="lines",
opacity=0.8,
yaxis="y2",
)
)
yaxis1_dict = dict(title=tag, gridwidth=2, matches="y1", showticklabels=True, side="left")
if (y1_limits[0] is not None) or (y1_limits[1] is not None):
yaxis1_dict["autorange"] = False
yaxis1_dict["range"] = y1_limits
yaxis2_dict: dict[str, Any] = dict(title=tag_y2, gridwidth=2, matches="y2", showticklabels=True, side="right")
if (y2_limits[0] is not None) or (y2_limits[1] is not None):
yaxis2_dict["autorange"] = False
yaxis2_dict["range"] = y2_limits
fig.update_layout(
title=f"Plot of: '{tag}'"
+ (f" on left axis; with '{tag_y2}' on right axis." if tag_y2 else ".")
+ (f" [{extra_info}]" if extra_info else ""),
hovermode="closest",
showlegend=True,
legend=dict(
orientation="h",
traceorder="normal",
font=dict(family="sans-serif", size=12, color="#000"),
bordercolor="#DDDDDD",
borderwidth=1,
),
autosize=False,
xaxis=dict(title=x_axis_label, gridwidth=1),
yaxis=yaxis1_dict,
width=html_aspect_ratio_w_over_h * html_image_height,
height=html_image_height,
)
if tag_y2:
fig.update_layout(yaxis2=yaxis2_dict)
return fig
[docs]
def colours_per_batch_id(
batch_ids: list,
batches_to_highlight: dict,
default_line_width: float,
use_default_colour: bool = False,
colour_map: Callable | None = None,
) -> dict[Any, dict]:
"""
Return a colour to use for each trace in the plot. A dictionary: keys are batch ids, and
the value is a colour and line width setting for Plotly.
override_default_colour: bool
If True, then the default colour is used (grey: 0.5, 0.5, 0.5)
"""
if colour_map is None:
colour_map = partial(sns.color_palette, "hls")
random.seed(13)
n_colours = len(batch_ids)
colours = list(colour_map(n_colours)) if not (use_default_colour) else [(0.5, 0.5, 0.5)] * n_colours
random.shuffle(colours)
colours = [get_rgba_from_triplet(c, as_string=True) for c in colours]
colour_assignment = {
key: dict(width=default_line_width, color=val)
for key, val in zip(list(batch_ids), colours, strict=False)
}
for key, val in batches_to_highlight.items():
colour_assignment.update({item: json.loads(key) for item in val if item in batch_ids})
return colour_assignment
# flake8: noqa: C901
[docs]
def plot_multitags( # noqa: PLR0912, PLR0913, PLR0915
df_dict: dict,
batch_list: list | None = None,
tag_list: list | None = None,
time_column: str | None = None,
batches_to_highlight: dict | None = None,
settings: dict | None = None,
fig: go.Figure | None = None,
) -> go.Figure:
"""
Plot all the tags for a batch; or a subset of tags, if specified in `tag_list`.
Parameters
----------
df_dict : dict
Standard data format for batches.
batch_list : list [default: None, will plot all batches in df_dict]
Which batches to plot; if provided, must be a list of valid keys into df_dict.
tag_list : list [default: None, will plot all tags in the dataframes]
Which tags to plot; tags will also be plotted in this order, or in the order of the
first dataframe if not specified.
time_column : str, optional
Which tag on the x-axis. If not specified, creates sequential integers, starting from 0
if left as the default, `None`.
batches_to_highlight : dict, optional
Keys are JSON strings parseable by ``json.loads`` into a Plotly line specifier.
For example::
batches_to_highlight = {'{"width": 2, "color": "rgba(255,0,0,0.5)"}': redlist}
will plot the batch identifiers in ``redlist`` with that colour and linewidth.
settings : dict
Default settings::
{
"nrows": 1, # int: number of subplot rows
"ncols": None, # int or None: columns (None = auto)
"x_axis_label": "Time, grouped per tag",# str: x-axis label
"title": "", # str: overall plot title
"show_legend": True, # bool: show legend
"html_image_height": 900, # int: image height in pixels
"html_aspect_ratio_w_over_h": 16/9, # float: width as ratio of height
}
fig : go.Figure
If supplied, uses the existing Plotly figure to draw in.
"""
if batches_to_highlight is None:
batches_to_highlight = {}
font_size = 12
margin_dict = dict(l=10, r=10, b=5, t=80) # Defaults: l=80, r=80, t=100, b=80
hovertemplate = "Time: %{x}\ny: %{y}"
# This will be clumsy, until we have Python 3.9. TODO: use pydantic instead
# This will be clumsy, until we have Python 3.9. TODO: use pydantic instead
default_settings: dict[str, Any] = dict(
# Pydantic: int
nrows=1,
# Pydantic: int
ncols=0,
# Pydantic: str
x_axis_label="Time, grouped per tag",
# Pydantic: str
title="",
# Pydantic: bool
show_legend=True,
# Pydantic: >0
html_image_height=900,
# Pydantic: >0
html_aspect_ratio_w_over_h=16 / 9,
# Pydantic: >0
default_line_width=2,
# Pydantic: callable
colour_map=sns.husl_palette,
# Pydantic: bool
animate=False,
# Pydantic: list
animate_batches_to_highlight=[],
# Pydantic: bool
animate_show_slider=True,
# Pydantic: bool
animate_show_pause=True,
# Pydantic: str
animate_slider_prefix="Index: ",
# Pydantic: bool
# fraction of figure height. Default should be OK, but depends if the
# legend is show and length of batch names
animate_slider_vertical_offset=-0.3,
# Pydantic: > 0
animate_line_width=4, # the animated lines are drawn on top of the historical lines
# Pydantic: optional or int
animate_n_frames=None, # takes max frames required to give every time step 1 frame.
# Pydantic: int >= 0
animate_framerate_milliseconds=0,
)
if settings:
default_settings.update(settings)
settings = default_settings
if len(settings["animate_batches_to_highlight"]) == 0:
settings["animate"] = False
if settings["animate"]:
# override for animations, because we want to see everything in frame zero
settings["default_line_width"] = 0.5
# Override these settings for animations, because we want to see everything in frame zero
animation_colour_assignment = colours_per_batch_id(
batch_ids=list(df_dict.keys()),
batches_to_highlight=batches_to_highlight or dict(),
default_line_width=settings["animate_line_width"],
use_default_colour=False,
colour_map=settings["colour_map"],
)
else:
# Adjust the other animate settings in such a way that the regular functionality works
settings["animate_show_slider"] = False
settings["animate_show_pause"] = False
settings["animate_line_width"] = 0
settings["animate_n_frames"] = 0
settings["animate_batches_to_highlight"] = []
if fig is None:
fig = go.Figure()
batch1 = df_dict[next(iter(df_dict.keys()))]
if tag_list is None:
tag_list = list(batch1.columns)
tag_list = list(tag_list) # Force it; sometimes we get non-list inputs
if batch_list is None:
batch_list = list(df_dict.keys())
batch_list = list(batch_list)
if settings["animate"]:
for batch_id in settings["animate_batches_to_highlight"]:
batch_list.remove(batch_id)
# Afterwards, add them back, at the end.
batch_list.extend(settings["animate_batches_to_highlight"])
if time_column in tag_list:
tag_list.remove(time_column)
# Check that the tag_list is present in all batches.
assert check_valid_batch_dict({k: v[tag_list] for k, v in df_dict.items() if k in batch_list}, no_nan=False)
if settings["ncols"] == 0:
settings["ncols"] = int(np.ceil(len(tag_list) / int(settings["nrows"])))
specs = [[{"type": "scatter"}] * int(settings["ncols"])] * int(settings["nrows"])
fig.set_subplots(
rows=settings["nrows"],
cols=settings["ncols"],
shared_xaxes="all",
shared_yaxes=False,
start_cell="top-left",
vertical_spacing=0.2 / settings["nrows"],
horizontal_spacing=0.2 / settings["ncols"],
subplot_titles=tag_list,
specs=specs,
)
colour_assignment = colours_per_batch_id(
batch_ids=list(df_dict.keys()),
batches_to_highlight=batches_to_highlight,
default_line_width=settings["default_line_width"],
# if animating, yes, use grey for all lines; unless `batches_to_highlight` was specified
use_default_colour=settings["animate"] if settings["animate"] and (len(batches_to_highlight) == 0) else False,
colour_map=settings["colour_map"],
)
# Initial plot (what is visible before animation starts)
longest_time_length: int = 0
for batch_id in batch_list:
batch_df = df_dict[batch_id]
# Time axis values
time_data = batch_df[time_column] if time_column in batch_df.columns else list(range(batch_df.shape[0]))
longest_time_length = max(longest_time_length, len(time_data))
row = col = 1
for tag in tag_list:
showlegend = settings["show_legend"] if tag == tag_list[0] else False
# This feels right, but leads to the animated batched taking the places of the
# first few non-animated batches in the legend.
# Ugh, even without this, they still overwrite them. Sadly.
# if batch_id in settings["animate_batches_to_highlight"]:
# showlegend = False # overridden. If required, we will add it during the animation
trace = go.Scatter(
x=time_data,
y=batch_df[tag],
name=batch_id,
mode="lines",
hovertemplate=hovertemplate,
line=colour_assignment[batch_id],
legendgroup=batch_id,
# Only add batch_id to legend the first time it is plotted (the first subplot)
showlegend=showlegend,
xaxis=fig.get_subplot(row, col)[1]["anchor"],
yaxis=fig.get_subplot(row, col)[0]["anchor"],
)
fig.add_trace(trace)
col += 1
if col > settings["ncols"]:
row += 1
col = 1
# Create the slider; will be ignore later if not required
# https://plotly.com/python/reference/layout/sliders/
slider_baseline_dict = {
"active": 0,
"yanchor": "top",
"xanchor": "left",
"font": {"size": font_size},
"currentvalue": {
"font": {"size": font_size},
"prefix": settings["animate_slider_prefix"],
"visible": True,
"xanchor": "left",
},
"transition": {
"duration": settings["animate_framerate_milliseconds"],
"easing": "linear",
},
"pad": {"b": 0, "t": 0},
"lenmode": "fraction",
"len": 0.9,
"x": 0.05,
"y": settings["animate_slider_vertical_offset"],
"name": "Slider",
"steps": [],
}
# Create other animation settings. Again, these will be ignored if not needed
frames: list = []
slider_steps = []
frame_settings = dict(
frame={"duration": settings["animate_framerate_milliseconds"], "redraw": True},
mode="immediate",
transition={"duration": 0},
)
settings["animate_n_frames"] = (
settings["animate_n_frames"] if settings["animate_n_frames"] >= 0 else longest_time_length
)
for raw_index in np.linspace(0, longest_time_length, settings["animate_n_frames"]):
# TO OPTIMIZE: add hover template only on the last iteration
# TO OPTIMIZE: can you add only the incremental new piece of animation?
index = int(np.floor(raw_index))
frame_name = f"{index}" # this is the link with the slider and the animation in the play button
one_frame = generate_one_frame(
df_dict,
tag_list,
fig,
up_to_index=index + 1,
time_column=time_column,
batch_ids_to_animate=settings["animate_batches_to_highlight"],
animation_colour_assignment=animation_colour_assignment,
show_legend=settings["show_legend"],
hovertemplate=hovertemplate,
max_columns=settings["ncols"],
)
frames.append(go.Frame(data=one_frame, name=frame_name))
slider_dict = dict(
args=[
[frame_name],
frame_settings,
],
label=frame_name,
method="animate",
)
slider_steps.append(slider_dict)
# Buttons: for animations
button_play = dict(
label="Play",
method="animate",
args=[
None,
dict(
frame=dict(duration=0, redraw=False),
transition=dict(duration=30, easing="quadratic-in-out"),
fromcurrent=True,
mode="immediate",
),
],
)
button_pause = dict(
label="Pause",
method="animate",
args=[
# https://plotly.com/python/animations/
# Note the None is in a list!
[[None]], # TODO: does not work at the moment.
dict(
frame=dict(duration=0, redraw=False),
transition=dict(duration=0),
mode="immediate",
),
],
)
# OK, pull things together to render the fig
slider_baseline_dict["steps"] = slider_steps
button_list: list[Any] = []
if settings["animate"]:
fig.update(frames=frames)
button_list.append(button_play)
if settings["animate_show_pause"]:
button_list.append(button_pause)
fig.update_layout(
title=settings["title"],
margin=margin_dict,
hovermode="closest",
showlegend=settings["show_legend"],
legend=dict(
orientation="h",
traceorder="normal",
font=dict(family="sans-serif", size=12, color="#000"),
bordercolor="#DDDDDD",
borderwidth=1,
),
autosize=False,
xaxis=dict(
gridwidth=1,
mirror=True, # ticks are mirror at the top of the frame also
showspikes=True,
visible=True,
),
yaxis=dict(
gridwidth=2,
type="linear",
autorange=True,
showspikes=True,
visible=True,
showline=True, # show a separating line
side="left", # show on the RHS
),
width=settings["html_aspect_ratio_w_over_h"] * settings["html_image_height"],
height=settings["html_image_height"],
sliders=[slider_baseline_dict] if settings["animate_show_slider"] else [],
updatemenus=[
dict(
type="buttons",
showactive=False,
y=0,
x=1.05,
xanchor="left",
yanchor="bottom",
buttons=button_list,
)
],
)
return fig
[docs]
def generate_one_frame( # noqa: PLR0913
df_dict: dict,
tag_list: list,
fig: go.Figure,
up_to_index: int,
time_column: str | None,
batch_ids_to_animate: list,
animation_colour_assignment: dict,
show_legend: bool = False,
hovertemplate: str = "",
max_columns: int = 0,
) -> list[dict]:
"""
Return a list of dictionaries.
Each entry in the list is for each subplot; in the order of the subplots.
Since each subplot is a tag, we need the `tag_list` as input.
"""
output = []
row = col = 1
for tag in tag_list:
for batch_id in batch_ids_to_animate:
# These 4 lines are duplicated from the outside function
if time_column in df_dict[batch_id].columns:
time_data = df_dict[batch_id][time_column]
else:
time_data = list(range(df_dict[batch_id].shape[0]))
output.append(
go.Scatter(
x=time_data[0:up_to_index],
y=df_dict[batch_id][tag][0:up_to_index],
name=batch_id,
mode="lines",
hovertemplate=hovertemplate,
line=animation_colour_assignment[batch_id],
legendgroup=batch_id,
showlegend=show_legend if tag == tag_list[0] else False,
xaxis=fig.get_subplot(row, col)[1]["anchor"],
yaxis=fig.get_subplot(row, col)[0]["anchor"],
)
)
# One level outdented: if the loop for the tags, not in the loop for
# the `batch_ids_to_animate`!
col += 1
if col > max_columns:
row += 1
col = 1
return output