"""Core Wick contraction engine.
Enumerates all valid pairings (contractions) of field operators and evaluates
each pairing as a product of propagators.
"""
from __future__ import annotations
from dataclasses import dataclass
from itertools import combinations, permutations
from math import factorial
from typing import Iterator, Optional, Sequence
from .expressions import Expr, Product, Propagator, Sum, ZERO
from .fields import FieldOperator, FieldType
from .propagators import contract_pair
# A pairing is a tuple of (index_i, index_j) pairs
Pairing = tuple[tuple[int, int], ...]
# Cache for R-cycle detection keyed by spatial R-edge structure
_r_cycle_cache: dict[frozenset[tuple[str, str]], bool] = {}
[docs]
def generate_all_pairings(indices: list[int]) -> Iterator[Pairing]:
"""Generate all complete pairings of the given indices.
For 2n items, yields (2n-1)!! = 1*3*5*...*(2n-1) pairings.
Each pairing is a tuple of n pairs.
"""
n = len(indices)
if n == 0:
yield ()
return
if n % 2 != 0:
return # odd count -> no complete pairing
first = indices[0]
rest = indices[1:]
for i, partner in enumerate(rest):
remaining = rest[:i] + rest[i + 1 :]
for sub_pairing in generate_all_pairings(remaining):
yield ((first, partner),) + sub_pairing
[docs]
def generate_valid_pairings(
phi_indices: list[int],
psi_indices: list[int],
operators: Sequence[FieldOperator] | None = None,
ito: bool = False,
) -> Iterator[Pairing]:
"""Generate only non-vanishing pairings (no psi-psi contractions).
Exploits the MSR structure: each psi must pair with a phi (producing R),
and remaining phi's pair among themselves (producing C).
This avoids generating the many zero-valued psi-psi pairings.
When *operators* and *ito* are provided, also skips phi-psi pairs
at the same spatial point (R(x,x)=0 under Itô) and phi-psi
assignments whose R edges form a causal cycle.
"""
n_phi = len(phi_indices)
n_psi = len(psi_indices)
# Quick checks
if n_psi > n_phi:
return # not enough phi's
if (n_phi - n_psi) % 2 != 0:
return # remaining phi's can't pair
if n_psi == 0:
# All physical fields: pair them all
yield from generate_all_pairings(phi_indices)
return
# Precompute spatial args for Itô pruning
do_ito_prune = ito and operators is not None
if do_ito_prune:
psi_spatial = [operators[psi_indices[j]].spatial_arg for j in range(n_psi)]
# Choose which phi's pair with the psi's
for phi_subset_indices in combinations(range(n_phi), n_psi):
chosen_phis = [phi_indices[k] for k in phi_subset_indices]
remaining_phis = [
phi_indices[k] for k in range(n_phi) if k not in phi_subset_indices
]
# All ways to match psi's to chosen phi's
for perm in permutations(range(n_psi)):
# Early Itô: skip if any phi-psi pair shares a spatial point
if do_ito_prune:
skip = False
for j in range(n_psi):
phi_spatial = operators[chosen_phis[perm[j]]].spatial_arg
if phi_spatial == psi_spatial[j]:
skip = True
break
if skip:
continue
# Early R-cycle check on the phi-psi spatial edges
r_edges = frozenset(
(operators[chosen_phis[perm[j]]].spatial_arg, psi_spatial[j])
for j in range(n_psi)
)
cached = _r_cycle_cache.get(r_edges)
if cached is None:
# Quick inline cycle check on just these R edges
adj: dict[str, list[str]] = {}
for sl, sr in r_edges:
adj.setdefault(sl, []).append(sr)
WHITE, GRAY, BLACK = 0, 1, 2
colour: dict[str, int] = {}
for v in adj:
colour[v] = WHITE
for targets in adj.values():
for t in targets:
colour.setdefault(t, WHITE)
def _dfs_early(node: str) -> bool:
colour[node] = GRAY
for nb in adj.get(node, ()):
if colour[nb] == GRAY:
return True
if colour[nb] == WHITE and _dfs_early(nb):
return True
colour[node] = BLACK
return False
cached = any(
_dfs_early(v) for v, c in list(colour.items()) if c == WHITE
)
_r_cycle_cache[r_edges] = cached
if cached:
continue
r_pairs = tuple(
(chosen_phis[perm[j]], psi_indices[j]) for j in range(n_psi)
)
# All ways to pair remaining phi's among themselves
for c_pairing in generate_all_pairings(remaining_phis):
yield r_pairs + c_pairing
def _has_r_cycle(propagators: list[Propagator]) -> bool:
"""Check whether the R propagators form a directed cycle.
A cycle R(a₁,a₂) R(a₂,a₃) ... R(aₙ,a₁) vanishes by causality
because R is retarded: it would require t₁ > t₂ > ... > tₙ > t₁.
"""
# Build directed adjacency from R spatial args (skip equal-point,
# which is already handled by the Itô rule).
adj: dict[str, list[str]] = {}
for p in propagators:
if p.kind == "R" and p.spatial_left != p.spatial_right:
adj.setdefault(p.spatial_left, []).append(p.spatial_right)
if not adj:
return False
# Standard DFS cycle detection (white/gray/black colouring)
WHITE, GRAY, BLACK = 0, 1, 2
colour: dict[str, int] = {v: WHITE for v in adj}
# Also register nodes that only appear as targets
for targets in adj.values():
for t in targets:
colour.setdefault(t, WHITE)
def _dfs(node: str) -> bool:
colour[node] = GRAY
for nb in adj.get(node, ()):
if colour[nb] == GRAY:
return True # back edge → cycle
if colour[nb] == WHITE and _dfs(nb):
return True
colour[node] = BLACK
return False
return any(
_dfs(v) for v, c in list(colour.items()) if c == WHITE
)
[docs]
def evaluate_pairing(
operators: Sequence[FieldOperator],
pairing: Pairing,
ito: bool = True,
) -> Optional[tuple[Expr, list[Propagator]]]:
r"""Evaluate a single complete pairing.
Returns ``(Product of propagators, list of individual propagators)``
or ``None`` if any contraction vanishes.
Vanishing conditions:
- :math:`\psi`\ --\ :math:`\psi` contraction
- Equal-point :math:`R(x,x)=0` when *ito* is ``True``
- Causal R-loop: any directed cycle among R propagator spatial
arguments (e.g. :math:`R(a,b)\,R(b,a)=0`) when *ito* is ``True``
Args:
operators: The full list of field operators.
pairing: Tuple of ``(i, j)`` index pairs.
ito: If ``True``, apply the Itô prescription and causal
vanishing rules.
"""
propagators: list[Propagator] = []
for i, j in pairing:
prop = contract_pair(operators[i], operators[j], ito=ito)
if prop is None:
return None
propagators.append(prop)
if len(propagators) == 0:
return None
# Causal vanishing: R-loop detection (with memoization by spatial edges)
if ito:
r_spatial_key = frozenset(
(p.spatial_left, p.spatial_right)
for p in propagators
if p.kind == "R" and p.spatial_left != p.spatial_right
)
if r_spatial_key:
cached = _r_cycle_cache.get(r_spatial_key)
if cached is None:
cached = _has_r_cycle(propagators)
_r_cycle_cache[r_spatial_key] = cached
if cached:
return None
if len(propagators) == 1:
return propagators[0], propagators
return Product(tuple(propagators)), propagators
[docs]
def wick_contract(
operators: Sequence[FieldOperator],
ito: bool = True,
) -> tuple[Expr, list[Pairing]]:
r"""Apply Wick's theorem to a product of field operators.
Args:
operators: Sequence of field operators to contract.
ito: If ``True``, apply the Itô prescription
:math:`\Theta(0)=0`: response propagators at equal spatial
points vanish, i.e. :math:`R(x,x)=0`.
Returns:
``(expression, surviving_pairings)`` where *expression* is a
Sum of Products of Propagators, and *surviving_pairings* lists
the non-zero pairings.
"""
_r_cycle_cache.clear()
n = len(operators)
if n == 0:
return ZERO, []
if n % 2 != 0:
return ZERO, []
# Separate phi and psi operators (by index in the operators list)
phi_indices: list[int] = []
psi_indices: list[int] = []
for idx, op in enumerate(operators):
if op.field_type == FieldType.PHYSICAL:
phi_indices.append(idx)
else:
psi_indices.append(idx)
# Check feasibility
if len(psi_indices) > len(phi_indices):
return ZERO, []
if (len(phi_indices) - len(psi_indices)) % 2 != 0:
return ZERO, []
terms: list[Expr] = []
surviving_pairings: list[Pairing] = []
for pairing in generate_valid_pairings(
phi_indices, psi_indices, operators=operators, ito=ito
):
result = evaluate_pairing(operators, pairing, ito=ito)
if result is not None:
expr, _props = result
terms.append(expr)
surviving_pairings.append(pairing)
if not terms:
return ZERO, []
if len(terms) == 1:
return terms[0], surviving_pairings
return Sum(tuple(terms)), surviving_pairings
# Spatial signature: sorted tuple of canonical (kind, spatial_left, spatial_right)
SpatialSignature = tuple[tuple[str, str, str], ...]
def _canonical_edge(kind: str, sl: str, sr: str) -> tuple[str, str, str]:
"""Canonical form for a propagator edge (C is symmetric)."""
if kind == "C":
return ("C", min(sl, sr), max(sl, sr))
return (kind, sl, sr)
[docs]
def wick_contract_grouped(
operators: Sequence[FieldOperator],
ito: bool = True,
) -> tuple[
dict[SpatialSignature, list[tuple[list[Propagator], Pairing]]],
list[Pairing],
]:
r"""Apply Wick's theorem, grouping results by spatial propagator signature.
Like :func:`wick_contract` but instead of returning a flat Sum of
Products, groups surviving pairings by their spatial propagator
topology. Pairings in the same group share the same Feynman diagram
but differ in component-index routing.
Returns:
``(groups, all_pairings)`` where *groups* maps each spatial
signature to a list of ``(propagator_list, pairing)`` tuples,
and *all_pairings* is the flat list of surviving pairings.
"""
_r_cycle_cache.clear()
n = len(operators)
if n == 0:
return {}, []
if n % 2 != 0:
return {}, []
phi_indices: list[int] = []
psi_indices: list[int] = []
for idx, op in enumerate(operators):
if op.field_type == FieldType.PHYSICAL:
phi_indices.append(idx)
else:
psi_indices.append(idx)
if len(psi_indices) > len(phi_indices):
return {}, []
if (len(phi_indices) - len(psi_indices)) % 2 != 0:
return {}, []
groups: dict[SpatialSignature, list[tuple[list[Propagator], Pairing]]] = {}
all_pairings: list[Pairing] = []
for pairing in generate_valid_pairings(
phi_indices, psi_indices, operators=operators, ito=ito
):
result = evaluate_pairing(operators, pairing, ito=ito)
if result is None:
continue
_expr, props = result
all_pairings.append(pairing)
sig = tuple(sorted(
_canonical_edge(p.kind, p.spatial_left, p.spatial_right)
for p in props
))
groups.setdefault(sig, []).append((props, pairing))
return groups, all_pairings
# ---------------------------------------------------------------------------
# Spatial-level Wick contraction engine
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class SpatialTopology:
"""A spatial-level Wick pairing.
Represents a complete contraction at the spatial-point level,
abstracting away component indices.
"""
r_edges: tuple[tuple[str, str], ...] # sorted (phi_point, psi_point)
c_edges: tuple[tuple[str, str], ...] # sorted canonical (min, max) pairs
multiplicity: int # equivalent operator-level pairings
def _has_r_cycle_edges(r_edges: list[tuple[str, str]]) -> bool:
"""Check whether directed R-edges form a cycle."""
adj: dict[str, list[str]] = {}
for sl, sr in r_edges:
if sl != sr:
adj.setdefault(sl, []).append(sr)
if not adj:
return False
WHITE, GRAY, BLACK = 0, 1, 2
colour: dict[str, int] = {}
for v in adj:
colour[v] = WHITE
for targets in adj.values():
for t in targets:
colour.setdefault(t, WHITE)
def _dfs(node: str) -> bool:
colour[node] = GRAY
for nb in adj.get(node, ()):
if colour[nb] == GRAY:
return True
if colour[nb] == WHITE and _dfs(nb):
return True
colour[node] = BLACK
return False
return any(_dfs(v) for v, c in list(colour.items()) if c == WHITE)
def _enumerate_r_assignments(
psi_list: list[str],
phi_capacity: dict[str, int],
ito: bool,
idx: int = 0,
r_so_far: list[tuple[str, str]] | None = None,
) -> Iterator[list[tuple[str, str]]]:
"""Enumerate valid R-assignments at the spatial level.
Each ψ at point ``psi_list[idx]`` is assigned to a φ at some point
with remaining capacity.
Deduplication: consecutive ψ's at the same point get non-decreasing
φ-source ordering.
"""
if r_so_far is None:
r_so_far = []
if idx == len(psi_list):
if not _has_r_cycle_edges(r_so_far):
yield list(r_so_far)
return
psi_point = psi_list[idx]
# Determine minimum φ-source for dedup of same-point ψ's
min_phi: str | None = None
if idx > 0 and psi_list[idx] == psi_list[idx - 1]:
min_phi = r_so_far[-1][0]
for phi_point in sorted(phi_capacity.keys()):
if phi_capacity[phi_point] <= 0:
continue
if ito and phi_point == psi_point:
continue
if min_phi is not None and phi_point < min_phi:
continue
phi_capacity[phi_point] -= 1
r_so_far.append((phi_point, psi_point))
yield from _enumerate_r_assignments(
psi_list, phi_capacity, ito, idx + 1, r_so_far
)
r_so_far.pop()
phi_capacity[phi_point] += 1
def _enumerate_c_pairings(
remaining: dict[str, int],
_min_cross_partner: str | None = None,
) -> Iterator[list[tuple[str, str]]]:
"""Enumerate ways to pair remaining φ-slots into C propagators.
Always picks from the lexicographically first available point to
ensure each spatial C-pairing is yielded exactly once.
When a point has multiple remaining slots going to cross-edges,
``_min_cross_partner`` enforces non-decreasing partner ordering
to avoid generating the same spatial pairing twice.
"""
# Find first point with remaining > 0
first_pt: str | None = None
for pt in sorted(remaining.keys()):
if remaining[pt] > 0:
first_pt = pt
break
if first_pt is None:
yield []
return
remaining[first_pt] -= 1
# Self-loop (no ordering constraint — self-loops are distinct from cross)
if remaining[first_pt] > 0:
remaining[first_pt] -= 1
for sub in _enumerate_c_pairings(remaining):
yield [(first_pt, first_pt)] + sub
remaining[first_pt] += 1
# Cross-edge to another point
for other_pt in sorted(remaining.keys()):
if other_pt == first_pt:
continue
if remaining[other_pt] <= 0:
continue
if _min_cross_partner is not None and other_pt < _min_cross_partner:
continue
remaining[other_pt] -= 1
# If first_pt still has remaining, enforce non-decreasing partner
next_min = other_pt if remaining[first_pt] > 0 else None
for sub in _enumerate_c_pairings(remaining, next_min):
yield [(first_pt, other_pt)] + sub
remaining[other_pt] += 1
remaining[first_pt] += 1
def _compute_multiplicity(
c_edges: list[tuple[str, str]],
vertex_phi_counts: dict[str, int],
) -> int:
r"""Compute the operator-level multiplicity for a spatial topology.
Formula: :math:`\frac{\prod_v m_v!}{2^{N_{\text{self}}} \cdot \prod_e k_e!}`
where *m_v* is the number of φ operators from vertex *v*,
*N_self* is the total number of C self-loops, and
*k_e* is the number of parallel C edges between each pair of
spatial points (parallel edges are indistinguishable).
``vertex_phi_counts`` maps each **vertex** spatial point to its φ count.
Observable points are excluded (they have multiplicity 1).
"""
# Count self-loops per spatial point
self_loops: dict[str, int] = {}
for s1, s2 in c_edges:
if s1 == s2:
self_loops[s1] = self_loops.get(s1, 0) + 1
mult = 1
for v_point, m_v in vertex_phi_counts.items():
n_self = self_loops.get(v_point, 0)
mult *= factorial(m_v) // (2 ** n_self)
# Divide by k! for each group of k parallel C edges
edge_counts: dict[tuple[str, str], int] = {}
for s1, s2 in c_edges:
key = (min(s1, s2), max(s1, s2))
edge_counts[key] = edge_counts.get(key, 0) + 1
for count in edge_counts.values():
if count > 1:
mult //= factorial(count)
return mult
[docs]
def wick_contract_spatial(
operators: Sequence[FieldOperator],
ito: bool = True,
vertex_points: frozenset[str] | None = None,
) -> dict[SpatialSignature, tuple[list[Propagator], int, Pairing]]:
r"""Spatial-level Wick contraction.
Enumerates spatial topologies (not individual operator-level pairings)
and computes a multiplicity for each. This avoids the combinatorial
explosion from component-index routing.
Args:
operators: Sequence of field operators.
ito: Apply the Itô prescription.
vertex_points: Frozenset of spatial args that belong to vertices
(not observables). Needed for multiplicity computation.
Returns:
Dict mapping each :data:`SpatialSignature` to
``(reference_propagators, multiplicity, representative_pairing)``.
"""
_r_cycle_cache.clear()
n = len(operators)
if n == 0 or n % 2 != 0:
return {}
# Group operators by (spatial_arg, field_type)
phi_at: dict[str, list[int]] = {} # spatial -> [operator indices]
psi_at: dict[str, list[int]] = {}
for idx, op in enumerate(operators):
if op.field_type == FieldType.PHYSICAL:
phi_at.setdefault(op.spatial_arg, []).append(idx)
else:
psi_at.setdefault(op.spatial_arg, []).append(idx)
phi_counts = {s: len(ops) for s, ops in phi_at.items()}
psi_counts = {s: len(ops) for s, ops in psi_at.items()}
total_phi = sum(phi_counts.values())
total_psi = sum(psi_counts.values())
if total_psi > total_phi:
return {}
if (total_phi - total_psi) % 2 != 0:
return {}
# Build flat ψ-list for R-assignment enumeration (sorted for determinism)
psi_list: list[str] = []
for s in sorted(psi_counts.keys()):
psi_list.extend([s] * psi_counts[s])
# Vertex φ-counts for multiplicity (exclude observable points)
if vertex_points is None:
vertex_points = frozenset()
vertex_phi_counts = {
s: c for s, c in phi_counts.items() if s in vertex_points
}
results: dict[SpatialSignature, tuple[list[Propagator], int, Pairing]] = {}
phi_cap = dict(phi_counts)
for r_assign in _enumerate_r_assignments(psi_list, phi_cap, ito):
# Compute remaining φ-capacity after R-assignment
remaining = dict(phi_counts)
for phi_s, _psi_s in r_assign:
remaining[phi_s] -= 1
for c_pairing in _enumerate_c_pairings(remaining):
# Compute multiplicity
mult = _compute_multiplicity(c_pairing, vertex_phi_counts)
# Build canonical spatial signature
sig_edges: list[tuple[str, str, str]] = []
for phi_s, psi_s in r_assign:
sig_edges.append(("R", phi_s, psi_s))
for s1, s2 in c_pairing:
sig_edges.append(_canonical_edge("C", s1, s2))
sig: SpatialSignature = tuple(sorted(sig_edges))
if sig in results:
# Same signature reached via different enumeration path
# (shouldn't happen with canonical enumeration, but be safe)
old_props, old_mult, old_pairing = results[sig]
results[sig] = (old_props, old_mult + mult, old_pairing)
continue
# Build representative propagators and pairing
# Track which operators have been consumed at each point
phi_used: dict[str, int] = {s: 0 for s in phi_at}
psi_used: dict[str, int] = {s: 0 for s in psi_at}
props: list[Propagator] = []
pairs: list[tuple[int, int]] = []
for phi_s, psi_s in r_assign:
phi_idx = phi_at[phi_s][phi_used[phi_s]]
psi_idx = psi_at[psi_s][psi_used[psi_s]]
phi_used[phi_s] += 1
psi_used[psi_s] += 1
phi_op = operators[phi_idx]
psi_op = operators[psi_idx]
props.append(Propagator(
kind="R",
index_left=phi_op.component_index,
index_right=psi_op.component_index,
spatial_left=phi_op.spatial_arg,
spatial_right=psi_op.spatial_arg,
))
pairs.append((phi_idx, psi_idx))
for s1, s2 in c_pairing:
idx1 = phi_at[s1][phi_used[s1]]
phi_used[s1] += 1
idx2 = phi_at[s2][phi_used[s2]]
phi_used[s2] += 1
op1 = operators[idx1]
op2 = operators[idx2]
props.append(Propagator(
kind="C",
index_left=op1.component_index,
index_right=op2.component_index,
spatial_left=op1.spatial_arg,
spatial_right=op2.spatial_arg,
))
pairs.append((idx1, idx2))
results[sig] = (props, mult, tuple(pairs))
return results