"""Matplotlib rendering for :class:`FeynmanDiagram` objects.
Visual conventions are now controlled by a
:class:`~sft_wick.render_style.RenderStyle` instance: colours,
linestyles, marker sizes, fonts, layout parameters, and label
formatting are all configurable. See :mod:`sft_wick.render_style`
for the available presets (default / publication / grayscale /
minimal) and :mod:`sft_wick.render_labels` for the label
override hooks.
The constructor remains backward-compatible — calling
``DiagramRenderer(figsize=(8, 6))`` keeps working with the
``default_style`` preset.
For a LaTeX-native alternative, see
:class:`sft_wick.drawing_tikz.TikzRenderer`.
"""
from __future__ import annotations
from collections import Counter
from collections.abc import Callable, Mapping, Sequence
from typing import Any
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from .diagrams import FeynmanDiagram
from .render_labels import (
LabelCallable,
default_external_label,
default_vertex_label,
resolve_label,
)
from .render_layout import compute_layout, label_offset, neighbor_center
from .render_style import (
NodeStyle,
PropagatorStyle,
RenderStyle,
default_style,
)
# Map our backend-agnostic linestyle names to matplotlib values.
_MPL_LINESTYLE: dict[str, str] = {
"solid": "-",
"dashed": "--",
"dotted": ":",
"dashdot": "-.",
}
_MPL_MARKER: dict[str, str] = {
"circle": "o",
"square": "s",
"diamond": "D",
"triangle": "^",
}
[docs]
class DiagramRenderer:
"""Render :class:`FeynmanDiagram` objects with matplotlib.
Args:
figsize: Figure size in inches for the single-diagram entry
point and the per-cell size for ``draw_all``.
style: Visual style. ``None`` → :func:`default_style`.
external_label_fn: Optional callable
``fn(node_id, node_attrs) -> str | None`` invoked per
external node. Returning ``None`` falls through to the
built-in formatter.
vertex_label_fn: Same shape, for interaction vertices.
label_format: Default external-label format
(:data:`~sft_wick.render_style.LABEL_COMPACT` /
``LABEL_FULL`` / ``LABEL_TIME_F``). Overrides
``style.label_format`` when supplied explicitly.
"""
# Backwards-compatible class attribute. Users that previously
# mutated ``DiagramRenderer.PROP_STYLES["C"]["color"]`` to retheme
# globally still see their override take effect via
# :meth:`_resolve_style`.
PROP_STYLES: dict[str, dict[str, Any]] = {
"C": {"linestyle": "-", "color": "#2166ac", "linewidth": 2.0},
"R": {"linestyle": "--", "color": "#d6604d", "linewidth": 2.0},
}
def __init__(
self,
figsize: tuple[float, float] = (6, 5),
style: RenderStyle | None = None,
external_label_fn: LabelCallable | None = None,
vertex_label_fn: LabelCallable | None = None,
label_format: str | None = None,
) -> None:
self.figsize = figsize
# Track whether the user supplied a custom style. When they
# did not, we honour any monkey-patches to the legacy class
# attribute ``PROP_STYLES`` (for backward compatibility);
# when they did, we respect their style verbatim.
self._user_supplied_style: bool = style is not None
self._raw_style: RenderStyle = style if style is not None else default_style()
self.external_label_fn = external_label_fn
self.vertex_label_fn = vertex_label_fn
self._label_format_override = label_format
# ------------------------------------------------------------------
# Style helpers
# ------------------------------------------------------------------
@property
def style(self) -> RenderStyle:
"""The active :class:`RenderStyle`.
When the user did *not* pass a ``style`` to the constructor,
any monkey-patches to the legacy class attribute
:attr:`PROP_STYLES` are folded in for backward compatibility.
When the user did pass a ``style``, it is respected verbatim.
"""
return self._resolve_style()
def _resolve_style(self) -> RenderStyle:
base = self._raw_style
if self._user_supplied_style:
return base
# Backwards-compat: pick up any tweaks the user made to
# ``DiagramRenderer.PROP_STYLES`` at the class level.
legacy_props: dict[str, PropagatorStyle] = {}
for kind, defaults in type(self).PROP_STYLES.items():
current = base.propagators.get(kind)
if current is None:
continue
ls_name = _matplotlib_linestyle_to_name(defaults.get("linestyle"))
legacy_props[kind] = PropagatorStyle(
color=defaults.get("color", current.color),
linestyle=ls_name or current.linestyle,
linewidth=defaults.get("linewidth", current.linewidth),
arrow=current.arrow,
legend_label=current.legend_label,
)
if legacy_props:
new_props = dict(base.propagators)
new_props.update(legacy_props)
return base.with_overrides(propagators=new_props)
return base
@property
def label_format(self) -> str:
return (
self._label_format_override
if self._label_format_override is not None
else self._raw_style.label_format
)
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
[docs]
def draw(
self,
diagram: FeynmanDiagram,
ax: plt.Axes | None = None,
title: str = "",
title_kwargs: Mapping[str, Any] | None = None,
external_labels: Mapping[str, str] | None = None,
vertex_labels: Mapping[str, str] | None = None,
positions: Mapping[str, tuple[float, float]] | None = None,
show_legend: bool | None = None,
) -> plt.Axes:
"""Draw a single Feynman diagram on a matplotlib axes.
Args:
diagram: The diagram to render.
ax: Existing axes to draw on; ``None`` →
a new figure of size ``self.figsize``.
title: Per-axes title. Empty string → use
``diagram.summary()``.
title_kwargs: Extra kwargs passed to ``ax.set_title``
(overrides the style's
``title_fontsize``).
external_labels: Map ``{node_id: label}`` overriding
specific external-vertex labels.
vertex_labels: Same shape, for interaction vertices.
positions: Map ``{node_id: (x, y)}`` pinning
specific node coordinates and skipping
spring layout for them.
show_legend: Override ``style.show_legend``.
"""
style = self._resolve_style()
with mpl.rc_context(style.effective_rcparams()):
return self._draw_within_context(
diagram=diagram,
ax=ax,
title=title,
title_kwargs=title_kwargs,
external_labels=external_labels,
vertex_labels=vertex_labels,
positions=positions,
show_legend=show_legend,
style=style,
)
[docs]
def draw_all(
self,
diagrams: list[FeynmanDiagram],
ncols: int = 3,
suptitle: str = "",
suptitle_kwargs: Mapping[str, Any] | None = None,
subtitle_fn: Callable[[int, FeynmanDiagram, int], str] | None = None,
multiplicities: Sequence[int] | None = None,
shared_legend: bool = True,
wspace: float = 0.04,
hspace: float | None = None,
show: bool = False,
external_labels: Mapping[int, Mapping[str, str]] | None = None,
) -> plt.Figure:
"""Draw a grid of Feynman diagrams.
Args:
diagrams: Diagrams to lay out, one per subplot.
ncols: Maximum columns (clipped to ``len(diagrams)``).
suptitle: Figure-level title; pass ``""`` to skip.
suptitle_kwargs: Extra kwargs forwarded to ``fig.suptitle``.
subtitle_fn: Per-subplot title formatter
``fn(index, diagram, multiplicity) -> str``.
``None`` → ``"#i: <summary> [xN]"``.
multiplicities: Optional per-diagram multiplicities (for
default subtitles).
shared_legend: If ``True`` (default), suppress per-panel
legends and draw one figure-level legend for
all propagator kinds present in the grid.
``False`` restores per-panel legends.
wspace, hspace: Optional subplot spacing passed to
``Figure.subplots_adjust`` after
``tight_layout``. ``hspace=None`` uses a
compact automatic value and relaxes it if
rendered rows would overlap.
show: If ``True``, call ``plt.show()`` after
building the figure. Default ``False``
— callers that ``savefig`` should leave
this off.
external_labels: Optional map keyed by *subplot index*
whose values are
``{node_id: label}`` overrides for that
subplot.
Returns:
The matplotlib :class:`Figure` (always returned, even when
``show=True``).
"""
n = len(diagrams)
style = self._resolve_style()
if n == 0:
fig, ax = plt.subplots(1, 1, figsize=self.figsize)
ax.text(0.5, 0.5, "No diagrams", ha="center", va="center")
ax.axis("off")
if show:
plt.show()
return fig
if multiplicities is None:
multiplicities = [1] * n
ncols = min(ncols, n)
nrows = (n + ncols - 1) // ncols
with mpl.rc_context(style.effective_rcparams()):
fig, axes = plt.subplots(
nrows, ncols,
figsize=(self.figsize[0] * ncols, self.figsize[1] * nrows),
)
if nrows == 1 and ncols == 1:
axes = np.array([axes])
axes_flat = np.array(axes).flatten()
edge_kinds = {
data.get("kind")
for diagram in diagrams
for _, _, data in diagram.graph.edges(data=True)
}
handles = (
self._legend_handles(edge_kinds, style)
if shared_legend and style.show_legend else []
)
legend_panel_idx = (
_legend_panel_index(n, ncols, nrows)
if handles else None
)
diagram_axes = [
ax for idx, ax in enumerate(axes_flat)
if idx != legend_panel_idx
][:n]
for i, (diagram, ax) in enumerate(zip(diagrams, diagram_axes)):
mult = multiplicities[i]
if subtitle_fn is not None:
label = subtitle_fn(i, diagram, mult)
else:
summary = diagram.summary(short=True)
label = f"#{i + 1}: {summary}"
if mult > 1:
label += f" [x{mult}]"
ext_overrides = (
external_labels.get(i)
if external_labels is not None else None
)
self._draw_within_context(
diagram=diagram,
ax=ax,
title=label,
title_kwargs=None,
external_labels=ext_overrides,
vertex_labels=None,
positions=None,
show_legend=False if shared_legend else None,
style=style,
)
used_axes = set(diagram_axes)
if legend_panel_idx is not None:
legend_ax = axes_flat[legend_panel_idx]
_draw_shared_legend_panel(legend_ax, handles, style)
used_axes.add(legend_ax)
for ax in axes_flat:
if ax not in used_axes:
ax.axis("off")
fig_legend_at_top = bool(handles and legend_panel_idx is None)
if fig_legend_at_top:
fig.legend(
handles=handles,
loc="upper center",
bbox_to_anchor=(0.5, 0.995 if not suptitle else 0.94),
ncol=len(handles),
fontsize=style.legend_fontsize,
frameon=False,
)
bottom = 0.02
if fig_legend_at_top:
top = 0.84 if suptitle else 0.89
else:
top = 0.94 if suptitle else 0.97
if suptitle:
kw: dict[str, Any] = {"fontsize": style.suptitle_fontsize}
if suptitle_kwargs:
kw.update(dict(suptitle_kwargs))
fig.suptitle(suptitle, **kw)
fig.tight_layout(rect=(0, bottom, 1, top), pad=0.35)
effective_hspace = 0.22 if hspace is None else hspace
adjust_kwargs: dict[str, float] = {}
if wspace is not None:
adjust_kwargs["wspace"] = wspace
if effective_hspace is not None:
adjust_kwargs["hspace"] = effective_hspace
if adjust_kwargs:
fig.subplots_adjust(**adjust_kwargs)
if hspace is None and nrows > 1:
_relax_vertical_spacing(fig, diagram_axes)
if show:
plt.show()
return fig
# ------------------------------------------------------------------
# Internal: do the actual drawing inside the rc_context block
# ------------------------------------------------------------------
def _draw_within_context(
self,
diagram: FeynmanDiagram,
ax: plt.Axes | None,
title: str,
title_kwargs: Mapping[str, Any] | None,
external_labels: Mapping[str, str] | None,
vertex_labels: Mapping[str, str] | None,
positions: Mapping[str, tuple[float, float]] | None,
show_legend: bool | None,
style: RenderStyle,
) -> plt.Axes:
if ax is None:
_fig, ax = plt.subplots(1, 1, figsize=self.figsize)
g = diagram.graph
if g.number_of_nodes() == 0:
shown_title = title or "Empty diagram"
self._set_title(ax, shown_title, title_kwargs, style)
ax.axis("off")
return ax
pos = compute_layout(diagram, style.layout, positions)
# Pre-count parallel edges and self-loops
edge_pair_count: Counter = Counter()
self_loop_count: Counter = Counter()
for u, v, _key in g.edges(keys=True):
if u == v:
self_loop_count[u] += 1
else:
pair = (min(u, v), max(u, v))
edge_pair_count[pair] += 1
edge_pair_idx: Counter = Counter()
self_loop_idx: Counter = Counter()
# Edges
for u, v, _key, data in g.edges(keys=True, data=True):
kind = data.get("kind", "C")
prop_style = style.propagators.get(kind)
if prop_style is None:
prop_style = next(iter(style.propagators.values()))
p1 = np.array(pos[u])
p2 = np.array(pos[v])
if u == v:
idx = self_loop_idx[u]
total = self_loop_count[u]
self_loop_idx[u] += 1
ncenter = neighbor_center(g, u, pos)
self._draw_self_loop(
ax, p1, prop_style,
loop_index=idx, n_loops=total,
away_from=ncenter,
)
else:
pair = (min(u, v), max(u, v))
idx = edge_pair_idx[pair]
total = edge_pair_count[pair]
edge_pair_idx[pair] += 1
# R propagators are directed: the arrowhead must land
# on the physical (φ) end, i.e. the arrow points ψ → φ
# (matches the TikZ renderer — see drawing_tikz.py).
#
# The edge geometry is ALWAYS drawn p1→p2 (the stored
# u→v orientation) so that the curvature side of
# parallel edges (``rad`` below) is independent of the
# arrow direction. If we flipped p1/p2 to steer the
# head, a reversed R edge would bow onto its parallel
# C/R partner and overlap it. Instead the head is
# placed on ``phi_end`` via the arrowstyle:
# ``"<|-"`` puts it on posA (=u), ``"-|>"`` on posB (=v).
phi_end = data.get("phi_end")
head_at_start = (
prop_style.arrow
and phi_end is not None
and phi_end == u
)
self._draw_edge(ax, p1, p2, prop_style,
key=idx, n_parallel=total,
head_at_start=head_at_start,
curvature=style.layout.parallel_edge_curvature)
# Nodes
ext_nodes = diagram.external_nodes
vert_nodes = diagram.vertex_nodes
all_pts = np.array([pos[n] for n in g.nodes()])
center = np.mean(all_pts, axis=0)
for n in ext_nodes:
self._draw_node(
ax, pos[n], style.external_node, zorder=5,
)
label = resolve_label(
node_id=n,
node_attrs=g.nodes[n],
overrides=external_labels,
callable_fn=self.external_label_fn,
default_fn=lambda attrs: default_external_label(
attrs, fmt=self.label_format,
),
)
self._draw_label(
ax, pos[n], label, style.external_label,
center=center,
)
for n in vert_nodes:
self._draw_node(
ax, pos[n], style.vertex_node, zorder=5,
)
label = resolve_label(
node_id=n,
node_attrs=g.nodes[n],
overrides=vertex_labels,
callable_fn=self.vertex_label_fn,
default_fn=default_vertex_label,
)
if label:
self._draw_label(
ax, pos[n], label, style.vertex_label,
center=center,
)
# Legend
legend_visible = (
style.show_legend if show_legend is None else show_legend
)
if legend_visible:
self._draw_legend(ax, g, style)
# Title
shown_title = title if title else diagram.summary()
self._set_title(ax, shown_title, title_kwargs, style)
ax.axis("off")
ax.set_aspect("equal")
margin = style.layout.margin
# The layout is normalised iff bbox-normalisation is on AND the
# caller did not pin any nodes (manual positions skip the
# normalisation pass — see render_layout.compute_layout).
layout_was_normalised = (
style.layout.normalize_bbox and not positions
)
if layout_was_normalised:
# Use the standardised target extent so every panel in a
# grid has identical coordinate limits — this is what
# makes loop radii, edge thicknesses, and marker sizes
# look consistent across diagrams. Very flat diagrams
# (for example a bare two-point line) get a shallower
# y-range anchored near the panel title so they do not
# float in the middle of an otherwise empty subplot.
half_w = style.layout.target_extent[0] / 2.0
half_h = style.layout.target_extent[1] / 2.0
ax.set_xlim(-half_w - margin, half_w + margin)
data_h = ax.dataLim.height
bare_propagator = not diagram.vertex_nodes
if (
bare_propagator
and np.isfinite(data_h)
and data_h < 0.25 * style.layout.target_extent[1]
):
data_center_y = 0.5 * (ax.dataLim.y0 + ax.dataLim.y1)
min_h = 0.6
ax.set_ylim(
data_center_y - min_h / 2.0 - 2.4 * margin,
data_center_y + min_h / 2.0 + 0.15 * margin,
)
ax.set_anchor("N")
else:
ax.set_ylim(
-half_h - 1.8 * margin,
half_h + 0.2 * margin,
)
else:
ax.set_xlim(
all_pts[:, 0].min() - margin, all_pts[:, 0].max() + margin,
)
ax.set_ylim(
all_pts[:, 1].min() - margin, all_pts[:, 1].max() + margin,
)
return ax
# ------------------------------------------------------------------
# Drawing primitives
# ------------------------------------------------------------------
def _draw_edge(
self,
ax: plt.Axes,
p1: np.ndarray,
p2: np.ndarray,
style: PropagatorStyle,
key: int = 0,
n_parallel: int = 1,
head_at_start: bool = False,
curvature: float = 0.6,
) -> None:
from matplotlib.patches import FancyArrowPatch
if n_parallel <= 1:
rad = 0.0
else:
idx = key - (n_parallel - 1) / 2.0
rad = idx * curvature
# ``head_at_start`` steers a directed edge's arrowhead onto
# posA (the start) instead of posB (the end), without changing
# the geometry — see the directed-edge comment in the edge loop.
if not style.arrow:
arrowstyle = "-"
elif head_at_start:
arrowstyle = "<|-"
else:
arrowstyle = "-|>"
arrow = FancyArrowPatch(
posA=tuple(p1),
posB=tuple(p2),
connectionstyle=f"arc3,rad={rad}",
arrowstyle=arrowstyle,
color=style.color,
linestyle=_MPL_LINESTYLE.get(style.linestyle, style.linestyle),
linewidth=style.linewidth,
mutation_scale=12,
shrinkA=4,
shrinkB=4,
zorder=3 if style.arrow else 2,
)
ax.add_patch(arrow)
def _draw_self_loop(
self,
ax: plt.Axes,
center: np.ndarray,
style: PropagatorStyle,
loop_index: int = 0,
n_loops: int = 1,
away_from: np.ndarray | None = None,
) -> None:
if away_from is not None:
base_dir = center - away_from
norm = np.linalg.norm(base_dir)
if norm > 1e-8:
base_dir = base_dir / norm
else:
base_dir = np.array([0.0, 1.0])
else:
base_dir = np.array([0.0, 1.0])
if n_loops > 1:
spread = np.pi / 3
angle_offset = (loop_index - (n_loops - 1) / 2) * spread
cos_a, sin_a = np.cos(angle_offset), np.sin(angle_offset)
direction = np.array([
base_dir[0] * cos_a - base_dir[1] * sin_a,
base_dir[0] * sin_a + base_dir[1] * cos_a,
])
else:
direction = base_dir
loop_radius = 0.5
loop_center = center + (loop_radius + 0.05) * direction
circle = plt.Circle(
tuple(loop_center),
loop_radius,
fill=False,
color=style.color,
linestyle=_MPL_LINESTYLE.get(style.linestyle, style.linestyle),
linewidth=style.linewidth,
zorder=2,
)
ax.add_patch(circle)
if style.arrow:
angle = np.arctan2(direction[1], direction[0])
arrow_angle = angle + np.pi / 3
tip = loop_center + loop_radius * np.array(
[np.cos(arrow_angle), np.sin(arrow_angle)])
tail = loop_center + loop_radius * np.array(
[np.cos(arrow_angle + 0.2), np.sin(arrow_angle + 0.2)])
ax.annotate(
"", xy=tuple(tip), xytext=tuple(tail),
arrowprops=dict(arrowstyle="-|>", color=style.color,
lw=style.linewidth),
)
def _draw_node(
self,
ax: plt.Axes,
point: np.ndarray,
style: NodeStyle,
zorder: int = 5,
) -> None:
marker = _MPL_MARKER.get(style.shape, "o")
face = style.color if style.fill else "none"
edge = style.edge_color if style.edge_color != "none" else style.color
ax.plot(
point[0], point[1],
marker=marker,
markersize=style.size,
markerfacecolor=face,
markeredgecolor=edge,
linestyle="None",
zorder=zorder,
)
def _draw_label(
self,
ax: plt.Axes,
point: np.ndarray,
text: str,
style: "Any", # LabelStyle, but loosely typed to avoid extra import here
center: np.ndarray,
) -> None:
if not text:
return
offset = label_offset(point, center, distance=style.offset_pt)
bbox = (
dict(
boxstyle="round,pad=0.15", facecolor="white",
edgecolor="none", alpha=style.bbox_alpha,
)
if style.bbox else None
)
ax.annotate(
text, point,
textcoords="offset points", xytext=offset,
ha="center", va="center",
fontsize=style.fontsize,
fontweight="bold" if style.bold else "normal",
bbox=bbox,
zorder=6,
)
def _draw_legend(
self,
ax: plt.Axes,
g: nx.MultiGraph,
style: RenderStyle,
) -> None:
edge_kinds = {data.get("kind") for _, _, data in g.edges(data=True)}
handles = self._legend_handles(edge_kinds, style)
if handles:
ax.legend(
handles=handles,
loc=style.legend_loc,
fontsize=style.legend_fontsize,
framealpha=0.8,
edgecolor="none",
)
def _legend_handles(
self,
edge_kinds: set[Any],
style: RenderStyle,
) -> list[plt.Line2D]:
handles = []
for kind, prop in style.propagators.items():
if kind not in edge_kinds:
continue
label = prop.legend_label
if not label:
continue
handles.append(plt.Line2D(
[0], [0],
color=prop.color,
linestyle=_MPL_LINESTYLE.get(prop.linestyle, prop.linestyle),
lw=prop.linewidth,
label=label,
))
return handles
def _set_title(
self,
ax: plt.Axes,
title: str,
title_kwargs: Mapping[str, Any] | None,
style: RenderStyle,
) -> None:
if not title or style.title_fontsize <= 0:
return
kw: dict[str, Any] = {"fontsize": style.title_fontsize, "pad": 0.0}
if title_kwargs:
kw.update(dict(title_kwargs))
ax.set_title(title, **kw)
# ------------------------------------------------------------------
# Backwards-compat helpers (kept in case external code uses them)
# ------------------------------------------------------------------
@staticmethod
def _neighbor_center(
g: nx.MultiGraph, node: str, pos: dict[str, np.ndarray],
) -> np.ndarray:
return neighbor_center(g, node, pos)
@staticmethod
def _label_offset(
point: np.ndarray, center: np.ndarray, distance: float = 18,
) -> tuple[float, float]:
return label_offset(point, center, distance=distance)
# ----------------------------------------------------------------------
# Helpers
# ----------------------------------------------------------------------
def _matplotlib_linestyle_to_name(ls: str | None) -> str | None:
"""Reverse of ``_MPL_LINESTYLE``.
The legacy ``PROP_STYLES`` class attribute uses ``"-"`` / ``"--"``;
the new style API uses ``"solid"`` / ``"dashed"``. This helper
bridges the two so ``DiagramRenderer.PROP_STYLES["C"]["linestyle"]
= "-."`` still works.
"""
if ls is None:
return None
inverse = {
"-": "solid",
"--": "dashed",
":": "dotted",
"-.": "dashdot",
}
return inverse.get(ls, ls)
def _legend_panel_index(n: int, ncols: int, nrows: int) -> int | None:
"""Reserve the upper-right empty panel for a shared legend."""
total_panels = ncols * nrows
if ncols <= 1 or total_panels <= n:
return None
return ncols - 1
def _draw_shared_legend_panel(
ax: plt.Axes,
handles: Sequence[plt.Line2D],
style: RenderStyle,
) -> None:
"""Use an otherwise empty subplot as the shared legend panel."""
ax.axis("off")
ax.legend(
handles=handles,
loc="center",
ncol=1,
fontsize=style.legend_fontsize,
frameon=False,
handlelength=2.4,
borderaxespad=0.0,
)
def _relax_vertical_spacing(
fig: plt.Figure,
axes: Sequence[plt.Axes],
*,
min_gap_px: float = 8.0,
max_hspace: float = 0.8,
) -> None:
"""Increase row spacing until rendered tight-bboxes no longer touch."""
if len(axes) <= 1:
return
for _ in range(8):
fig.canvas.draw()
gap = _minimum_row_gap_px(fig, axes)
if gap is None or gap >= min_gap_px:
return
current = fig.subplotpars.hspace or 0.0
if current >= max_hspace:
return
fig.subplots_adjust(hspace=min(max_hspace, current + 0.06))
def _minimum_row_gap_px(
fig: plt.Figure,
axes: Sequence[plt.Axes],
) -> float | None:
"""Return the smallest vertical gap between neighbouring rows."""
renderer = fig.canvas.get_renderer()
rows: dict[float, list[Any]] = {}
for ax in axes:
bbox = ax.get_tightbbox(renderer)
if bbox is None:
continue
row_key = round(ax.get_position().y0, 4)
rows.setdefault(row_key, []).append(bbox)
if len(rows) <= 1:
return None
row_bounds = []
for row_y, bboxes in rows.items():
row_bounds.append((
row_y,
min(b.y0 for b in bboxes),
max(b.y1 for b in bboxes),
))
row_bounds.sort(key=lambda item: item[0], reverse=True)
gaps = []
for upper, lower in zip(row_bounds, row_bounds[1:]):
upper_min_y = upper[1]
lower_max_y = lower[2]
gaps.append(upper_min_y - lower_max_y)
return min(gaps) if gaps else None