"""Plotly backend adapter: ChartSpec → Plotly figure dict.
Converts a :class:`ChartSpec` into a Plotly-compatible dict that can be
passed directly to ``plotly.graph_objects.Figure(data_dict)`` or
serialised to JSON via ``json.dumps``.
"""
from __future__ import annotations
from typing import Any
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from process_improve.visualization.adapters.base import AbstractAdapter
from process_improve.visualization.colors import (
DOE_PALETTE,
SURFACE_COLORSCALE,
)
from process_improve.visualization.spec import (
Annotation,
ChartSpec,
LayerSpec,
PanelSpec,
)
from process_improve.visualization.types import AnnotationType, MarkType
[docs]
class PlotlyAdapter(AbstractAdapter):
"""Translate a :class:`ChartSpec` to a Plotly figure dict."""
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def render(self, spec: ChartSpec) -> dict[str, Any]:
"""Convert the full chart spec to a Plotly figure dict.
Parameters
----------
spec : ChartSpec
The backend-agnostic chart specification.
Returns
-------
dict
Plotly figure dict with ``data`` and ``layout`` keys.
"""
n = len(spec.panels)
if n == 0:
return go.Figure().to_dict()
if n == 1: # noqa: SIM108
fig = self._single_panel(spec.panels[0], spec.title)
else:
fig = self._multi_panel(spec)
return fig.to_dict()
[docs]
def render_panel(self, panel: PanelSpec) -> dict[str, Any]:
"""Convert a single panel to a Plotly figure dict.
Parameters
----------
panel : PanelSpec
One chart panel.
Returns
-------
dict
Plotly figure dict.
"""
return self._single_panel(panel, panel.title).to_dict()
# ------------------------------------------------------------------
# Single-panel rendering
# ------------------------------------------------------------------
def _single_panel(self, panel: PanelSpec, title: str = "") -> go.Figure:
if panel.secondary_y: # noqa: SIM108
fig = make_subplots(specs=[[{"secondary_y": True}]])
else:
fig = go.Figure()
for layer in panel.layers:
trace, on_secondary = self._layer_to_trace(layer)
if panel.secondary_y and on_secondary:
fig.add_trace(trace, secondary_y=True)
else:
fig.add_trace(trace)
for ann in panel.annotations:
self._add_annotation(fig, ann)
fig.update_layout(
title=dict(text=title or panel.title),
xaxis=dict(title=panel.x_title),
width=panel.width,
height=panel.height,
template="plotly_white",
font=dict(size=12),
)
if panel.secondary_y:
fig.update_yaxes(title_text=panel.y_title, secondary_y=False)
fig.update_yaxes(title_text=panel.secondary_y_title, secondary_y=True)
else:
fig.update_layout(yaxis=dict(title=panel.y_title))
if panel.backend_hints.get("equal_aspect"):
fig.update_yaxes(scaleanchor="x", scaleratio=1)
return fig
# ------------------------------------------------------------------
# Multi-panel rendering
# ------------------------------------------------------------------
def _multi_panel(self, spec: ChartSpec) -> go.Figure:
n = len(spec.panels)
cols = min(spec.columns, n)
rows = (n + cols - 1) // cols
subtitles = [p.title for p in spec.panels]
fig = make_subplots(rows=rows, cols=cols, subplot_titles=subtitles)
for idx, panel in enumerate(spec.panels):
row = idx // cols + 1
col = idx % cols + 1
for layer in panel.layers:
trace, _ = self._layer_to_trace(layer)
fig.add_trace(trace, row=row, col=col)
fig.update_xaxes(title_text=panel.x_title, row=row, col=col)
fig.update_yaxes(title_text=panel.y_title, row=row, col=col)
for ann in panel.annotations:
self._add_annotation(fig, ann, row=row, col=col)
fig.update_layout(
title=dict(text=spec.title),
template="plotly_white",
font=dict(size=12),
showlegend=True,
)
return fig
# ------------------------------------------------------------------
# Layer → Plotly trace
# ------------------------------------------------------------------
def _layer_to_trace(self, layer: LayerSpec) -> tuple[go.BaseTraceType, bool]: # noqa: PLR0911
"""Convert a :class:`LayerSpec` to a Plotly trace.
Returns
-------
tuple[go.BaseTraceType, bool]
The trace and whether it targets the secondary y-axis.
"""
on_secondary = layer.style.get("secondary_y", False)
mark = layer.mark if isinstance(layer.mark, MarkType) else MarkType(layer.mark)
# Marks that don't read from layer.x.field / layer.y.field at the
# row level (they use layer.style grids or row-level ``q_stats``).
if mark == MarkType.contour:
return self._contour_trace(layer), on_secondary
if mark == MarkType.surface:
return self._surface_trace(layer), on_secondary
if mark == MarkType.heatmap:
return self._heatmap_trace(layer), on_secondary
if mark == MarkType.wireframe:
return self._wireframe_trace(layer), on_secondary
if mark == MarkType.boxplot:
return self._boxplot_trace(layer), on_secondary
x_vals = [row[layer.x.field] for row in layer.data] if layer.x else []
y_vals = [row[layer.y.field] for row in layer.data] if layer.y else []
if mark == MarkType.bar:
return self._bar_trace(layer, x_vals, y_vals), on_secondary
if mark == MarkType.line:
return self._line_trace(layer, x_vals, y_vals), on_secondary
if mark == MarkType.scatter:
return self._scatter_trace(layer, x_vals, y_vals), on_secondary
if mark == MarkType.text:
return self._text_trace(layer, x_vals, y_vals), on_secondary
# Fallback to scatter
return self._scatter_trace(layer, x_vals, y_vals), on_secondary
def _bar_trace(
self,
layer: LayerSpec,
x_vals: list,
y_vals: list,
) -> go.Bar:
colors = layer.style.get("colors")
kwargs: dict[str, Any] = {
"x": x_vals,
"y": y_vals,
"name": layer.name,
"marker": dict(color=colors or layer.color or DOE_PALETTE["primary"]),
"opacity": layer.opacity,
}
error_y = layer.style.get("error_y")
if error_y is not None:
arr = [abs(float(e)) if e is not None else 0.0 for e in error_y]
if any(v > 0 for v in arr):
kwargs["error_y"] = dict(type="data", array=arr, visible=True)
return go.Bar(**kwargs)
def _line_trace(
self,
layer: LayerSpec,
x_vals: list,
y_vals: list,
) -> go.Scatter:
dash = layer.style.get("dash", "solid")
width = layer.style.get("width", 2)
return go.Scatter(
x=x_vals,
y=y_vals,
mode="lines",
name=layer.name,
line=dict(
color=layer.color or DOE_PALETTE["primary"],
dash=dash,
width=width,
),
opacity=layer.opacity,
)
def _scatter_trace(
self,
layer: LayerSpec,
x_vals: list,
y_vals: list,
) -> go.Scatter:
size = layer.style.get("size", 8)
symbol = layer.style.get("symbol", "circle")
colors = layer.style.get("colors")
hover_field = layer.style.get("hover_field")
hover_text = (
[row.get(hover_field, "") for row in layer.data] if hover_field else None
)
marker: dict[str, Any] = {
"color": colors or layer.color or DOE_PALETTE["primary"],
"size": size,
"symbol": symbol,
}
edge_color = layer.style.get("edge_color")
if edge_color is not None:
marker["line"] = {
"color": edge_color,
"width": layer.style.get("edge_width", 1),
}
return go.Scatter(
x=x_vals,
y=y_vals,
mode="markers",
name=layer.name,
marker=marker,
opacity=layer.opacity,
text=hover_text,
hoverinfo="text" if hover_text is not None else None,
)
def _contour_trace(self, layer: LayerSpec) -> go.Contour:
x_vals = layer.style.get("x_grid", [])
y_vals = layer.style.get("y_grid", [])
z_matrix = layer.style.get("z_matrix", [])
return go.Contour(
x=x_vals,
y=y_vals,
z=z_matrix,
name=layer.name,
colorscale=SURFACE_COLORSCALE,
contours=dict(showlabels=True),
)
def _surface_trace(self, layer: LayerSpec) -> go.Surface:
x_vals = layer.style.get("x_grid", [])
y_vals = layer.style.get("y_grid", [])
z_matrix = layer.style.get("z_matrix", [])
return go.Surface(
x=x_vals,
y=y_vals,
z=z_matrix,
name=layer.name,
colorscale=SURFACE_COLORSCALE,
)
def _heatmap_trace(self, layer: LayerSpec) -> go.Heatmap:
x_vals = layer.style.get("x_grid", [])
y_vals = layer.style.get("y_grid", [])
z_matrix = layer.style.get("z_matrix", [])
return go.Heatmap(
x=x_vals,
y=y_vals,
z=z_matrix,
name=layer.name,
colorscale=SURFACE_COLORSCALE,
)
def _text_trace(
self,
layer: LayerSpec,
x_vals: list,
y_vals: list,
) -> go.Scatter:
text_vals = [row.get("text", "") for row in layer.data]
return go.Scatter(
x=x_vals,
y=y_vals,
mode="text",
text=text_vals,
name=layer.name,
textfont=dict(size=layer.style.get("size", 12)),
)
def _boxplot_trace(self, layer: LayerSpec) -> go.Box:
"""Build a Plotly box trace from pre-computed quartiles.
Each row in ``layer.data`` provides ``group`` (category label)
and ``q_stats`` as ``[min, Q1, median, Q3, max]`` - matching
the wire format used by the ECharts adapter.
"""
groups = [row["group"] for row in layer.data]
lower = [row["q_stats"][0] for row in layer.data]
q1 = [row["q_stats"][1] for row in layer.data]
median = [row["q_stats"][2] for row in layer.data]
q3 = [row["q_stats"][3] for row in layer.data]
upper = [row["q_stats"][4] for row in layer.data]
return go.Box(
x=groups,
lowerfence=lower,
q1=q1,
median=median,
q3=q3,
upperfence=upper,
name=layer.name,
marker=dict(color=layer.color or DOE_PALETTE["primary"]),
opacity=layer.opacity,
boxpoints=False,
)
def _wireframe_trace(self, layer: LayerSpec) -> go.Scatter3d:
x_vals = [row[layer.x.field] for row in layer.data] if layer.x else []
y_vals = [row[layer.y.field] for row in layer.data] if layer.y else []
z_vals = [row[layer.z.field] for row in layer.data] if layer.z else []
return go.Scatter3d(
x=x_vals,
y=y_vals,
z=z_vals,
mode="lines+markers+text",
name=layer.name,
line=dict(color=layer.color or DOE_PALETTE["primary"], width=3),
marker=dict(size=5),
text=[row.get("text", "") for row in layer.data],
textposition="top center",
)
# ------------------------------------------------------------------
# Annotations → Plotly shapes / annotations
# ------------------------------------------------------------------
def _add_annotation( # noqa: C901
self,
fig: go.Figure,
ann: Annotation,
*,
row: int | None = None,
col: int | None = None,
) -> None:
at = ann.annotation_type
if isinstance(at, str):
at = AnnotationType(at)
color = ann.style.get("color", DOE_PALETTE["threshold_me"])
dash = ann.style.get("dash", "dash")
width = ann.style.get("width", 2)
if at in (AnnotationType.reference_line, AnnotationType.significance_threshold):
if ann.value is None:
return
if ann.axis == "y":
fig.add_hline(
y=ann.value,
line_dash=dash,
line_color=color,
line_width=width,
annotation_text=ann.label,
annotation_position="top right",
row=row or "all",
col=col or "all",
)
else:
fig.add_vline(
x=ann.value,
line_dash=dash,
line_color=color,
line_width=width,
annotation_text=ann.label,
annotation_position="top right",
row=row or "all",
col=col or "all",
)
elif at == AnnotationType.reference_band:
if ann.value is None or ann.value_end is None:
return
band_color = ann.style.get("fill_color", "rgba(37, 99, 235, 0.1)")
if ann.axis == "y":
fig.add_hrect(
y0=ann.value,
y1=ann.value_end,
fillcolor=band_color,
line_width=0,
row=row or "all",
col=col or "all",
)
else:
fig.add_vrect(
x0=ann.value,
x1=ann.value_end,
fillcolor=band_color,
line_width=0,
row=row or "all",
col=col or "all",
)
elif at == AnnotationType.constraint_region:
x_min = ann.style.get("x_min")
x_max = ann.style.get("x_max")
y_min = ann.style.get("y_min")
y_max = ann.style.get("y_max")
fill_color = ann.style.get("color", "rgba(220, 38, 38, 0.15)")
if x_min is not None and x_max is not None:
fig.add_vrect(
x0=x_min,
x1=x_max,
fillcolor=fill_color,
line_width=0,
annotation_text=ann.label,
row=row or "all",
col=col or "all",
)
if y_min is not None and y_max is not None:
fig.add_hrect(
y0=y_min,
y1=y_max,
fillcolor=fill_color,
line_width=0,
annotation_text=ann.label,
row=row or "all",
col=col or "all",
)