Source code for sft_wick.render_layout

"""Position computation for Feynman-diagram nodes (backend-agnostic).

This module owns the geometric layout of a :class:`FeynmanDiagram`:

1. external nodes are placed on a circle (or, for the common
   two-external case, mirrored on the x-axis),
2. interaction vertices are seeded around the origin and refined by
   ``networkx.spring_layout`` with externals pinned,
3. a post-pass enforces a minimum vertex separation,
4. for two-external layouts, vertices are clamped horizontally so
   the externals stay outermost.

Both rendering backends (:mod:`sft_wick.drawing` and
:mod:`sft_wick.drawing_tikz`) call :func:`compute_layout` and then
translate the resulting positions to their own coordinate
conventions.

The function accepts a ``manual_positions`` map so users can pin
arbitrary nodes — particularly useful when the spring layout
produces an aesthetically poor result and the user wants full
control.
"""

from __future__ import annotations

from collections.abc import Mapping

import networkx as nx
import numpy as np

from .diagrams import FeynmanDiagram
from .render_style import LayoutParams


[docs] def compute_layout( diagram: FeynmanDiagram, params: LayoutParams = LayoutParams(), manual_positions: Mapping[str, tuple[float, float]] | None = None, ) -> dict[str, np.ndarray]: """Return ``{node_id: np.array([x, y])}`` for every node. Args: diagram: The diagram to lay out. params: Layout parameters (radius, spring-layout knobs, separation thresholds, …). manual_positions: Optional pin map. Any node listed here is placed at the given coordinate and excluded from spring-layout updates. Externals not listed fall back to the circular placement; vertices not listed follow the standard spring + min-distance pipeline. Returns: A dict mapping node-id to a length-2 numpy array. Empty diagrams produce an empty dict; single-node diagrams return one entry at the origin. """ g = diagram.graph ext = diagram.external_nodes verts = diagram.vertex_nodes n_nodes = g.number_of_nodes() if n_nodes == 0: return {} if n_nodes == 1: only = list(g.nodes())[0] if manual_positions and only in manual_positions: return {only: np.array(manual_positions[only], dtype=float)} return {only: np.array([0.0, 0.0])} pos: dict[str, np.ndarray] = {} pinned: set[str] = set() if manual_positions: for node_id, xy in manual_positions.items(): if node_id in g.nodes: pos[node_id] = np.array(xy, dtype=float) pinned.add(node_id) # ------------------------------------------------------------------ # External nodes — circle / mirrored pair # ------------------------------------------------------------------ n_ext = len(ext) free_ext = [n for n in ext if n not in pinned] if n_ext == 2 and len(free_ext) == 2: pos[ext[0]] = np.array([-params.ext_radius, 0.0]) pos[ext[1]] = np.array([params.ext_radius, 0.0]) elif free_ext: # Place free externals at the angular slots they would occupy # in the full-circle scheme, so adding pins doesn't reshuffle # the others. for i, node in enumerate(ext): if node in pinned: continue angle = 2 * np.pi * i / max(n_ext, 1) - np.pi / 2 pos[node] = np.array( [params.ext_radius * np.cos(angle), params.ext_radius * np.sin(angle)] ) # ------------------------------------------------------------------ # Vertex nodes # ------------------------------------------------------------------ n_vert = len(verts) free_verts = [v for v in verts if v not in pinned] if n_vert == 0: pass elif len(free_verts) == 0: pass # all vertices are pinned elif n_vert == 1 and len(free_verts) == 1: pos[verts[0]] = np.array([0.0, 0.0]) elif n_ext == 0: for i, v in enumerate(verts): if v in pinned: continue angle = 2 * np.pi * i / n_vert - np.pi / 2 pos[v] = np.array([np.cos(angle), np.sin(angle)]) else: # Spring layout with externals + pins fixed. init_pos = dict(pos) for i, v in enumerate(free_verts): angle = 2 * np.pi * i / max(len(free_verts), 1) + np.pi / 4 init_pos.setdefault( v, np.array([0.5 * np.cos(angle), 0.5 * np.sin(angle)]), ) fixed_nodes = list(pinned | set(ext)) spring_pos = nx.spring_layout( g, pos=init_pos, fixed=fixed_nodes if fixed_nodes else None, k=params.spring_k, iterations=params.spring_iterations, seed=params.spring_seed, ) for v in free_verts: pos[v] = spring_pos[v] # Min-distance enforcement between every pair of vertices that # are *not* both pinned (we never push pinned nodes). _enforce_min_distance(pos, verts, pinned, params.min_vertex_dist) # ------------------------------------------------------------------ # Two-external clamp: keep vertices strictly between the externals # ------------------------------------------------------------------ if n_ext == 2 and verts: x_lo = min(pos[ext[0]][0], pos[ext[1]][0]) x_hi = max(pos[ext[0]][0], pos[ext[1]][0]) if x_hi > x_lo: inset = 0.35 * (x_hi - x_lo) for v in verts: if v in pinned: continue pos[v] = np.array([ np.clip(pos[v][0], x_lo + inset, x_hi - inset), pos[v][1], ]) # Disconnected diagrams can otherwise collapse visually because # spring_layout has no edges tying their components apart. if not pinned and params.component_gap > 0: _separate_connected_components(g, pos, params.component_gap) # ------------------------------------------------------------------ # Normalise the bounding box so every diagram has a comparable # visual extent. Skipped whenever the caller pinned any node — # rescaling would move pins away from their intended coordinates. # ------------------------------------------------------------------ if params.normalize_bbox and pos and not pinned: _normalize_bbox(pos, params.target_extent) return pos
def _normalize_bbox( pos: dict[str, np.ndarray], target_extent: tuple[float, float], ) -> None: """Re-centre and uniformly scale ``pos`` in-place so its bounding box fits within ``target_extent`` without distortion. Aspect ratio is preserved: the scale factor is the smaller of ``target_w / current_w`` and ``target_h / current_h``, so a short-and-wide diagram gets stretched horizontally to the target width while a tall-and-narrow one is bounded by the target height. """ pts = np.array(list(pos.values()), dtype=float) centroid = pts.mean(axis=0) pts_centred = pts - centroid half_w = float(np.abs(pts_centred[:, 0]).max()) half_h = float(np.abs(pts_centred[:, 1]).max()) target_half_w = target_extent[0] / 2.0 target_half_h = target_extent[1] / 2.0 if half_w < 1e-8 and half_h < 1e-8: # All nodes coincident — nothing meaningful to scale. return if half_w < 1e-8: scale = target_half_h / half_h elif half_h < 1e-8: scale = target_half_w / half_w else: scale = min(target_half_w / half_w, target_half_h / half_h) for k in pos: pos[k] = (pos[k] - centroid) * scale def _enforce_min_distance( pos: dict[str, np.ndarray], verts: list[str], pinned: set[str], min_dist: float, max_iter: int = 10, ) -> None: """In-place: push every pair of free vertices at least ``min_dist`` apart.""" if min_dist <= 0: return for _ in range(max_iter): changed = False for iv, v1 in enumerate(verts): for v2 in verts[iv + 1:]: d = float(np.linalg.norm(pos[v1] - pos[v2])) if d == 0: if v2 not in pinned: pos[v2] = pos[v2] + np.array([min_dist, 0.0]) changed = True elif v1 not in pinned: pos[v1] = pos[v1] + np.array([-min_dist, 0.0]) changed = True continue if d >= min_dist: continue direction = (pos[v2] - pos[v1]) / d push = (min_dist - d) / 2 + 0.05 if v1 not in pinned: pos[v1] = pos[v1] - push * direction changed = True if v2 not in pinned: pos[v2] = pos[v2] + push * direction changed = True if not changed: break def _separate_connected_components( g: nx.MultiGraph, pos: dict[str, np.ndarray], component_gap: float, ) -> None: """Move disconnected components apart horizontally in-place.""" components = [list(c) for c in nx.connected_components(g)] if len(components) <= 1: return boxes = [] for comp in components: pts = np.array([pos[n] for n in comp], dtype=float) boxes.append({ "nodes": comp, "min_x": float(pts[:, 0].min()), "max_x": float(pts[:, 0].max()), "center_x": float(pts[:, 0].mean()), }) boxes.sort(key=lambda b: b["center_x"]) old_center = np.mean(np.array(list(pos.values()), dtype=float), axis=0) current_right = boxes[0]["max_x"] for box in boxes[1:]: shift = max(0.0, current_right + component_gap - box["min_x"]) if shift: for node in box["nodes"]: pos[node] = pos[node] + np.array([shift, 0.0]) box["min_x"] += shift box["max_x"] += shift box["center_x"] += shift current_right = max(current_right, box["max_x"]) new_center = np.mean(np.array(list(pos.values()), dtype=float), axis=0) delta = new_center - old_center for node in pos: pos[node] = pos[node] - delta
[docs] def neighbor_center( g: nx.MultiGraph, node: str, pos: Mapping[str, np.ndarray], ) -> np.ndarray: """Centroid of a node's non-self neighbours, used for self-loop direction. If the node has no other neighbours, returns a point one unit below it so self-loops on isolated tadpoles point straight up. """ neighbors = [n for n in g.neighbors(node) if n != node] if not neighbors: return pos[node] + np.array([0.0, -1.0]) return np.mean(np.array([pos[n] for n in neighbors]), axis=0)
[docs] def label_offset( point: np.ndarray, center: np.ndarray, distance: float = 18.0, ) -> tuple[float, float]: """Return a label-offset vector (in points) pointing away from ``center``. Used by both backends to place labels on the outside of the bounding box where they are least likely to overlap edges. """ direction = point - center norm = float(np.linalg.norm(direction)) if norm < 1e-8: return (0.0, distance) direction = direction / norm return (float(direction[0] * distance), float(direction[1] * distance))