"""Perturbative expansion driver.
Computes <O>_S = sum_{n=0}^{N} (-1)^n / n! * <O * S_int^n>_{S_0}
using Wick's theorem for each order.
"""
from __future__ import annotations
from dataclasses import dataclass
from math import factorial
from typing import TYPE_CHECKING
import numpy as np
from .action import Action
from .expressions import (
ZERO,
Expr,
ImaginaryUnit,
IntegralOver,
KroneckerDelta,
Product,
Propagator,
Rational,
Sum,
SumOverIndex,
Symbol,
_latex_index,
apply_response_phase,
)
from .fields import FieldOperator
from .indices import IndexContext
from .propagators import contract_pair
from .simplify import (
_apply_index_sub,
_canonical_diagram_form,
_match_propagators_after_spatial,
collect_by_diagram,
diagonal_propagators,
simplify,
)
from .vertices import VertexInstance
from .wick import Pairing, SpatialSignature, wick_contract, wick_contract_spatial
if TYPE_CHECKING:
from .evaluate import DiagramIntegrand, SpatialStructure
from matplotlib.figure import Figure
[docs]
class PerturbativeResult:
"""Container for the result of a perturbative calculation.
Stores the symbolic expression at each perturbative order together
with the corresponding Feynman diagrams.
Attributes:
order_terms: Mapping from perturbative order *n* to the
simplified expression for that order.
total: Sum of all non-zero order contributions.
diagrams_by_order: Mapping from perturbative order *n* to the
list of :class:`DiagramInfo` records at that order.
"""
def __init__(
self,
order_terms: dict[int, Expr],
total: Expr,
diagrams_by_order: dict[int, list[DiagramInfo]],
diagram_terms_by_order: dict[int, list[DiagramTerm]] | None = None,
) -> None:
self.order_terms = order_terms
self.total = total
self.diagrams_by_order = diagrams_by_order
self.diagram_terms_by_order: dict[int, list[DiagramTerm]] = (
diagram_terms_by_order or {}
)
[docs]
def order(self, n: int) -> Expr:
"""Get the contribution at a specific perturbative order.
Args:
n: The perturbative order (0, 1, 2, ...).
Returns:
The simplified expression at order *n*, or ``ZERO`` if that
order has no contribution.
"""
return self.order_terms.get(n, ZERO)
[docs]
def diagram_terms(self, order: int) -> list[DiagramTerm]:
"""Structured diagram contributions for numerical evaluation.
Each :class:`DiagramTerm` carries the propagators, coupling
coefficient, prefactor, and index structure needed to evaluate
a single Feynman diagram numerically.
Populated when ``collect_topology=True`` (the default).
Args:
order: The perturbative order.
Returns:
List of :class:`DiagramTerm` at that order, or empty list.
"""
return self.diagram_terms_by_order.get(order, [])
[docs]
def to_latex(self) -> str:
"""Generate LaTeX for the full result, order by order.
When diagram terms are available (the default ``collect_topology=True``
path), each order is rendered as a sum of fully-wrapped diagram
contributions — with summation indices, integration variables,
response phase, and prefactor all explicit.
Falls back to the raw symbolic expression when diagram terms are
not populated.
Returns:
A multi-line string with one ``O(n): <latex>`` line per
non-zero order.
"""
lines: list[str] = []
for n in sorted(self.order_terms.keys()):
dts = self.diagram_terms_by_order.get(n, [])
if dts:
term_strs = [dt.to_latex() for dt in dts]
combined = " + ".join(term_strs) if len(term_strs) == 1 else \
" \\\\\n + ".join(term_strs)
lines.append(f"O({n}): {combined}")
else:
expr = self.order_terms[n]
if not _is_zero(expr):
lines.append(f"O({n}): {expr.to_latex()}")
return "\n".join(lines)
[docs]
def draw_diagrams(self, order: int | None = None, **kwargs) -> "Figure | None":
"""Draw Feynman diagrams using matplotlib.
Topologically identical diagrams are drawn only once, with a
``×N`` multiplicity label when *N* > 1.
Args:
order: If given, draw only diagrams at this perturbative
order. Otherwise draw all diagrams.
**kwargs: Forwarded to :class:`~sft_wick.drawing.DiagramRenderer`
(e.g. ``figsize``). Grid-level options accepted by
:meth:`~sft_wick.drawing.DiagramRenderer.draw_all`
(``ncols``, ``suptitle``, ``suptitle_kwargs``,
``subtitle_fn``, ``shared_legend``, ``wspace``,
``hspace``, ``show``, ``external_labels``) are forwarded
there instead.
Returns:
The matplotlib ``Figure`` containing the rendered diagrams,
or ``None`` when there are no diagrams. Returning the figure
makes the method display reliably as the final expression in
notebooks and lets scripts call ``fig.savefig(...)``.
"""
from collections import OrderedDict
from .diagrams import FeynmanDiagram
from .drawing import DiagramRenderer
draw_all_keys = {
"ncols",
"suptitle",
"suptitle_kwargs",
"subtitle_fn",
"shared_legend",
"wspace",
"hspace",
"show",
"external_labels",
}
draw_all_kwargs = {
key: kwargs.pop(key)
for key in list(kwargs.keys())
if key in draw_all_keys
}
renderer = DiagramRenderer(**kwargs)
if order is not None:
diagrams = self.diagrams_by_order.get(order, [])
else:
diagrams = []
for o in sorted(self.diagrams_by_order.keys()):
diagrams.extend(self.diagrams_by_order[o])
if not diagrams:
print("No diagrams to draw.")
return
# Build FeynmanDiagram objects and deduplicate by canonical form
fd_list = [d.to_feynman_diagram() for d in diagrams]
groups: OrderedDict[tuple, list[FeynmanDiagram]] = OrderedDict()
for fd in fd_list:
key = fd.canonical_form()
if key not in groups:
groups[key] = []
groups[key].append(fd)
unique_diagrams = [group[0] for group in groups.values()]
multiplicities = [len(group) for group in groups.values()]
return renderer.draw_all(
unique_diagrams,
multiplicities=multiplicities,
**draw_all_kwargs,
)
def __repr__(self) -> str:
return self.to_latex()
[docs]
class DiagramInfo:
"""Lightweight record of a Feynman diagram for deferred construction.
The actual :class:`~sft_wick.diagrams.FeynmanDiagram` graph is built
lazily via :meth:`to_feynman_diagram` to avoid up-front cost when
many diagrams are generated.
Attributes:
observable_ops: Field operators forming the observable.
vertex_instances: Instantiated vertices contributing to this diagram.
pairing: The Wick contraction pairing (tuple of index pairs).
coefficient: Rational prefactor for this diagram.
order: The perturbative order at which this diagram appears.
"""
def __init__(
self,
observable_ops: list[FieldOperator],
vertex_instances: list[VertexInstance],
pairing: Pairing,
coefficient: Rational,
order: int,
) -> None:
self.observable_ops = observable_ops
self.vertex_instances = vertex_instances
self.pairing = pairing
self.coefficient = coefficient
self.order = order
[docs]
def to_feynman_diagram(self):
"""Construct a :class:`~sft_wick.diagrams.FeynmanDiagram` from this record.
Returns:
A fully-constructed ``FeynmanDiagram`` graph with external
nodes, vertices, and propagator edges.
"""
from .diagrams import FeynmanDiagram
return FeynmanDiagram.from_pairing(
self.observable_ops,
self.vertex_instances,
self.pairing,
)
[docs]
@dataclass(frozen=True)
class DiagramTerm:
"""A single Feynman diagram's contribution, structured for numerical evaluation.
The full contribution is::
rational_prefactor × response_phase_factor × coupling_sum × ∏ propagators
summed over ``summation_indices`` and integrated over ``integration_vars``.
Attributes:
propagators: Tuple of propagators forming the diagram.
coupling_sum: Symbolic coupling expression (sum of permuted
couplings), **without** the rational prefactor.
rational_prefactor: The ``(-1)^n / n! × multinomial`` coefficient.
integration_vars: Spatial variables to integrate over.
summation_indices: ``(index_name, dimension)`` pairs for component
index summations.
n_response: Number of R propagators (determines the response phase).
"""
propagators: tuple[Propagator, ...]
coupling_sum: Expr
rational_prefactor: Rational
integration_vars: tuple[str, ...]
summation_indices: tuple[tuple[str, int], ...]
n_response: int
# Maps each non-representative internal spatial label to its
# canonical time representative for equal_time NonLocalVertex
# vertices. Default empty tuple means "no time aliasing"; the
# m legs of each contributing vertex integrate independently
# (the original sft-wick contract). When non-empty, the named
# labels share a single time integration variable while keeping
# independent spatial labels --- this is the equal-shell
# cumulant case (see ``NonLocalVertex.equal_time``).
equal_time_aliases: tuple[tuple[str, str], ...] = ()
# ``(partner_label, leg_label)`` pairs identifying R-propagators
# whose factor has been absorbed into an upstream ``κ^(m)_R``
# callable (``NonLocalVertex(already_R_contracted=True)``). The
# propagator stays in ``self.propagators`` so direction-group
# union-find continues to identify the leg with its partner, but
# the integrand R-product loop skips its R-factor (replaces by 1),
# the leg time is aliased onto the partner via
# ``equal_time_aliases``, and per-leg coupling lookups receive the
# partner coordinates instead of the κ leg's own. Empty tuple ⇒
# no R-absorption (original sft-wick contract).
r_absorbed_pairs: tuple[tuple[str, str], ...] = ()
@property
def propagator_indices(self) -> tuple[tuple[str, int], ...]:
"""Summation indices that appear in at least one propagator."""
prop_idx_set: set[str] = set()
for p in self.propagators:
if p.index_left:
prop_idx_set.add(p.index_left)
if p.index_right:
prop_idx_set.add(p.index_right)
return tuple(item for item in self.summation_indices if item[0] in prop_idx_set)
@property
def coupling_only_indices(self) -> tuple[tuple[str, int], ...]:
"""Summation indices that appear only in the coupling sum, not in any propagator."""
prop_idx_set: set[str] = set()
for p in self.propagators:
if p.index_left:
prop_idx_set.add(p.index_left)
if p.index_right:
prop_idx_set.add(p.index_right)
return tuple(item for item in self.summation_indices if item[0] not in prop_idx_set)
[docs]
def spatial_topology(self) -> list[tuple[str, str, str]]:
"""Return ``(kind, spatial_left, spatial_right)`` for each propagator."""
return [
(p.kind, p.spatial_left, p.spatial_right) for p in self.propagators
]
[docs]
def response_phase_factor(self) -> complex:
"""Return ``(-i)^n_response`` as a complex number."""
return [1.0, -1j, -1.0, 1j][self.n_response % 4]
[docs]
def evaluate_coupling(
self,
coupling_values: dict,
fixed_indices: dict[str, int] | None = None,
) -> "np.ndarray":
"""Substitute numeric coupling tensor values and return an array.
The output is indexed by **propagator indices** only. For each
combination of propagator index values the coupling sum is
evaluated by summing over **coupling-only indices** (indices that
appear in the coupling expression but in no propagator). This
matches the natural contraction order: fix the propagator legs,
then sum the vertex factors internally.
Args:
coupling_values: ``{name: array}`` mapping coupling names to
NumPy arrays. For a rank-3 coupling ``F``, the array
shape should be ``(N, N, N)`` where *N* is the number
of field components.
fixed_indices: Optional ``{index_name: int_value}`` for
indices pinned by external constraints (e.g. observable
component indices like ``{'1': 0}``).
Returns:
NumPy array with shape ``(dim_i0, dim_i1, ...)`` for the
propagator indices, with ``rational_prefactor`` and the MSR
response phase ``(-i)^n_response`` already applied. Returns
a 0-d array when there are no propagator indices. The result
is complex whenever ``n_response % 4 in {1, 3}`` or the
coupling tensors are complex-valued.
"""
import numpy as np
phase = self.response_phase_factor() # (-i)^n_response
pref = phase * (self.rational_prefactor.numerator / self.rational_prefactor.denominator)
base_map: dict[str, int] = dict(fixed_indices) if fixed_indices else {}
dtype = complex if (
isinstance(pref, complex) or
any(np.iscomplexobj(v) for v in coupling_values.values())
) else float
prop_idx = self.propagator_indices # outer axes of result
coup_idx = self.coupling_only_indices # summed internally
prop_shape = tuple(dim for _, dim in prop_idx)
coup_shape = tuple(dim for _, dim in coup_idx)
def _sum_coupling(prop_map: dict[str, int]) -> "np.number":
"""Evaluate coupling_sum with prop_map fixed, summing over coup_idx."""
if not coup_idx:
return dtype(_eval_symbolic(self.coupling_sum, coupling_values, prop_map))
total = dtype(0)
for cidx in np.ndindex(*coup_shape):
index_map = {**prop_map,
**{name: v for (name, _), v in zip(coup_idx, cidx)}}
total += dtype(_eval_symbolic(self.coupling_sum, coupling_values, index_map))
return total
if not prop_idx:
# No propagator indices → scalar result
val = _sum_coupling(base_map)
return np.array(pref * val, dtype=dtype)
result = np.zeros(prop_shape, dtype=dtype)
for pidx in np.ndindex(*prop_shape):
prop_map = {**base_map,
**{name: v for (name, _), v in zip(prop_idx, pidx)}}
result[pidx] = _sum_coupling(prop_map)
return pref * result
[docs]
def evaluate_coupling_batched(
self,
coupling_values_batched: dict,
n_samples: int,
fixed_indices: dict[str, int] | None = None,
) -> "np.ndarray":
"""Vectorised counterpart of :meth:`evaluate_coupling`.
Same contract as :meth:`evaluate_coupling`, but every per-sample
``Symbol`` value carries a leading sample axis of length
``n_samples``. Returns an array whose shape is
``(n_samples,) + prop_shape`` (or just ``(n_samples,)`` when
there are no propagator indices).
Args:
coupling_values_batched: ``{name: ndarray}`` -- arrays may
be either ``(n_samples, *kappa_shape)`` (dynamic /
per-sample) or ``(*kappa_shape,)`` (static, broadcast
across the sample axis).
n_samples: Length of the sample axis.
fixed_indices: Optional ``{index_name: int_value}`` for
indices pinned by external constraints (e.g. observable
component indices like ``{'a': 0}``).
Returns:
Complex / float array. When this :class:`DiagramTerm` has
propagator indices, the returned shape is
``(n_samples,) + prop_shape``. Otherwise it is
``(n_samples,)``.
Raises:
NotImplementedError: If the symbolic ``coupling_sum``
contains a node type the batched evaluator cannot
handle. Callers MUST catch this and fall back to a
per-sample loop over :meth:`evaluate_coupling`.
"""
phase = self.response_phase_factor() # (-i)^n_response
pref = phase * (
self.rational_prefactor.numerator / self.rational_prefactor.denominator
)
base_map: dict[str, int] = dict(fixed_indices) if fixed_indices else {}
# Decide an output dtype: complex iff phase is imaginary or any
# supplied array is complex.
is_complex = isinstance(pref, complex) or any(
np.iscomplexobj(v) for v in coupling_values_batched.values()
)
dtype = complex if is_complex else float
prop_idx = self.propagator_indices
coup_idx = self.coupling_only_indices
prop_shape = tuple(dim for _, dim in prop_idx)
coup_shape = tuple(dim for _, dim in coup_idx)
def _sum_coupling_batched(prop_map: dict[str, int]) -> np.ndarray:
"""Evaluate ``coupling_sum`` for a fixed ``prop_map``,
summing over the coupling-only indices. Returns
``(n_samples,)``."""
if not coup_idx:
val = _eval_symbolic_batched(
self.coupling_sum,
coupling_values_batched,
prop_map,
n_samples,
)
return np.broadcast_to(
np.asarray(val, dtype=dtype), (n_samples,)
).astype(dtype, copy=False)
total = np.zeros(n_samples, dtype=dtype)
for cidx in np.ndindex(*coup_shape):
index_map = {
**prop_map,
**{name: v for (name, _), v in zip(coup_idx, cidx)},
}
val = _eval_symbolic_batched(
self.coupling_sum,
coupling_values_batched,
index_map,
n_samples,
)
total = total + np.asarray(val, dtype=dtype)
return total
if not prop_idx:
val = _sum_coupling_batched(base_map)
return (pref * val).astype(dtype, copy=False)
result = np.zeros((n_samples,) + prop_shape, dtype=dtype)
for pidx in np.ndindex(*prop_shape):
prop_map = {
**base_map,
**{name: v for (name, _), v in zip(prop_idx, pidx)},
}
result[(slice(None),) + pidx] = _sum_coupling_batched(prop_map)
return (pref * result).astype(dtype, copy=False)
[docs]
def apply_diagonal(
self,
*,
diag_R: bool = False,
diag_C: bool = False,
iso_R: bool = False,
iso_C: bool = False,
) -> "DiagramTerm":
"""Return a new term with diagonal (and optionally isotropic) propagator constraints applied.
Eliminates summation indices that are pinned by diagonal
propagators and substitutes index equalities into the coupling
expression. The delta constraint collapses the double sum into
a single sum, so no extra factor is applied.
With ``iso_R=True`` (implies ``diag_R=True``), the remaining equal
component indices are stripped from R propagators entirely, expressing
the assumption that all diagonal R entries are the same scalar R.
Args:
diag_R: Enforce diagonal response propagators.
diag_C: Enforce diagonal correlation propagators.
iso_R: Enforce isotropic (index-free) response propagators.
iso_C: Enforce isotropic (index-free) correlation propagators.
Returns:
A new :class:`DiagramTerm` with reduced summation indices.
"""
diag_R = diag_R or iso_R
diag_C = diag_C or iso_C
if not diag_R and not diag_C:
return self
sum_idx_set = {name for name, _ in self.summation_indices}
constraints: list[tuple[str, str]] = []
for p in self.propagators:
if p.index_left is not None and p.index_right is not None:
if (p.kind == "R" and diag_R) or (p.kind == "C" and diag_C):
if p.index_left != p.index_right:
constraints.append((p.index_left, p.index_right))
if not constraints:
return self
# Union-find (prefer external index as root)
parent: dict[str, str] = {}
def find(x: str) -> str:
while x in parent:
x = parent[x]
return x
for left, right in constraints:
rl, rr = find(left), find(right)
if rl == rr:
continue
if rr in sum_idx_set and rl not in sum_idx_set:
parent[rr] = rl
elif rl in sum_idx_set and rr not in sum_idx_set:
parent[rl] = rr
else:
parent[rr] = rl
sub: dict[str, str] = {}
for idx in {k for pair in constraints for k in pair}:
root = find(idx)
if idx != root:
sub[idx] = root
if not sub:
return self
eliminated = {idx for idx in sub if idx in sum_idx_set}
# External (non-summation) index pairs that get merged by
# the diag_R / diag_C constraint must retain a KroneckerDelta
# factor in the coupling expression. For a SUMMATION index
# ``i`` merged into another, the union-find rename is exact
# because ``sum_i delta_{i, j} f(i) = f(j)``. But for an
# OBSERVABLE index ``a`` (e.g. the labels in
# ``phi_a(x) phi_b(y)``), ``a`` is pinned by the caller via
# ``fixed_indices`` and is *not* summed. Dropping
# ``delta_{a, b}`` then loses the constraint that the
# cross-pair (a != b) order-0 contribution vanishes under
# diag_C. Mirror the logic in ``simplify.py::diagonal_
# propagators`` (lines ~969-977) by collecting an explicit
# KroneckerDelta for each external-external pair.
external_deltas: list[KroneckerDelta] = []
seen_pairs: set[frozenset[str]] = set()
for idx, root in sub.items():
if idx in sum_idx_set or root in sum_idx_set:
continue
pair = frozenset({idx, root})
if pair in seen_pairs:
continue
seen_pairs.add(pair)
i1, i2 = sorted(pair)
external_deltas.append(KroneckerDelta(i1, i2))
# Update propagators
new_props = []
for p in self.propagators:
il = sub.get(p.index_left, p.index_left) if p.index_left else p.index_left
ir = sub.get(p.index_right, p.index_right) if p.index_right else p.index_right
new_props.append(
Propagator(p.kind, il, ir, p.spatial_left, p.spatial_right)
)
# Update coupling sum and prepend any external-external
# deltas BEFORE the index substitution, so the deltas
# reference the ORIGINAL (pre-merge) index names that the
# caller's ``fixed_indices`` will pin at evaluation time.
substituted_coupling = simplify(_apply_index_sub(self.coupling_sum, sub))
if external_deltas:
if isinstance(substituted_coupling, Rational) \
and substituted_coupling.numerator == 1 \
and substituted_coupling.denominator == 1:
# Plain unity coupling -- replace it with the delta product.
if len(external_deltas) == 1:
new_coupling = external_deltas[0]
else:
new_coupling = Product(tuple(external_deltas))
elif isinstance(substituted_coupling, Product):
new_coupling = Product(
tuple(external_deltas) + substituted_coupling.factors
)
else:
new_coupling = Product(
tuple(external_deltas) + (substituted_coupling,)
)
else:
new_coupling = substituted_coupling
# Update summation indices (remove eliminated)
new_sum_indices = tuple(
(name, dim) for name, dim in self.summation_indices
if name not in eliminated
)
# Isotropic: strip equal component indices from R/C propagators
if iso_R or iso_C:
new_props = [
Propagator(p.kind, None, None, p.spatial_left, p.spatial_right)
if (
p.index_left is not None
and p.index_left == p.index_right
and ((p.kind == "R" and iso_R) or (p.kind == "C" and iso_C))
)
else p
for p in new_props
]
# Canonical rename: surviving summation indices → i_0, i_1, … in appearance order.
# Scan propagators first (left-to-right, left-index before right-index),
# then coupling_sum for any index that only lives there.
surviving_set = {name for name, _ in new_sum_indices}
ordered: list[str] = []
ordered_set: set[str] = set()
for p in new_props:
for idx in (p.index_left, p.index_right):
if idx and idx in surviving_set and idx not in ordered_set:
ordered.append(idx)
ordered_set.add(idx)
for idx in _collect_index_order(new_coupling, surviving_set):
if idx not in ordered_set:
ordered.append(idx)
ordered_set.add(idx)
rename = {old: f"i_{j}" for j, old in enumerate(ordered)}
if rename:
new_coupling = _apply_index_sub(new_coupling, rename)
new_props = [_apply_index_sub(p, rename) for p in new_props]
dim_map = dict(new_sum_indices)
new_sum_indices = tuple((rename[name], dim_map[name]) for name in ordered)
return DiagramTerm(
propagators=tuple(new_props),
coupling_sum=new_coupling,
rational_prefactor=self.rational_prefactor,
integration_vars=self.integration_vars,
summation_indices=new_sum_indices,
n_response=self.n_response,
equal_time_aliases=self.equal_time_aliases,
r_absorbed_pairs=self.r_absorbed_pairs,
)
[docs]
def to_latex(self) -> str:
r"""Full LaTeX representation including summations, integrals, and phase.
Renders the complete contribution::
\sum_{i_0} ... \int dy_0 ... coeff × (coupling) × propagators
where *coeff* combines the rational prefactor and the MSR phase
``(-i)^{n_R}`` into a single coefficient.
"""
body_parts: list[str] = []
# --- combined coefficient: rational_prefactor × (-i)^n_response ---
coeff_latex = _coeff_latex(self.rational_prefactor, self.n_response)
if coeff_latex:
body_parts.append(coeff_latex)
body_parts.append(f"({self.coupling_sum.to_latex()})")
for p in self.propagators:
body_parts.append(p.to_latex())
body = " ".join(body_parts)
# --- wrap with integrals (innermost first in the list) ---
for var in reversed(self.integration_vars):
body = rf"\int \mathrm{{d}}{_latex_index(var)}\, {body}"
# --- wrap with summations (innermost first in the list) ---
for name, dim in reversed(self.summation_indices):
body = rf"\sum_{{{_latex_index(name)}=1}}^{{{dim}}} {body}"
return body
[docs]
def to_latex_evaluated(
self,
coupling_values: dict,
fixed_indices: dict[str, int] | None = None,
) -> str:
r"""LaTeX with numerically evaluated coupling coefficients.
Like :meth:`to_latex`, but replaces the symbolic coupling sum
with the numerical coefficient array obtained from
:meth:`evaluate_coupling`. Summation indices that survive in
the propagators are shown as explicit sums over the numerical
coefficients; coupling-only indices are already contracted.
Args:
coupling_values: ``{name: numpy_array}`` for coupling tensors.
fixed_indices: Optional pinned observable indices.
Returns:
LaTeX string with numerical coefficients.
"""
coeff = self.evaluate_coupling(coupling_values, fixed_indices)
prop_idx = self.propagator_indices
# Format propagators
prop_latex = " ".join(p.to_latex() for p in self.propagators)
# Wrap with integrals
int_wrap = ""
for var in reversed(self.integration_vars):
int_wrap += rf"\int \mathrm{{d}}{_latex_index(var)}\, "
if not prop_idx:
# Scalar coefficient
val = complex(coeff)
# Format nicely
if val.imag == 0:
val_r = val.real
if abs(val_r - round(val_r)) < 1e-10:
coeff_str = str(int(round(val_r)))
else:
coeff_str = f"{val_r:.4g}"
else:
coeff_str = f"({val:.4g})"
if coeff_str == "1":
coeff_str = ""
elif coeff_str == "-1":
coeff_str = "-"
return f"{int_wrap}{coeff_str} {prop_latex}"
# Array coefficient: show as sum with explicit values
idx_names = [name for name, _ in prop_idx]
dims = [dim for _, dim in prop_idx]
terms: list[str] = []
for pidx in np.ndindex(*dims):
c = complex(coeff[pidx])
if abs(c) < 1e-14:
continue
if c.imag == 0:
c_r = c.real
if abs(c_r - round(c_r)) < 1e-10:
c_str = str(int(round(c_r)))
else:
c_str = f"{c_r:.4g}"
else:
c_str = f"({c:.4g})"
# Substitute index values into propagators
sub = {name: str(val + 1) for name, val in zip(idx_names, pidx)}
prop_parts = []
for p in self.propagators:
il = sub.get(p.index_left, p.index_left)
ir = sub.get(p.index_right, p.index_right)
sl = _latex_index(p.spatial_left)
sr = _latex_index(p.spatial_right)
if il and ir:
prop_parts.append(
f"{p.kind}_{{{il}{ir}}}({sl}, {sr})"
)
elif il or ir:
idx = il or ir
prop_parts.append(
f"{p.kind}_{{{idx}}}({sl}, {sr})"
)
else:
prop_parts.append(f"{p.kind}({sl}, {sr})")
terms.append(f"{c_str}\\," + "\\,".join(prop_parts))
if not terms:
return "0"
body = " + ".join(terms)
# Fix double signs
body = body.replace("+ -", "- ")
return f"{int_wrap}{body}"
def __repr__(self) -> str:
return self.to_latex()
# --- Thin delegation methods for numerical evaluation pipeline ---
[docs]
def analyze_spatial(self) -> "SpatialStructure":
"""Analyze R-propagator connectivity and time orderings.
Returns a :class:`~sft_wick.evaluate.SpatialStructure` describing
direction identification groups, causal time orderings, and
surviving integration variables.
"""
from .evaluate import analyze_spatial
return analyze_spatial(self)
[docs]
def build_integrand(
self,
coupling_values: dict,
fixed_indices: dict[str, int] | None = None,
) -> "DiagramIntegrand":
"""Build a :class:`~sft_wick.evaluate.DiagramIntegrand` for numerical evaluation.
Combines coupling coefficient evaluation (Step 1) with spatial
structure analysis (Step 2) into a ready-to-evaluate object.
Args:
coupling_values: dict mapping coupling-symbol name → value.
Each value is either:
- a **numeric** ``numpy.ndarray`` of shape ``(N,)*rank``
(e.g. a spacetime-independent local-vertex coefficient
``F``, or a placeholder for a non-local tensor whose
spacetime dependence is trivial),
- a **callable** ``fn(n_list, t_list) → ndarray`` for
spacetime-dependent couplings (e.g. demo2's
``κ^{(3)}(x₁,t₁; x₂,t₂; x₃,t₃)``). Here ``n_list``
and ``t_list`` are length-``order`` sequences of the
vertex's ψ-leg positions and times.
When any value is callable the integrand enters a
**per-QMC-sample** evaluation path (see
:class:`~sft_wick.evaluate.DynamicCouplingPromise`);
all-ndarray values keep the fast vectorised path.
fixed_indices: Optional pinned index values for observable
component labels (e.g. ``{'a': 1, 'b': 1}``).
Returns:
A :class:`DiagramIntegrand` that can evaluate the integrand at
specific time/direction coordinates.
"""
from .evaluate import (
DiagramIntegrand,
DynamicCouplingPromise,
analyze_spatial,
)
spatial = analyze_spatial(self)
fi = dict(fixed_indices) if fixed_indices else {}
# Detect dynamic (callable) coupling values.
dynamic_names = {
name: fn for name, fn in coupling_values.items()
if callable(fn) and not isinstance(fn, np.ndarray)
}
# Is this diagram affected by any callable coupling? A
# diagram that doesn't reference the callable's symbol
# (e.g. the order-0 ``C(x,y)`` diagram when the user has
# also declared a K non-local vertex for higher orders) can
# go through the fully-static fast path — the callable's
# value is simply never needed.
symbol_names_in_coupling = _collect_symbol_names(self.coupling_sum)
active_dynamic = {
name: fn for name, fn in dynamic_names.items()
if name in symbol_names_in_coupling
}
if not active_dynamic:
# No callable coupling is actually referenced — strip
# callables from coupling_values (they'd crash the
# static evaluator) and evaluate statically.
static_cv = {
name: v for name, v in coupling_values.items()
if name not in dynamic_names
}
coeff = self.evaluate_coupling(static_cv, fi)
return DiagramIntegrand(
diagram_term=self, spatial=spatial, coupling_array=coeff,
fixed_indices=fi,
)
# Dynamic path: extract each active callable symbol's
# spatial_args from the coupling_sum so we know which ψ-leg
# coordinates to pass at QMC time.
spatial_args_by_name = _collect_symbol_spatial_args(
self.coupling_sum
)
for name in active_dynamic:
if name not in spatial_args_by_name:
raise ValueError(
f"coupling_values['{name}'] is callable and the "
f"symbol '{name}' appears in this diagram's "
f"coupling_sum, but has no spatial_args — "
f"cannot determine ψ-leg coordinates. This "
f"typically means '{name}' is being used as a "
f"local (zero-leg) coupling; pass it as an "
f"ndarray instead."
)
static_values = {
name: np.asarray(v) for name, v in coupling_values.items()
if name not in dynamic_names
}
dynamic_names = active_dynamic # reduce to what's active
promise = DynamicCouplingPromise(
diagram_term=self,
static_values=static_values,
dynamic_values=dict(dynamic_names),
spatial_args_by_name=spatial_args_by_name,
fixed_indices=fi,
)
# Placeholder coupling_array so the existing vectorised code
# that reads `ig.coupling_array` doesn't error before the
# dynamic branch is taken. The per-sample path fills in the
# actual values.
prop_shape = tuple(dim for _, dim in self.propagator_indices)
placeholder = np.zeros(prop_shape, dtype=complex)
return DiagramIntegrand(
diagram_term=self, spatial=spatial,
coupling_array=placeholder,
fixed_indices=fi,
dynamic_coupling=promise,
)
def _collect_r_absorbed_pairs(
props: tuple[Propagator, ...],
vertex_instances,
) -> tuple[tuple[tuple[str, str], ...], tuple[tuple[str, str], ...]]:
"""Identify R-propagators that touch a leg of an
``already_R_contracted`` non-local vertex.
For each such R-propagator, the leg's ψ-side gets aliased onto the
partner's φ-side (so time and direction integration collapses), and
the propagator is recorded as ``(partner, leg)`` in the returned
``r_absorbed_pairs`` tuple. The integrand evaluator skips the
R-factor at evaluation time for any propagator in that set.
Canonical order from :func:`~sft_wick.propagators.contract_pair`:
``R(spatial_left, spatial_right) = ⟨φ(spatial_left) ψ(spatial_right)⟩``,
so the κ-leg lives at ``spatial_right`` and the partner φ at
``spatial_left``.
Returns:
``(r_absorbed_pairs, leg_to_partner_aliases)`` — the first
feeds ``DiagramTerm.r_absorbed_pairs``; the second is a list of
``(leg, partner)`` alias pairs to merge into the diagram's
``equal_time_aliases`` so the leg's time variable is dropped
from integration.
"""
absorbed_legs: set[str] = set()
for vi in vertex_instances:
if getattr(vi.vertex, "already_R_contracted", False):
absorbed_legs.update(vi.spatial_variables)
if not absorbed_legs:
return (), ()
pairs: list[tuple[str, str]] = []
aliases: list[tuple[str, str]] = []
seen_legs: set[str] = set()
for p in props:
if p.kind != "R":
continue
# canonical order: spatial_right is the ψ-side (leg).
if p.spatial_right in absorbed_legs:
leg = p.spatial_right
partner = p.spatial_left
elif p.spatial_left in absorbed_legs:
# Defensive: if a non-canonical ordering ever slips through,
# treat the absorbed-leg endpoint as the leg.
leg = p.spatial_left
partner = p.spatial_right
else:
continue
if leg in seen_legs:
# Each absorbed leg participates in exactly one R-propagator
# (every operator pairs once in a valid Wick contraction).
# Defensive guard against unexpected multi-pairing.
continue
seen_legs.add(leg)
pairs.append((partner, leg))
aliases.append((leg, partner))
return tuple(sorted(pairs)), tuple(sorted(aliases))
def _collect_symbol_names(expr: Expr) -> set[str]:
"""Walk a coupling-sum tree and return the set of unique
:class:`~sft_wick.expressions.Symbol` names present.
Used by :meth:`DiagramTerm.build_integrand` to decide whether a
callable coupling value is actually referenced by this
particular diagram (so higher-order diagrams with the non-local
vertex trigger the dynamic path, while lower-order ones keep
the fast static path)."""
out: set[str] = set()
def walk(e: Expr) -> None:
if isinstance(e, Symbol):
out.add(e.name)
return
if isinstance(e, Rational):
return
if isinstance(e, Product):
for f in e.factors:
walk(f)
return
if isinstance(e, Sum):
for t in e.terms:
walk(t)
return
for attr in ("expr", "body", "integrand"):
child = getattr(e, attr, None)
if isinstance(child, Expr):
walk(child)
walk(expr)
return out
def _collect_symbol_spatial_args(expr: Expr) -> dict[str, tuple[str, ...]]:
"""Walk a coupling-sum expression tree and collect the
``spatial_args`` tuple for each unique :class:`~sft_wick.expressions.Symbol`
name.
Used by :meth:`DiagramTerm.build_integrand` to determine, for a
non-local vertex coupling passed as a callable, which spatial
labels (e.g. ``y_0_0``, ``y_0_1``, ``y_0_2``) correspond to that
vertex's ψ-legs — so that at QMC time the per-sample
``(n_list, t_list)`` can be reconstructed and fed into the
callable.
Returns ``{name: spatial_args_tuple}``. Symbols with no
spatial args are omitted.
If a symbol appears multiple times in the tree (e.g. across
different Wick pairings of the same vertex) we return the
spatial_args of its first occurrence — they are guaranteed to
agree within one :class:`DiagramTerm` because all Symbols with
the same name derive from the same :class:`VertexInstance`.
"""
out: dict[str, tuple[str, ...]] = {}
def walk(e: Expr) -> None:
if isinstance(e, Symbol):
if e.spatial_args and e.name not in out:
out[e.name] = tuple(e.spatial_args)
return
if isinstance(e, Rational):
return
if isinstance(e, Product):
for f in e.factors:
walk(f)
return
if isinstance(e, Sum):
for t in e.terms:
walk(t)
return
# Other Expr subclasses may wrap children on common attr names
for attr in ("expr", "body", "integrand"):
child = getattr(e, attr, None)
if isinstance(child, Expr):
walk(child)
walk(expr)
return out
[docs]
def compute_moment(
observable: list[FieldOperator],
action: Action,
order: int,
ito: bool = True,
response_phase: bool = True,
collect_topology: bool = True,
diag_R: bool = False,
diag_C: bool = False,
iso_R: bool = False,
iso_C: bool = False,
) -> PerturbativeResult:
r"""Compute the perturbative expansion of an observable.
Evaluates
.. math::
\langle \mathcal{O} \rangle_S
= \sum_{n=0}^{N} \frac{(-1)^n}{n!}\,
\langle \mathcal{O}\, S_{\mathrm{int}}^{\,n} \rangle_{S_0}
up to the requested perturbative order. In the MSR formalism the
partition function :math:`Z = 1`, so there is no denominator.
Args:
observable: List of field operators defining the observable
:math:`\mathcal{O}`.
action: The interaction action :math:`S_{\mathrm{int}}`.
order: Maximum perturbative order *N* to compute.
ito: If ``True``, apply the Itô prescription
:math:`\Theta(0)=0`: the response propagator vanishes at
equal spatial points, :math:`R(x,x)=0`, and causal
R-loops are eliminated.
response_phase: If ``True``, multiply each term by
:math:`(-\mathrm{i})^n` where *n* is the number of
response propagators in that term, implementing the
convention :math:`\langle\phi\,\psi\rangle =
-\mathrm{i}\,R`.
collect_topology: If ``True``, group terms that share the
same propagator spatial structure and factor out the
propagators, summing the coupling coefficients with
appropriately permuted indices.
diag_R: If ``True``, enforce diagonal response propagators
:math:`R_{ij} = \delta_{ij} R`, eliminating one summation
index per R propagator.
diag_C: If ``True``, enforce diagonal correlation propagators
:math:`C_{ij} = \delta_{ij} C`, eliminating one summation
index per C propagator.
iso_R: If ``True`` (implies ``diag_R``), further assume all
diagonal R entries are equal, :math:`R_{ii} = R`. The
component index is dropped from R propagators entirely.
iso_C: If ``True`` (implies ``diag_C``), the same for C.
Returns:
A :class:`PerturbativeResult` containing order-by-order
expressions, a combined total, and Feynman diagram information.
"""
diag_R = diag_R or iso_R
diag_C = diag_C or iso_C
order_terms: dict[int, Expr] = {}
diagrams_by_order: dict[int, list[DiagramInfo]] = {}
dt_by_order: dict[int, list[DiagramTerm]] = {}
for n in range(order + 1):
sign = (-1) ** n
fact = factorial(n)
order_exprs: list[Expr] = []
order_diagrams: list[DiagramInfo] = []
order_dterms: list[DiagramTerm] = []
if n == 0:
# Zeroth order: just <O>_{S_0}
wick_result, pairings = wick_contract(observable, ito=ito)
order_exprs.append(wick_result)
for p in pairings:
order_diagrams.append(
DiagramInfo(observable, [], p, Rational(1), 0)
)
# Build DiagramTerm for order 0
props = []
for i, j in p:
pr = contract_pair(observable[i], observable[j], ito=ito)
if isinstance(pr, Propagator):
props.append(pr)
if props:
order_dterms.append(DiagramTerm(
propagators=tuple(props),
coupling_sum=Rational(1),
rational_prefactor=Rational(1),
integration_vars=(),
summation_indices=(),
n_response=sum(1 for pr in props if pr.kind == "R"),
))
else:
# Expand S_int^n using multinomial theorem
for vertex_seq, multinomial_coeff in action.all_vertex_combinations(n):
idx_ctx = IndexContext()
# Instantiate each vertex copy
vertex_instances = [
VertexInstance.instantiate(v, idx_ctx, copy_id=k)
for k, v in enumerate(vertex_seq)
]
# Collect all operators: observable + vertex fields
all_ops = list(observable)
for vi in vertex_instances:
all_ops.extend(vi.field_operators)
# Build prefactor: (-1)^n / n! * multinomial_coeff
prefactor = Rational(sign * multinomial_coeff, fact)
# Build pure coupling (without prefactor) for DiagramTerm
coupling_syms = [vi.coupling_symbol for vi in vertex_instances]
if len(coupling_syms) > 1:
pure_coupling: Expr = Product(tuple(coupling_syms))
elif coupling_syms:
pure_coupling = coupling_syms[0]
else:
pure_coupling = Rational(1)
# Full coupling product (with prefactor) for the expression
coupling_product = Product(
tuple([prefactor] + coupling_syms)
)
# Collect integration variables (vertex spatial points)
integration_vars: frozenset[str] = frozenset(
var
for vi in vertex_instances
for var in vi.spatial_variables
)
# Collect equal-time alias map from any equal_time
# NonLocalVertex instances; each maps a non-representative
# leg label → the canonical representative whose time
# variable is integrated. Empty when no equal_time vertex
# is present (back-compat).
_eq_time_alias_pairs: list[tuple[str, str]] = []
for vi in vertex_instances:
for k, v in vi.equal_time_aliases or ():
_eq_time_alias_pairs.append((k, v))
eq_time_aliases_tuple = tuple(sorted(_eq_time_alias_pairs))
# Build summation index info for DiagramTerm
sum_indices: list[tuple[str, int]] = []
for vi in vertex_instances:
for comp_idx in vi.component_indices:
for op in vi.field_operators:
if op.component_index == comp_idx:
sum_indices.append(
(comp_idx, op.field.n_components)
)
break
int_vars_sorted = tuple(sorted(integration_vars))
if collect_topology:
# --- Hybrid: spatial topology + component routing ---
spatial_results = wick_contract_spatial(
all_ops, ito=ito, vertex_points=integration_vars,
)
if not spatial_results:
continue
groups: dict[
SpatialSignature,
list[tuple[list[Propagator], Pairing]],
] = {}
pairings: list[Pairing] = []
for sig, (ref_props, mult, rep_pairing) in spatial_results.items():
routings = _enumerate_component_routings(
ref_props, rep_pairing, all_ops, integration_vars,
)
groups[sig] = routings
pairings.extend(p for _, p in routings)
internal_indices: set[str] = set()
for vi in vertex_instances:
internal_indices.update(vi.component_indices)
inner = _collect_grouped_wick(
groups, pure_coupling,
internal_indices, integration_vars,
)
if _is_zero(inner):
continue
term = Product((prefactor, inner))
# Extract DiagramTerm records from inner
for dt_props, dt_coupling in _extract_diagram_records(
inner
):
r_absorbed_pairs, leg_aliases = (
_collect_r_absorbed_pairs(
dt_props, vertex_instances,
)
)
merged_aliases = (
tuple(sorted(
set(eq_time_aliases_tuple) | set(leg_aliases)
))
if leg_aliases else eq_time_aliases_tuple
)
order_dterms.append(DiagramTerm(
propagators=dt_props,
coupling_sum=dt_coupling,
rational_prefactor=prefactor,
integration_vars=int_vars_sorted,
summation_indices=tuple(sum_indices),
n_response=sum(
1 for p in dt_props if p.kind == "R"
),
equal_time_aliases=merged_aliases,
r_absorbed_pairs=r_absorbed_pairs,
))
else:
# --- Operator-level Wick contraction ---
wick_result, pairings = wick_contract(all_ops, ito=ito)
if _is_zero(wick_result):
continue
term = Product((coupling_product, wick_result))
# Wrap with integrals over internal spatial variables
for vi in vertex_instances:
for var in vi.spatial_variables:
term = IntegralOver(var, term)
# Wrap with summations over internal component indices
for vi in vertex_instances:
for comp_idx in vi.component_indices:
field_for_idx = None
for op in vi.field_operators:
if op.component_index == comp_idx:
field_for_idx = op.field
break
if field_for_idx is not None:
term = SumOverIndex(
comp_idx, field_for_idx.n_components, term
)
order_exprs.append(term)
for p in pairings:
order_diagrams.append(
DiagramInfo(
observable,
vertex_instances,
p,
prefactor,
n,
)
)
if order_exprs:
raw = order_exprs[0] if len(order_exprs) == 1 else Sum(tuple(order_exprs))
simplified = simplify(raw)
if not collect_topology:
# Operator-level path: group by diagram after simplification
simplified = collect_by_diagram(simplified)
# (Hybrid path: _collect_grouped_wick already produced coupling sums)
if diag_R or diag_C:
simplified = diagonal_propagators(
simplified, diag_R=diag_R, diag_C=diag_C, iso_R=iso_R, iso_C=iso_C,
)
order_terms[n] = (
apply_response_phase(simplified) if response_phase
else simplified
)
else:
order_terms[n] = ZERO
if diag_R or diag_C:
order_dterms = [
dt.apply_diagonal(diag_R=diag_R, diag_C=diag_C, iso_R=iso_R, iso_C=iso_C)
for dt in order_dterms
]
diagrams_by_order[n] = order_diagrams
dt_by_order[n] = order_dterms
# Total
all_terms = [order_terms[n] for n in range(order + 1) if not _is_zero(order_terms[n])]
if not all_terms:
total = ZERO
elif len(all_terms) == 1:
total = all_terms[0]
else:
total = simplify(Sum(tuple(all_terms)))
if response_phase:
total = apply_response_phase(total)
return PerturbativeResult(
order_terms, total, diagrams_by_order, dt_by_order,
)
def _collect_grouped_wick(
groups: dict[SpatialSignature, list[tuple[list[Propagator], Pairing]]],
pure_coupling: Expr,
internal_indices: set[str],
integration_vars: frozenset[str],
) -> Expr:
"""Pre-collect Wick contraction results grouped by spatial signature.
Instead of building a huge Sum and relying on collect_by_diagram,
this function:
1. Computes component-index permutations within each spatial group.
2. Merges spatial groups that share the same canonical diagram form
(i.e. are related by integration-variable relabeling).
3. Returns a compact Sum with one term per distinct Feynman diagram.
The ``pure_coupling`` should be the coupling expression **without**
the rational prefactor (so that ``_extract_diagram_records`` can
cleanly separate coupling from prefactor).
"""
from collections import defaultdict
if not groups:
return ZERO
# --- Phase 1: Within each spatial group, collect component-index perms ---
# Each spatial group's pairings share the same propagator spatial
# positions but differ in component indices.
spatial_collected: list[tuple[list[Propagator], list[dict[str, str]]]] = []
for sig, group_entries in groups.items():
ref_props = group_entries[0][0]
perms: list[dict[str, str]] = [{}] # identity for reference
for props, _pairing in group_entries[1:]:
comp_perm = _fast_component_match(
ref_props, props, internal_indices,
)
if comp_perm is not None:
perms.append(comp_perm)
else:
# Fallback: treat as its own 1-element group
spatial_collected.append((props, [{}]))
spatial_collected.append((ref_props, perms))
# --- Phase 2: Group by canonical diagram form ---
canonical_groups: dict[
tuple[tuple[str, str, str], ...],
list[tuple[list[Propagator], list[dict[str, str]], dict[str, str]]],
] = defaultdict(list)
for ref_props, perms in spatial_collected:
canon, mapping = _canonical_diagram_form(ref_props, integration_vars)
canonical_groups[canon].append((ref_props, perms, mapping))
# --- Phase 3: Merge and build expression ---
result_terms: list[Expr] = []
for canon, entries in canonical_groups.items():
ref_props_0, perms_0, mapping_0 = entries[0]
ref_inv_0 = {v: k for k, v in mapping_0.items()}
all_coupling_terms: list[Expr] = []
for ref_props, perms, mapping in entries:
# Spatial relabeling: map this entry's vars to the canonical ref
spatial_perm: dict[str, str] = {}
for orig, canon_name in mapping.items():
target = ref_inv_0.get(canon_name, canon_name)
if orig != target:
spatial_perm[orig] = target
# Component-index matching between this entry's ref props and
# the canonical group's ref props
if ref_props is ref_props_0 and not spatial_perm:
cross_comp_perm: dict[str, str] = {}
else:
cross_result = _match_propagators_after_spatial(
ref_props_0, ref_props, spatial_perm, internal_indices,
)
if cross_result is None:
# Cannot merge — add each within-group perm separately
for wp in perms:
permuted = _apply_perm_to_coupling(pure_coupling, wp)
prop_expr = (
Product(tuple(ref_props)) if len(ref_props) > 1
else ref_props[0]
)
result_terms.append(Product((permuted, prop_expr)))
continue
cross_comp_perm = cross_result
cross_full = {**spatial_perm, **cross_comp_perm}
# Compose cross-group perm with each within-group perm
for wp in perms:
total_perm: dict[str, str] = {}
for k, v in wp.items():
total_perm[k] = cross_full.get(v, v)
for k, v in cross_full.items():
if k not in total_perm:
total_perm[k] = v
# Remove identity mappings
total_perm = {k: v for k, v in total_perm.items() if k != v}
all_coupling_terms.append(
_apply_perm_to_coupling(pure_coupling, total_perm)
)
# Build: (sum of permuted couplings) × (reference propagators)
if len(all_coupling_terms) == 1:
coupling_expr: Expr = all_coupling_terms[0]
else:
# Fast path: check if all are identical
if all(t == all_coupling_terms[0] for t in all_coupling_terms[1:]):
n_terms = len(all_coupling_terms)
coupling_expr = Product(
(Rational(n_terms, 1), all_coupling_terms[0])
)
else:
# Hash-based dedup instead of full simplify
term_counts: dict[Expr, int] = {}
for t in all_coupling_terms:
term_counts[t] = term_counts.get(t, 0) + 1
deduped: list[Expr] = []
for t, count in term_counts.items():
if count == 1:
deduped.append(t)
else:
deduped.append(Product((Rational(count, 1), t)))
coupling_expr = deduped[0] if len(deduped) == 1 else Sum(tuple(deduped))
prop_expr = (
Product(tuple(ref_props_0)) if len(ref_props_0) > 1
else ref_props_0[0]
)
result_terms.append(Product((coupling_expr, prop_expr)))
if not result_terms:
return ZERO
if len(result_terms) == 1:
return result_terms[0]
return Sum(tuple(result_terms))
def _enumerate_component_routings(
ref_props: list[Propagator],
rep_pairing: Pairing,
all_ops: list[FieldOperator],
vertex_points: frozenset[str],
) -> list[tuple[list[Propagator], Pairing]]:
"""Enumerate all component-index routings for a given spatial topology.
Given a spatial topology (from ``wick_contract_spatial``), permute field
operators at each vertex point among same-type edge slots to recover all
distinct operator-level pairings within this topology.
Returns a list of ``(propagator_list, pairing)`` suitable for feeding
into ``_collect_grouped_wick``.
"""
from itertools import permutations, product as cartesian_product
n_edges = len(rep_pairing)
# Step 1: Build edge slot structure from the representative pairing.
# For each edge, record which operator fills the left and right slot,
# and classify each slot by (spatial_point, field_type).
# slots_at_point[spatial_point][field_type] = [(edge_idx, side), ...]
slots_at_point: dict[str, dict[str, list[tuple[int, str]]]] = {}
ops_at_slots: dict[str, dict[str, list[int]]] = {} # same structure but stores op indices
for edge_idx, (op_left, op_right) in enumerate(rep_pairing):
prop = ref_props[edge_idx]
if prop.kind == "R":
# Left = phi, right = psi
phi_pt = all_ops[op_left].spatial_arg
psi_pt = all_ops[op_right].spatial_arg
slots_at_point.setdefault(phi_pt, {}).setdefault("phi", []).append(
(edge_idx, "left")
)
ops_at_slots.setdefault(phi_pt, {}).setdefault("phi", []).append(op_left)
slots_at_point.setdefault(psi_pt, {}).setdefault("psi", []).append(
(edge_idx, "right")
)
ops_at_slots.setdefault(psi_pt, {}).setdefault("psi", []).append(op_right)
else:
# C edge: both sides are phi
left_pt = all_ops[op_left].spatial_arg
right_pt = all_ops[op_right].spatial_arg
slots_at_point.setdefault(left_pt, {}).setdefault("phi", []).append(
(edge_idx, "left")
)
ops_at_slots.setdefault(left_pt, {}).setdefault("phi", []).append(op_left)
slots_at_point.setdefault(right_pt, {}).setdefault("phi", []).append(
(edge_idx, "right")
)
ops_at_slots.setdefault(right_pt, {}).setdefault("phi", []).append(op_right)
# Step 2: At each vertex point, enumerate permutations of operators
# among same-type slots. Observable points are fixed.
per_point_perms: list[list[dict[tuple[int, str], int]]] = []
point_keys: list[tuple[str, str]] = [] # (point, field_type)
for point in sorted(slots_at_point.keys()):
if point not in vertex_points:
continue # Observable point — no permutation
for ftype in sorted(slots_at_point[point].keys()):
slots = slots_at_point[point][ftype]
ops = ops_at_slots[point][ftype]
if len(ops) <= 1:
continue # Only one operator — no permutation needed
# Enumerate all permutations of ops among slots
point_perms: list[dict[tuple[int, str], int]] = []
for perm_ops in permutations(ops):
mapping: dict[tuple[int, str], int] = {}
for slot, new_op in zip(slots, perm_ops):
mapping[slot] = new_op
point_perms.append(mapping)
per_point_perms.append(point_perms)
point_keys.append((point, ftype))
# Step 3: Cartesian product of per-point permutations
if not per_point_perms:
# No permutations possible — only the reference pairing
return [(list(ref_props), rep_pairing)]
# Build base assignment: slot → operator (from reference pairing)
base_assign: dict[tuple[int, str], int] = {}
for point in slots_at_point:
for ftype in slots_at_point[point]:
for slot, op_idx in zip(
slots_at_point[point][ftype], ops_at_slots[point][ftype]
):
base_assign[slot] = op_idx
seen_pairings: set[tuple[tuple[int, int], ...]] = set()
results: list[tuple[list[Propagator], Pairing]] = []
for combo in cartesian_product(*per_point_perms):
# Merge all per-point slot reassignments into the base
assign = dict(base_assign)
for point_mapping in combo:
assign.update(point_mapping)
# Build new pairing and propagators from the assignment
new_pairs: list[tuple[int, int]] = []
new_props: list[Propagator] = []
for edge_idx in range(n_edges):
prop = ref_props[edge_idx]
if prop.kind == "R":
left_op = assign[(edge_idx, "left")]
right_op = assign[(edge_idx, "right")]
else:
left_op = assign[(edge_idx, "left")]
right_op = assign[(edge_idx, "right")]
new_pairs.append((left_op, right_op))
# Build propagator with the new component indices
ol = all_ops[left_op]
or_ = all_ops[right_op]
new_props.append(Propagator(
kind=prop.kind,
index_left=ol.component_index,
index_right=or_.component_index,
spatial_left=ol.spatial_arg,
spatial_right=or_.spatial_arg,
))
# De-duplicate: canonicalize the pairing
canon_pairing = tuple(sorted(
tuple(sorted(pair)) for pair in new_pairs
))
if canon_pairing in seen_pairings:
continue
seen_pairings.add(canon_pairing)
results.append((new_props, tuple(new_pairs)))
return results
def _fast_component_match(
ref_props: list[Propagator],
other_props: list[Propagator],
internal_indices: set[str],
) -> dict[str, str] | None:
"""Fast component-index matching for props with identical spatial structure.
Unlike the general ``_match_propagators_after_spatial``, this assumes
spatial positions are already identical (no spatial perm). It groups
propagators by their exact ``(kind, spatial_left, spatial_right)``
tuple and only tries permutations within tied groups.
"""
from collections import defaultdict
from itertools import permutations as iterperms
# Group both lists by exact spatial key (including C directionality)
ref_by_key: dict[tuple, list[int]] = defaultdict(list)
other_by_key: dict[tuple, list[int]] = defaultdict(list)
for i, p in enumerate(ref_props):
key = (p.kind, p.spatial_left, p.spatial_right)
ref_by_key[key].append(i)
for i, p in enumerate(other_props):
key = (p.kind, p.spatial_left, p.spatial_right)
other_by_key[key].append(i)
# Also handle C symmetry: C(x,y) matches C(y,x)
# First try exact match; if keys don't align, try with C flipped
if set(ref_by_key.keys()) != set(other_by_key.keys()):
# Re-group other with C-flipped keys
other_by_key_flip: dict[tuple, list[tuple[int, bool]]] = defaultdict(list)
for i, p in enumerate(other_props):
if p.kind == "C":
# Try canonical key
ckey = ("C", min(p.spatial_left, p.spatial_right),
max(p.spatial_left, p.spatial_right))
flipped = p.spatial_left > p.spatial_right
other_by_key_flip[ckey].append((i, flipped))
else:
other_by_key_flip[(p.kind, p.spatial_left, p.spatial_right)].append((i, False))
ref_by_key_canon: dict[tuple, list[tuple[int, bool]]] = defaultdict(list)
for i, p in enumerate(ref_props):
if p.kind == "C":
ckey = ("C", min(p.spatial_left, p.spatial_right),
max(p.spatial_left, p.spatial_right))
flipped = p.spatial_left > p.spatial_right
ref_by_key_canon[ckey].append((i, flipped))
else:
ref_by_key_canon[(p.kind, p.spatial_left, p.spatial_right)].append((i, False))
if set(ref_by_key_canon.keys()) != set(other_by_key_flip.keys()):
return None
# Use canonical matching with flip tracking
perm: dict[str, str] = {}
for key in ref_by_key_canon:
ri_list = ref_by_key_canon[key]
oi_list = other_by_key_flip[key]
if len(ri_list) != len(oi_list):
return None
if len(ri_list) == 1:
ri, r_flip = ri_list[0]
oi, o_flip = oi_list[0]
actual_flip = r_flip != o_flip
if not _try_add_index_perm(
perm, ref_props[ri], other_props[oi],
actual_flip, internal_indices
):
return None
else:
# Try all permutations of the group
found = False
for op in iterperms(oi_list):
test_perm = dict(perm)
ok = True
for (ri, r_flip), (oi, o_flip) in zip(ri_list, op):
actual_flip = r_flip != o_flip
if not _try_add_index_perm(
test_perm, ref_props[ri], other_props[oi],
actual_flip, internal_indices
):
ok = False
break
if ok:
perm = test_perm
found = True
break
if not found:
return None
return perm
# Fast path: exact key match (no C flipping needed)
perm = {}
for key in ref_by_key:
ri_list = ref_by_key[key]
oi_list = other_by_key[key]
if len(ri_list) != len(oi_list):
return None
if len(ri_list) == 1:
ri, oi = ri_list[0], oi_list[0]
if not _try_add_index_perm(
perm, ref_props[ri], other_props[oi], False, internal_indices
):
return None
else:
found = False
for op in iterperms(oi_list):
test_perm = dict(perm)
ok = True
for ri, oi in zip(ri_list, op):
if not _try_add_index_perm(
test_perm, ref_props[ri], other_props[oi],
False, internal_indices
):
ok = False
break
if ok:
perm = test_perm
found = True
break
if not found:
return None
return perm
def _try_add_index_perm(
perm: dict[str, str],
ref_prop: Propagator,
other_prop: Propagator,
flipped: bool,
internal_indices: set[str],
) -> bool:
"""Try to add component-index mappings from other_prop to ref_prop."""
if flipped:
pairs = [
(ref_prop.index_left, other_prop.index_right),
(ref_prop.index_right, other_prop.index_left),
]
else:
pairs = [
(ref_prop.index_left, other_prop.index_left),
(ref_prop.index_right, other_prop.index_right),
]
for ref_idx, other_idx in pairs:
if ref_idx is None and other_idx is None:
continue
if ref_idx == other_idx:
continue
if other_idx not in internal_indices or ref_idx not in internal_indices:
return False
if other_idx in perm:
if perm[other_idx] != ref_idx:
return False
else:
perm[other_idx] = ref_idx
return True
def _apply_perm_to_coupling(coupling: Expr, perm: dict[str, str]) -> Expr:
"""Apply index permutation to a coupling product."""
if not perm:
return coupling
if isinstance(coupling, Product):
return Product(tuple(
_apply_perm_to_coupling(f, perm) for f in coupling.factors
))
if isinstance(coupling, Symbol):
new_indices = tuple(perm.get(i, i) for i in coupling.indices)
new_spatial = tuple(perm.get(s, s) for s in coupling.spatial_args)
if new_indices == coupling.indices and new_spatial == coupling.spatial_args:
return coupling
return Symbol(coupling.name, new_indices, new_spatial)
return coupling
def _extract_diagram_records(
expr: Expr,
) -> list[tuple[tuple[Propagator, ...], Expr]]:
"""Extract ``(propagators, coupling_sum)`` from ``_collect_grouped_wick`` output.
The output has a known structure: each diagram is a ``Product``
whose factors include propagators and coupling expressions. The
``Product`` constructor auto-flattens, so the factors may be
interleaved. We separate them by type.
Records with identical propagator sets are merged (their couplings
are summed).
"""
from collections import defaultdict
if _is_zero(expr):
return []
terms = list(expr.terms) if isinstance(expr, Sum) else [expr]
raw_records: list[tuple[tuple[Propagator, ...], Expr]] = []
for term in terms:
if isinstance(term, Product):
props: list[Propagator] = []
coupling_factors: list[Expr] = []
for f in term.factors:
if isinstance(f, Propagator):
props.append(f)
else:
coupling_factors.append(f)
if not props:
continue
if coupling_factors:
coupling: Expr = (
coupling_factors[0] if len(coupling_factors) == 1
else Product(tuple(coupling_factors))
)
else:
coupling = Rational(1)
raw_records.append((tuple(props), coupling))
elif isinstance(term, Propagator):
raw_records.append(((term,), Rational(1)))
# Merge records with identical propagator sets
grouped: dict[tuple[Propagator, ...], list[Expr]] = defaultdict(list)
for props_t, coupling in raw_records:
grouped[props_t].append(coupling)
records: list[tuple[tuple[Propagator, ...], Expr]] = []
for props_t, couplings in grouped.items():
if len(couplings) == 1:
records.append((props_t, couplings[0]))
else:
records.append((props_t, Sum(tuple(couplings))))
return records
def _collect_index_order(expr: Expr, summation_set: set[str]) -> list[str]:
"""Return summation index names that appear in *expr*, in DFS first-appearance order."""
seen: list[str] = []
seen_set: set[str] = set()
def _visit(e: Expr) -> None:
if isinstance(e, Symbol):
for idx in e.indices:
if idx in summation_set and idx not in seen_set:
seen.append(idx)
seen_set.add(idx)
elif isinstance(e, Propagator):
for idx in (e.index_left, e.index_right):
if idx and idx in summation_set and idx not in seen_set:
seen.append(idx)
seen_set.add(idx)
elif isinstance(e, Product):
for child in e.factors:
_visit(child)
elif isinstance(e, Sum):
for child in e.terms:
_visit(child)
_visit(expr)
return seen
def _eval_symbolic(
expr: Expr,
symbol_values: dict,
index_map: dict[str, int],
) -> float:
"""Recursively evaluate a symbolic expression with concrete values.
Args:
expr: The symbolic expression.
symbol_values: ``{name: numpy_array}`` mapping coupling names
to numeric arrays.
index_map: ``{index_name: int_value}`` mapping component index
names to concrete integer values.
Returns:
The numeric value as a float or complex.
"""
if isinstance(expr, Rational):
return expr.numerator / expr.denominator
if isinstance(expr, Symbol):
arr = symbol_values[expr.name]
if expr.indices:
def _resolve(i: str) -> int:
if i in index_map:
return index_map[i]
# Literal observable component (e.g. '1', '2') — 1-indexed convention
try:
return int(i) - 1
except ValueError:
raise KeyError(
f"Index '{i}' not found in index_map and is not a literal integer. "
"Pass it via the fixed_indices argument of evaluate_coupling()."
) from None
idx = tuple(_resolve(i) for i in expr.indices)
val = arr[idx]
else:
val = arr
return complex(val) if np.iscomplexobj(arr) else float(val)
if isinstance(expr, Product):
result: complex | float = 1.0
for f in expr.factors:
result *= _eval_symbolic(f, symbol_values, index_map)
return result
if isinstance(expr, Sum):
return sum(
_eval_symbolic(t, symbol_values, index_map) for t in expr.terms
)
if isinstance(expr, KroneckerDelta):
# Resolve each index against ``index_map``; literal observable
# indices ('0', '1', ...) are treated as ints. The delta is
# 1 iff the two indices resolve to the same value, else 0.
def _resolve(i: str) -> int:
if i in index_map:
return int(index_map[i])
try:
return int(i)
except ValueError:
raise KeyError(
f"KroneckerDelta index {i!r} not found in index_map "
f"and is not a literal integer. Pass it via the "
f"fixed_indices argument of evaluate_coupling()."
) from None
return 1.0 if _resolve(expr.index1) == _resolve(expr.index2) else 0.0
if isinstance(expr, ImaginaryUnit):
return 1j
raise TypeError(f"Cannot numerically evaluate {type(expr).__name__}")
def _eval_symbolic_batched(
expr: Expr,
symbol_values_batched: dict,
index_map: dict[str, int],
n_samples: int,
) -> "np.ndarray":
"""Vectorised counterpart of :func:`_eval_symbolic`.
Walks the symbolic expression tree once and returns an array of
shape ``(n_samples,)`` instead of a scalar. ``Symbol`` values in
``symbol_values_batched`` may carry a leading sample axis (so the
array shape is ``(n_samples,) + (N,) * rank``); static (non-batched)
arrays of shape ``(N,) * rank`` are also accepted and broadcast
automatically. All other node types collapse to scalars and are
promoted to ``(n_samples,)`` only when the parent ``Product`` /
``Sum`` mixes them with a batched child.
For node types that cannot be cheaply vectorised
(``Propagator``, ``IntegralOver``, ``SumOverIndex``, ``DiracDelta``)
a :class:`NotImplementedError` is raised so the caller can fall back
to the scalar per-sample loop. These nodes never appear in the
``coupling_sum`` of a :class:`DiagramTerm` (couplings are strictly
products / sums of ``Symbol``, ``KroneckerDelta``, ``Rational`` and
``ImaginaryUnit`` after :func:`apply_response_phase`), so the
fallback only triggers for unusual user inputs.
Args:
expr: The symbolic expression (typically a ``DiagramTerm``'s
``coupling_sum``).
symbol_values_batched: ``{name: ndarray}`` mapping coupling
symbol names to either ``(n_samples, *kappa_shape)`` arrays
(dynamic / per-sample) or ``(*kappa_shape,)`` arrays
(static, broadcast across the sample axis).
index_map: ``{index_name: int_value}`` for component indices
pinned by the caller (e.g. propagator-index outer loop or
``fixed_indices``).
n_samples: Length of the sample axis.
Returns:
A complex or float ``(n_samples,)`` array.
"""
if isinstance(expr, Rational):
return expr.numerator / expr.denominator # scalar; broadcast later
if isinstance(expr, ImaginaryUnit):
return 1j # scalar; broadcast later
if isinstance(expr, Symbol):
arr = symbol_values_batched[expr.name]
if not expr.indices:
# Scalar symbol -- just return as-is.
return arr
# Resolve component indices to integers (literal '1' → 0 etc.)
def _resolve(i: str) -> int:
if i in index_map:
return index_map[i]
try:
return int(i) - 1
except ValueError:
raise KeyError(
f"Index '{i}' not found in index_map and is not a "
f"literal integer. Pass it via the fixed_indices "
f"argument of evaluate_coupling()."
) from None
idx = tuple(_resolve(i) for i in expr.indices)
# Detect whether ``arr`` carries a leading sample axis.
# rank of the static tensor is len(expr.indices); compare
# ``arr.ndim`` to decide whether to prepend ``slice(None)``.
if arr.ndim == len(idx) + 1:
# Batched: arr shape is (n_samples,) + (N,)*rank.
return arr[(slice(None),) + idx]
if arr.ndim == len(idx):
# Static: shape is (N,)*rank — return scalar; caller's
# Product / Sum logic will broadcast.
return arr[idx]
raise ValueError(
f"Symbol {expr.name!r} has rank {len(idx)} but the supplied "
f"array has shape {arr.shape}; expected either rank or "
f"rank+1 (with leading sample axis of length {n_samples})."
)
if isinstance(expr, KroneckerDelta):
def _resolve(i: str) -> int:
if i in index_map:
return int(index_map[i])
try:
return int(i)
except ValueError:
raise KeyError(
f"KroneckerDelta index {i!r} not found in index_map "
f"and is not a literal integer. Pass it via the "
f"fixed_indices argument of evaluate_coupling()."
) from None
return 1.0 if _resolve(expr.index1) == _resolve(expr.index2) else 0.0
if isinstance(expr, Product):
result: object = 1.0
for f in expr.factors:
result = result * _eval_symbolic_batched(
f, symbol_values_batched, index_map, n_samples,
)
return result
if isinstance(expr, Sum):
terms = [
_eval_symbolic_batched(
t, symbol_values_batched, index_map, n_samples,
)
for t in expr.terms
]
if not terms:
return 0.0
total = terms[0]
for t in terms[1:]:
total = total + t
return total
raise NotImplementedError(
f"_eval_symbolic_batched: node type {type(expr).__name__} "
"is not vectorisable; caller should fall back to the scalar "
"per-sample loop."
)
def _is_zero(expr: Expr) -> bool:
return isinstance(expr, Rational) and expr.is_zero
def _coeff_latex(pref: Rational, n_response: int) -> str:
r"""Render ``rational_prefactor × (-i)^{n_response}`` as a single LaTeX token.
Combines the sign of the rational with the phase so that the output
is a clean fraction like ``\frac{\mathrm{i}}{6}`` rather than two
separate tokens ``\mathrm{i} -\frac{1}{6}``.
"""
from fractions import Fraction
# Phase by n_response mod 4:
# 0 → +1, 1 → -i, 2 → -1, 3 → +i
# Represent as (real_sign, is_imag):
# 0 → (+1, False), 1 → (-1, True), 2 → (-1, False), 3 → (+1, True)
phase_sign = [1, -1, -1, 1][n_response % 4]
is_imag = (n_response % 2) == 1
# Effective rational = pref × phase_sign (absorb sign into numerator)
f = Fraction(pref.numerator * phase_sign, pref.denominator)
num, den = abs(f.numerator), f.denominator
negative = f < 0
# Build the numerator string
if is_imag and num == 1:
num_str = r"\mathrm{i}"
elif is_imag:
num_str = rf"{num}\,\mathrm{{i}}"
else:
num_str = str(num)
# Assemble
sign = "-" if negative else ""
if num == 0:
return "0"
if den == 1 and num_str == "1" and not is_imag:
# coefficient is ±1 → omit (sign is still prepended if negative)
return f"{sign}1" if negative else ""
if den == 1:
return f"{sign}{num_str}"
return rf"{sign}\frac{{{num_str}}}{{{den}}}"
# =====================================================================
# Fast numerical path: nauty + einsum
# =====================================================================
[docs]
def compute_moment_numerical(
observable: list[FieldOperator],
action: Action,
order: int,
coupling_values: dict,
fixed_indices: dict[str, int] | None = None,
ito: bool = True,
diag_R: bool = False,
diag_C: bool = False,
iso_R: bool = False,
iso_C: bool = False,
n_jobs: int = 1,
) -> dict[int, list[DiagramTerm]]:
r"""Compute diagram terms using nauty graph isomorphism.
A faster alternative to :func:`compute_moment` that replaces the
:math:`k!` brute-force canonical form search with the nauty graph
isomorphism algorithm (via ``pynauty``), enabling order-6
calculations that were previously infeasible.
Optionally parallelizes the nauty canonicalization and component
routing steps using ``joblib`` (install with ``pip install
sft-wick[parallel]``).
Args:
observable: Field operators defining the observable.
action: The interaction action.
order: Maximum perturbative order.
coupling_values: ``{name: numpy_array}`` mapping coupling names
to numeric tensors.
fixed_indices: ``{index_label: int_value}`` for observable
component indices (e.g. ``{'a': 1, 'b': 1}``).
ito: Apply Itô prescription (default True).
diag_R: Enforce diagonal R propagators.
diag_C: Enforce diagonal C propagators.
iso_R: Isotropic R (implies ``diag_R``).
iso_C: Isotropic C (implies ``diag_C``).
n_jobs: Number of parallel workers for nauty canonicalization
and component routing. ``1`` = sequential (default),
``-1`` = use all CPU cores. Requires ``joblib``.
Returns:
``{order_n: [DiagramTerm, ...]}`` for each order with
non-zero contributions. Use :meth:`DiagramTerm.build_integrand`
for numerical evaluation.
"""
from collections import defaultdict
from .simplify import _canonical_key_nauty
diag_R = diag_R or iso_R
diag_C = diag_C or iso_C
fixed_indices = fixed_indices or {}
result: dict[int, list[DiagramTerm]] = {}
for n in range(order + 1):
sign = (-1) ** n
fact = factorial(n)
integrands: list[DiagramIntegrand] = []
if n == 0:
# Zeroth order: <O>_{S_0}
if len(observable) >= 2:
_, pairings = wick_contract(observable, ito=ito)
for p in pairings:
props = []
for i, j in p:
pr = contract_pair(observable[i], observable[j], ito=ito)
if isinstance(pr, Propagator):
props.append(pr)
if props:
dt = DiagramTerm(
propagators=tuple(props),
coupling_sum=Rational(1),
rational_prefactor=Rational(1),
integration_vars=(),
summation_indices=(),
n_response=sum(1 for pr in props if pr.kind == "R"),
)
if diag_R or diag_C:
dt = dt.apply_diagonal(
diag_R=diag_R, diag_C=diag_C,
iso_R=iso_R, iso_C=iso_C,
)
integrands.append(dt)
else:
for vertex_seq, mc in action.all_vertex_combinations(n):
idx_ctx = IndexContext()
vis = [
VertexInstance.instantiate(v, idx_ctx, copy_id=k)
for k, v in enumerate(vertex_seq)
]
all_ops = list(observable)
for vi in vis:
all_ops.extend(vi.field_operators)
integration_vars = frozenset(
var for vi in vis for var in vi.spatial_variables
)
# Collect equal-time alias map (see collect_topology
# branch above for the full motivation).
_eq_time_alias_pairs2: list[tuple[str, str]] = []
for vi in vis:
for k, v in vi.equal_time_aliases or ():
_eq_time_alias_pairs2.append((k, v))
eq_time_aliases_tuple2 = tuple(
sorted(_eq_time_alias_pairs2)
)
# Spatial topology enumeration
spatial_results = wick_contract_spatial(
all_ops, ito=ito, vertex_points=integration_vars,
)
if not spatial_results:
continue
# Group topologies by nauty canonical form, then build
# symbolic DiagramTerms (compatible with build_integrand).
# Combined pipeline: for each topology, compute nauty
# canonical key + pruned component routing in one pass.
# Results are grouped by canonical key afterward.
# Build summation index info (once)
sum_indices: list[tuple[str, int]] = []
for vi in vis:
for comp_idx in vi.component_indices:
for op in vi.field_operators:
if op.component_index == comp_idx:
sum_indices.append(
(comp_idx, op.field.n_components)
)
break
prefactor = Rational(sign * mc, fact)
int_vars_sorted = tuple(sorted(integration_vars))
topo_list = list(spatial_results.values())
def _process_topology(ref_props, mult, rep_pairing):
"""Nauty key + pruned routing for one topology."""
key = _canonical_key_nauty(
ref_props, integration_vars
)
routings = [
props for props, _pairing in
_enumerate_component_routings(
ref_props, rep_pairing, all_ops,
integration_vars,
)
]
return key, routings
# Run pipeline (parallel if beneficial)
if n_jobs != 1 and len(topo_list) > 100:
try:
from joblib import Parallel, delayed
topo_results = Parallel(n_jobs=n_jobs)(
delayed(_process_topology)(rp, m, p)
for rp, m, p in topo_list
)
except ImportError:
topo_results = [
_process_topology(rp, m, p)
for rp, m, p in topo_list
]
else:
topo_results = [
_process_topology(rp, m, p)
for rp, m, p in topo_list
]
# Group by canonical key
canon_groups: dict[bytes, list] = defaultdict(list)
ref_props_by_key: dict[bytes, list] = {}
for (rp, _m, _p), (key, routings) in zip(
topo_list, topo_results
):
if routings:
canon_groups[key].extend(routings)
if key not in ref_props_by_key:
ref_props_by_key[key] = rp
# Build one DiagramTerm per routing.
# Each routing has its own propagator component indices
# and the coupling uses the vertex's static field ordering
# (the coupling IS a vertex property; the routing only
# changes which propagator carries which component index).
for key, all_routing_props in canon_groups.items():
# Build the (static) coupling symbol product
coupling_syms: list[Expr] = []
for vi in vis:
idx_vals = tuple(
op.component_index
for op in vi.field_operators
)
coupling_syms.append(Symbol(
name=vi.vertex.coupling,
indices=idx_vals,
spatial_args=(),
))
coupling_expr: Expr = (
Product(tuple(coupling_syms))
if len(coupling_syms) > 1
else coupling_syms[0]
)
for props in all_routing_props:
props_tuple = tuple(props)
r_absorbed_pairs2, leg_aliases2 = (
_collect_r_absorbed_pairs(props_tuple, vis)
)
merged_aliases2 = (
tuple(sorted(
set(eq_time_aliases_tuple2) | set(leg_aliases2)
))
if leg_aliases2 else eq_time_aliases_tuple2
)
integrands.append(DiagramTerm(
propagators=props_tuple,
coupling_sum=coupling_expr,
rational_prefactor=prefactor,
integration_vars=int_vars_sorted,
summation_indices=tuple(sum_indices),
n_response=sum(
1 for p in props_tuple
if p.kind == "R"
),
equal_time_aliases=merged_aliases2,
r_absorbed_pairs=r_absorbed_pairs2,
))
# Apply diagonal constraints
if (diag_R or diag_C) and integrands:
integrands = [
dt.apply_diagonal(
diag_R=diag_R, diag_C=diag_C,
iso_R=iso_R, iso_C=iso_C,
)
for dt in integrands
]
if integrands:
result[n] = integrands
return result