Source code for sft_wick.simplify

"""Expression simplification for Wick contraction results.

Pipeline:
1. Flatten nested Sum/Product
2. Absorb rational prefactors in products
3. Eliminate zeros
4. Canonical ordering of propagators
5. Term collection (combine like terms)
6. Diagram-based collection (group by Feynman diagram isomorphism)
"""

from __future__ import annotations

from collections import defaultdict
from fractions import Fraction
from itertools import permutations

from .expressions import (
    ZERO,
    ONE,
    Expr,
    IntegralOver,
    KroneckerDelta,
    Product,
    Propagator,
    Rational,
    Sum,
    SumOverIndex,
    Symbol,
)


[docs] def simplify(expr: Expr) -> Expr: """Main simplification entry point.""" expr = _flatten(expr) expr = _absorb_rationals(expr) expr = _eliminate_zeros(expr) expr = _collect_terms(expr) return expr
def _flatten(expr: Expr) -> Expr: """Recursively flatten nested Sum and Product.""" if isinstance(expr, Sum): flat_terms: list[Expr] = [] for t in expr.terms: t = _flatten(t) if isinstance(t, Sum): flat_terms.extend(t.terms) else: flat_terms.append(t) return Sum(tuple(flat_terms)) if isinstance(expr, Product): flat_factors: list[Expr] = [] for f in expr.factors: f = _flatten(f) if isinstance(f, Product): flat_factors.extend(f.factors) else: flat_factors.append(f) return Product(tuple(flat_factors)) if isinstance(expr, IntegralOver): return IntegralOver(expr.variable, _flatten(expr.body)) if isinstance(expr, SumOverIndex): return SumOverIndex(expr.index_name, expr.dimension, _flatten(expr.body)) return expr def _absorb_rationals(expr: Expr) -> Expr: """In a Product, multiply all Rational factors into a single prefactor.""" if isinstance(expr, Product): coeff = Fraction(1) other_factors: list[Expr] = [] for f in expr.factors: f = _absorb_rationals(f) if isinstance(f, Rational): coeff *= f.to_fraction() else: other_factors.append(f) if coeff == 0: return ZERO if not other_factors: return Rational(coeff.numerator, coeff.denominator) if coeff == 1: if len(other_factors) == 1: return other_factors[0] return Product(tuple(other_factors)) return Product( (Rational(coeff.numerator, coeff.denominator), *other_factors) ) if isinstance(expr, Sum): return Sum(tuple(_absorb_rationals(t) for t in expr.terms)) if isinstance(expr, IntegralOver): return IntegralOver(expr.variable, _absorb_rationals(expr.body)) if isinstance(expr, SumOverIndex): return SumOverIndex( expr.index_name, expr.dimension, _absorb_rationals(expr.body) ) return expr def _eliminate_zeros(expr: Expr) -> Expr: """Remove zero terms from sums; collapse products containing zero.""" if isinstance(expr, Sum): terms = [_eliminate_zeros(t) for t in expr.terms] terms = [t for t in terms if not _is_zero(t)] if not terms: return ZERO if len(terms) == 1: return terms[0] return Sum(tuple(terms)) if isinstance(expr, Product): factors = [_eliminate_zeros(f) for f in expr.factors] if any(_is_zero(f) for f in factors): return ZERO factors = [f for f in factors if not _is_one(f)] if not factors: return ONE if len(factors) == 1: return factors[0] return Product(tuple(factors)) if isinstance(expr, IntegralOver): body = _eliminate_zeros(expr.body) if _is_zero(body): return ZERO return IntegralOver(expr.variable, body) if isinstance(expr, SumOverIndex): body = _eliminate_zeros(expr.body) if _is_zero(body): return ZERO return SumOverIndex(expr.index_name, expr.dimension, body) return expr def _is_zero(expr: Expr) -> bool: return isinstance(expr, Rational) and expr.is_zero def _is_one(expr: Expr) -> bool: return isinstance(expr, Rational) and expr.is_one def _propagator_key(p: Propagator) -> tuple: """Canonical sort key for a propagator.""" return (p.kind, p.spatial_left, p.spatial_right, p.index_left or "", p.index_right or "") def _term_signature(expr: Expr) -> tuple | None: """Extract the non-coefficient part of a term for grouping. A term is either: - A single Propagator - A Product of (optional Rational) * Propagators * Symbols * IntegralOvers * SumOverIndices Returns a hashable signature, or None if it can't be grouped. """ if isinstance(expr, Propagator): return (("prop", _propagator_key(expr)),) if isinstance(expr, Product): parts: list[tuple] = [] for f in expr.factors: if isinstance(f, Rational): continue # skip coefficient if isinstance(f, Propagator): parts.append(("prop", _propagator_key(f))) elif isinstance(f, Symbol): parts.append(("sym", f.name, f.indices, f.spatial_args)) elif isinstance(f, IntegralOver): parts.append(("int", f.variable)) elif isinstance(f, SumOverIndex): parts.append(("sum", f.index_name, f.dimension)) else: return None # can't group return tuple(sorted(parts)) return None def _get_coefficient(expr: Expr) -> Fraction: """Extract the rational coefficient from a term.""" if isinstance(expr, Rational): return expr.to_fraction() if isinstance(expr, Product): coeff = Fraction(1) for f in expr.factors: if isinstance(f, Rational): coeff *= f.to_fraction() return coeff return Fraction(1) def _set_coefficient(expr: Expr, coeff: Fraction) -> Expr: """Replace the rational coefficient of a term.""" if isinstance(expr, Rational): return Rational(coeff.numerator, coeff.denominator) if isinstance(expr, Product): non_rational = [f for f in expr.factors if not isinstance(f, Rational)] if coeff == 1: if len(non_rational) == 1: return non_rational[0] return Product(tuple(non_rational)) return Product( (Rational(coeff.numerator, coeff.denominator), *non_rational) ) # expr has implicit coefficient 1 if coeff == 1: return expr return Product((Rational(coeff.numerator, coeff.denominator), expr)) def _collect_terms(expr: Expr) -> Expr: """Combine terms with identical propagator structures.""" if not isinstance(expr, Sum): return expr groups: defaultdict[tuple | None, list[Expr]] = defaultdict(list) ungroupable: list[Expr] = [] for term in expr.terms: sig = _term_signature(term) if sig is not None: groups[sig].append(term) else: ungroupable.append(term) collected: list[Expr] = [] for sig, terms in groups.items(): total_coeff = sum((_get_coefficient(t) for t in terms), Fraction(0)) if total_coeff != 0: representative = terms[0] collected.append(_set_coefficient(representative, total_coeff)) collected.extend(ungroupable) if not collected: return ZERO if len(collected) == 1: return collected[0] return Sum(tuple(collected)) # --------------------------------------------------------------------------- # Diagram-based collection # --------------------------------------------------------------------------- def _apply_perm_to_symbol(sym: Symbol, perm: dict[str, str]) -> Symbol: """Relabel indices of a Symbol according to the permutation.""" new_indices = tuple(perm.get(idx, idx) for idx in sym.indices) new_spatial = tuple(perm.get(s, s) for s in sym.spatial_args) if new_indices == sym.indices and new_spatial == sym.spatial_args: return sym return Symbol(sym.name, new_indices, new_spatial) def _apply_perm_to_expr(expr: Expr, perm: dict[str, str]) -> Expr: """Apply index permutation to Symbols and Rationals in an expression.""" if not perm: return expr if isinstance(expr, Symbol): return _apply_perm_to_symbol(expr, perm) if isinstance(expr, Rational): return expr if isinstance(expr, Product): return Product(tuple(_apply_perm_to_expr(f, perm) for f in expr.factors)) return expr def _canonical_edge( kind: str, spatial_left: str, spatial_right: str, ) -> tuple[str, str, str]: """Canonical form for a single propagator edge. C edges are symmetric: ``C(x, y) = C(y, x)``, so we sort the spatial arguments. R edges are directed and kept as-is. """ if kind == "C": return ("C", min(spatial_left, spatial_right), max(spatial_left, spatial_right)) return (kind, spatial_left, spatial_right) def _canonical_diagram_form( props: list[Propagator], integration_vars: frozenset[str], ) -> tuple[tuple[tuple[str, str, str], ...], dict[str, str]]: """Compute canonical graph form for a set of propagators. Tries all permutations of integration variables (internal spatial points), keeping external spatial variables fixed. For each permutation, computes the sorted edge list using :func:`_canonical_edge`, then picks the lexicographically smallest. Returns ``(canonical_edges, best_mapping)`` where *best_mapping* maps original internal spatial variable names to the permuted names that produced the winning form. """ # Collect all spatial variables appearing in propagators all_spatial: set[str] = set() for p in props: all_spatial.add(p.spatial_left) all_spatial.add(p.spatial_right) # Internal = integration variables that actually appear in propagators internal = sorted(all_spatial & integration_vars) # Trivial case: no internal variables to permute if not internal: form = tuple(sorted( _canonical_edge(p.kind, p.spatial_left, p.spatial_right) for p in props )) return form, {} best_form: tuple[tuple[str, str, str], ...] | None = None best_mapping: dict[str, str] = {} # Try all permutations of internal variables (k! where k = |internal| ≤ 4 # in practice, so at most 24 iterations). Signature-based pruning is # deliberately avoided: local 1-hop signatures only capture graph # automorphisms, not the cross-graph relabelings needed for canonical # form minimization, and can incorrectly prevent the true minimum from # being found (e.g. when two nodes have different local degrees but the # graph is still isomorphic to another via a cross-signature permutation). for perm in permutations(internal): mapping: dict[str, str] = dict(zip(internal, perm)) edges_list: list[tuple[str, str, str]] = [] for p in props: sl = mapping.get(p.spatial_left, p.spatial_left) sr = mapping.get(p.spatial_right, p.spatial_right) edges_list.append(_canonical_edge(p.kind, sl, sr)) form = tuple(sorted(edges_list)) if best_form is None or form < best_form: best_form = form best_mapping = mapping assert best_form is not None return best_form, best_mapping def _canonical_key_nauty( props: list[Propagator], integration_vars: frozenset[str], ) -> bytes: """Compute a canonical graph certificate using nauty (via pynauty). Returns a ``bytes`` object that uniquely identifies the graph up to permutation of internal (integration) spatial variables. Two propagator lists produce the same certificate if and only if they represent isomorphic Feynman diagrams. The diagram is encoded as a directed vertex-colored graph with dummy intermediate nodes for edge-type encoding: - **R(a, b)**: ``a → R_dummy → b`` (two directed edges) - **C(a, b)**: ``a ↔ C_dummy ↔ b`` (bidirectional through dummy) Vertex coloring fixes external nodes (each gets a unique color) while allowing internal nodes and same-type dummies to be permuted. Falls back to :func:`_canonical_diagram_form` if ``pynauty`` is not installed. """ try: import pynauty except ImportError: canon, _mapping = _canonical_diagram_form(props, integration_vars) # Return bytes for consistency return str(canon).encode() # Collect spatial points all_spatial: set[str] = set() for p in props: all_spatial.add(p.spatial_left) all_spatial.add(p.spatial_right) external = sorted(all_spatial - integration_vars) internal = sorted(all_spatial & integration_vars) node_id: dict[str, int] = {} for i, name in enumerate(external): node_id[name] = i for i, name in enumerate(internal): node_id[name] = len(external) + i n_real = len(node_id) r_list = [p for p in props if p.kind == "R"] c_list = [p for p in props if p.kind == "C"] r_start = n_real c_start = n_real + len(r_list) total = n_real + len(r_list) + len(c_list) adj: dict[int, set[int]] = defaultdict(set) for k, p in enumerate(r_list): dummy = r_start + k left, right = node_id[p.spatial_left], node_id[p.spatial_right] adj[left].add(dummy) adj[dummy].add(right) for k, p in enumerate(c_list): dummy = c_start + k left, right = node_id[p.spatial_left], node_id[p.spatial_right] adj[left].add(dummy) adj[dummy].add(left) adj[right].add(dummy) adj[dummy].add(right) # Vertex coloring: each external is unique, internals share, # R-dummies share, C-dummies share coloring: list[set[int]] = [] for i in range(len(external)): coloring.append({i}) if internal: coloring.append(set(range(len(external), n_real))) if r_list: coloring.append(set(range(r_start, c_start))) if c_list: coloring.append(set(range(c_start, total))) g = pynauty.Graph(total, directed=True, vertex_coloring=coloring) for v, nbrs in adj.items(): if nbrs: g.connect_vertex(v, sorted(nbrs)) return pynauty.certificate(g) def _match_propagators_after_spatial( ref_props: list[Propagator], other_props: list[Propagator], spatial_perm: dict[str, str], internal_indices: set[str], ) -> dict[str, str] | None: """Find the component-index permutation after spatial relabeling. Applies *spatial_perm* to *other_props*' spatial arguments, then matches them to *ref_props* by ``(kind, spatial_left, spatial_right)``. For C propagators that match with reversed spatial arguments (C symmetry), the component indices are swapped. When multiple propagators share the same spatial signature, all valid matchings are tried to find one that yields a consistent component-index permutation. Returns a dict ``{other_index -> ref_index}`` for internal indices, or ``None`` if no consistent matching exists. """ # Apply spatial perm to other props remapped: list[tuple[Propagator, str, str, bool]] = [] for p in other_props: sl = spatial_perm.get(p.spatial_left, p.spatial_left) sr = spatial_perm.get(p.spatial_right, p.spatial_right) remapped.append((p, sl, sr, False)) # Group ref props by (kind, canonical_spatial_left, canonical_spatial_right) # where canonical means sorted for C ref_groups: dict[tuple[str, str, str], list[int]] = defaultdict(list) for i, p in enumerate(ref_props): key = _canonical_edge(p.kind, p.spatial_left, p.spatial_right) ref_groups[key].append(i) # Group remapped other props similarly, tracking whether C was flipped other_groups: dict[tuple[str, str, str], list[tuple[int, bool]]] = defaultdict(list) for i, (p, sl, sr, _) in enumerate(remapped): canon_key = _canonical_edge(p.kind, sl, sr) flipped = (p.kind == "C" and sl > sr) other_groups[canon_key].append((i, flipped)) # Check that groups match if set(ref_groups.keys()) != set(other_groups.keys()): return None for key in ref_groups: if len(ref_groups[key]) != len(other_groups[key]): return None # For each group, try all orderings of the other props # Build the full matching as a Cartesian product of per-group permutations group_keys = sorted(ref_groups.keys()) group_matchings: list[list[list[tuple[int, int, bool]]]] = [] for key in group_keys: ri_list = ref_groups[key] oi_list = other_groups[key] # Try all permutations of other indices against ref indices options: list[list[tuple[int, int, bool]]] = [] for oi_perm in permutations(oi_list): matching = [] for ri, (oi, flipped) in zip(ri_list, oi_perm): # Also check if ref C needs flipping rp = ref_props[ri] _, osl, osr, _ = remapped[oi] if rp.kind == "C": both_canonical = ( min(rp.spatial_left, rp.spatial_right) == min(osl, osr) and max(rp.spatial_left, rp.spatial_right) == max(osl, osr) ) if not both_canonical: continue # Determine if the actual (non-canonical) orders differ actual_flipped = (rp.spatial_left, rp.spatial_right) != (osl, osr) matching.append((ri, oi, actual_flipped)) else: # R: must match exactly after spatial perm if (rp.spatial_left, rp.spatial_right) != (osl, osr): continue matching.append((ri, oi, False)) if len(matching) == len(ri_list): options.append(matching) if not options: return None group_matchings.append(options) # Try all combinations of per-group matchings def _try_combinations( group_idx: int, perm: dict[str, str], ) -> dict[str, str] | None: if group_idx == len(group_matchings): return perm for option in group_matchings[group_idx]: new_perm = dict(perm) ok = True for ri, oi, flipped in option: rp = ref_props[ri] op = other_props[oi] if flipped: pairs = [(rp.index_left, op.index_right), (rp.index_right, op.index_left)] else: pairs = [(rp.index_left, op.index_left), (rp.index_right, op.index_right)] for ref_idx, other_idx in pairs: if ref_idx is None and other_idx is None: continue if ref_idx == other_idx: continue if (other_idx not in internal_indices or ref_idx not in internal_indices): ok = False break if other_idx in new_perm: if new_perm[other_idx] != ref_idx: ok = False break else: new_perm[other_idx] = ref_idx if not ok: break if ok: result = _try_combinations(group_idx + 1, new_perm) if result is not None: return result return None return _try_combinations(0, {}) # --------------------------------------------------------------------------- # Public API: collect_by_diagram # ---------------------------------------------------------------------------
[docs] def collect_by_diagram(expr: Expr) -> Expr: """Collect terms whose Feynman diagrams are isomorphic. Groups terms that represent the same Feynman diagram under relabeling of dummy integration variables and summation indices. Factors out canonical propagators and sums coupling coefficients with appropriately permuted indices. This correctly handles: - **Spatial variable relabeling**: integration variables can be freely permuted (e.g. ``y_0 ↔ y_1`` at second order). - **C propagator symmetry**: ``C(x, y) = C(y, x)``. - **Coupling permutation**: index permutations are applied to the outer coupling symbols, producing a sum of permuted couplings. Example:: F_{i0 i1 i2} × [R_{a i2}(x,y) C_{i0 i1}(y,y) + R_{a i1}(x,y) C_{i0 i2}(y,y)] becomes:: (F_{i0 i1 i2} + F_{i0 i2 i1}) × R_{a i2}(x,y) C_{i0 i1}(y,y) """ return _collect_inner(expr, frozenset(), frozenset())
def _collect_inner( expr: Expr, integration_vars: frozenset[str], summation_indices: frozenset[str], ) -> Expr: """Recursive worker for :func:`collect_by_diagram`. Accumulates integration variables and summation indices as it descends through wrapper nodes, then processes the inner Product/Sum structure. """ if isinstance(expr, IntegralOver): new_ivars = integration_vars | {expr.variable} body = _collect_inner(expr.body, new_ivars, summation_indices) return IntegralOver(expr.variable, body) if isinstance(expr, SumOverIndex): new_sindices = summation_indices | {expr.index_name} body = _collect_inner(expr.body, integration_vars, new_sindices) return SumOverIndex(expr.index_name, expr.dimension, body) if isinstance(expr, Sum): return Sum(tuple( _collect_inner(t, integration_vars, summation_indices) for t in expr.terms )) if not isinstance(expr, Product): return expr # --- Product: look for an inner Sum to collect --- sum_idx = None for i, f in enumerate(expr.factors): if isinstance(f, Sum): sum_idx = i break if sum_idx is None: return expr the_sum = expr.factors[sum_idx] outer_factors = list(expr.factors[:sum_idx]) + list(expr.factors[sum_idx + 1:]) # Separate outer factors by type outer_symbols: list[Symbol] = [] outer_other: list[Expr] = [] for f in outer_factors: if isinstance(f, Symbol): outer_symbols.append(f) else: outer_other.append(f) # Internal indices: those that appear in coupling symbols internal_indices: set[str] = set() for sym in outer_symbols: internal_indices.update(sym.indices) if not internal_indices and not integration_vars: return expr # nothing to collect # --- Group Sum terms by canonical diagram form --- # Pre-group by exact spatial signature to avoid redundant canonical # form computations. Terms with identical spatial signatures share # the same canonical form, so _canonical_diagram_form is called once # per spatial group instead of once per term. groups: dict[ tuple[tuple[str, str, str], ...], list[tuple[list[Propagator], list[Expr], dict[str, str]]], ] = {} ungroupable: list[Expr] = [] spatial_sig_groups: dict[ tuple[tuple[str, str, str], ...], list[tuple[list[Propagator], list[Expr]]], ] = {} for term in the_sum.terms: if isinstance(term, Product): props = [f for f in term.factors if isinstance(f, Propagator)] non_props = [f for f in term.factors if not isinstance(f, Propagator)] elif isinstance(term, Propagator): props = [term] non_props = [] else: ungroupable.append(term) continue if not props: ungroupable.append(term) continue sig = tuple(sorted( _canonical_edge(p.kind, p.spatial_left, p.spatial_right) for p in props )) spatial_sig_groups.setdefault(sig, []).append((props, non_props)) for sig, entries in spatial_sig_groups.items(): # Compute canonical form once for the group representative ref_props = entries[0][0] canonical_form, mapping = _canonical_diagram_form(ref_props, integration_vars) for props, non_props in entries: groups.setdefault(canonical_form, []).append((props, non_props, mapping)) # --- Merge each group --- new_sum_terms: list[Expr] = [] for canonical_form, group in groups.items(): if len(group) == 1: # Single term — reconstruct with coupling symbols included props, non_props, _ = group[0] parts: list[Expr] = list(outer_symbols) + non_props + props if len(parts) == 1: new_sum_terms.append(parts[0]) else: new_sum_terms.append(Product(tuple(parts))) continue # Multiple terms with same canonical diagram form ref_props, ref_non_props, ref_mapping = group[0] ref_inv = {v: k for k, v in ref_mapping.items()} coupling_terms: list[Expr] = [] for props, non_props, mapping in group: # Spatial permutation: other -> canonical -> ref spatial_perm: dict[str, str] = {} for orig, canon in mapping.items(): target = ref_inv.get(canon, canon) if orig != target: spatial_perm[orig] = target # Component index permutation via propagator matching comp_perm = _match_propagators_after_spatial( ref_props, props, spatial_perm, internal_indices, ) if comp_perm is None: # Cannot merge — add as separate term with coupling parts = list(outer_symbols) + non_props + props ungroupable.append( Product(tuple(parts)) if len(parts) > 1 else parts[0] ) continue # Full permutation full_perm = {**spatial_perm, **comp_perm} # Apply to outer coupling symbols + inner non-props permuted_factors: list[Expr] = [] for sym in outer_symbols: permuted_factors.append(_apply_perm_to_expr(sym, full_perm)) for np in non_props: permuted_factors.append(_apply_perm_to_expr(np, full_perm)) if not permuted_factors: coupling_terms.append(ONE) elif len(permuted_factors) == 1: coupling_terms.append(permuted_factors[0]) else: coupling_terms.append(Product(tuple(permuted_factors))) if not coupling_terms: continue # Build coupling sum if len(coupling_terms) == 1: coupling_expr: Expr = coupling_terms[0] else: # Fast path: sum bare Rationals directly if all(isinstance(t, Rational) for t in coupling_terms): total = sum(t.to_fraction() for t in coupling_terms) coupling_expr = Rational(total.numerator, total.denominator) else: coupling_expr = simplify(Sum(tuple(coupling_terms))) collected_term = Product((coupling_expr,) + tuple(ref_props)) new_sum_terms.append(collected_term) new_sum_terms.extend(ungroupable) if not new_sum_terms: return Product(tuple(outer_other)) if outer_other else ZERO if len(new_sum_terms) == 1: new_inner = new_sum_terms[0] else: new_inner = Sum(tuple(new_sum_terms)) if outer_other: return Product(tuple(outer_other) + (new_inner,)) return new_inner # Keep old name as alias for backward compatibility collect_by_topology = collect_by_diagram # --------------------------------------------------------------------------- # Diagonal propagator simplification # ---------------------------------------------------------------------------
[docs] def diagonal_propagators( expr: Expr, *, diag_R: bool = False, diag_C: bool = False, iso_R: bool = False, iso_C: bool = False, ) -> Expr: r"""Enforce diagonal and/or isotropic propagator constraints. When ``diag_R=True``, response propagators are diagonal: :math:`R_{ij}(x,y) = \delta_{ij}\,R(x,y)`. This substitutes the constraint :math:`i = j` into couplings and propagators, eliminating one summation index per constrained propagator. When ``diag_C=True``, the same is done for correlation propagators. When ``iso_R=True`` (implies ``diag_R=True``), all diagonal R entries are further assumed equal: :math:`R_{ii}(x,y) = R(x,y)`. The remaining component index is dropped from R propagators entirely. When ``iso_C=True`` (implies ``diag_C=True``), the same for C. Args: expr: Expression to simplify. diag_R: Enforce diagonal response propagators. diag_C: Enforce diagonal correlation propagators. iso_R: Enforce isotropic (index-free) response propagators. iso_C: Enforce isotropic (index-free) correlation propagators. Returns: Simplified expression with the requested constraints applied. """ diag_R = diag_R or iso_R diag_C = diag_C or iso_C if not diag_R and not diag_C: return expr result = _diag_walk(expr, diag_R, diag_C, iso_R, iso_C, {}) result = simplify(result) return _canonical_expr_indices(result)
def _diag_walk( expr: Expr, diag_R: bool, diag_C: bool, iso_R: bool, iso_C: bool, sum_dims: dict[str, int], ) -> Expr: """Recursively walk expression, applying diagonal propagator constraints. *sum_dims* maps summation-index names to their dimension, accumulated from ancestor :class:`SumOverIndex` wrappers. """ if isinstance(expr, SumOverIndex): new_sum_dims = {**sum_dims, expr.index_name: expr.dimension} new_body = _diag_walk(expr.body, diag_R, diag_C, iso_R, iso_C, new_sum_dims) if not _expr_uses_index(new_body, expr.index_name): return new_body return SumOverIndex(expr.index_name, expr.dimension, new_body) if isinstance(expr, IntegralOver): new_body = _diag_walk(expr.body, diag_R, diag_C, iso_R, iso_C, sum_dims) return IntegralOver(expr.variable, new_body) if isinstance(expr, Sum): new_terms = [_diag_walk(t, diag_R, diag_C, iso_R, iso_C, sum_dims) for t in expr.terms] new_terms = [t for t in new_terms if not _is_zero(t)] if not new_terms: return ZERO if len(new_terms) == 1: return new_terms[0] return Sum(tuple(new_terms)) if isinstance(expr, Product): has_props = any(isinstance(f, Propagator) for f in expr.factors) if has_props: return _diag_product(expr.factors, diag_R, diag_C, iso_R, iso_C, sum_dims) new_factors = [ _diag_walk(f, diag_R, diag_C, iso_R, iso_C, sum_dims) for f in expr.factors ] return Product(tuple(new_factors)) if isinstance(expr, Propagator): # Bare propagator (not inside a Product) — wrap and apply if ( expr.index_left is not None and expr.index_right is not None ): if expr.index_left != expr.index_right: if (expr.kind == "R" and diag_R) or (expr.kind == "C" and diag_C): return _diag_product((expr,), diag_R, diag_C, iso_R, iso_C, sum_dims) elif expr.index_left == expr.index_right: if (expr.kind == "R" and iso_R) or (expr.kind == "C" and iso_C): return Propagator(expr.kind, None, None, expr.spatial_left, expr.spatial_right) return expr return expr def _diag_product( factors: tuple[Expr, ...], diag_R: bool, diag_C: bool, iso_R: bool, iso_C: bool, sum_dims: dict[str, int], ) -> Expr: """Apply diagonal (and optionally isotropic) constraints to a Product that contains Propagators.""" sum_idx_set = set(sum_dims.keys()) constraints: list[tuple[str, str]] = [] for f in factors: if ( isinstance(f, Propagator) and f.index_left is not None and f.index_right is not None ): if (f.kind == "R" and diag_R) or (f.kind == "C" and diag_C): if f.index_left != f.index_right: constraints.append((f.index_left, f.index_right)) if not constraints: return Product(tuple(factors)) if len(factors) > 1 else factors[0] # Union-find for index equivalences (prefer external index as root) parent: dict[str, str] = {} def find(x: str) -> str: while x in parent: x = parent[x] return x for left, right in constraints: rl, rr = find(left), find(right) if rl == rr: continue if rr in sum_idx_set and rl not in sum_idx_set: parent[rr] = rl elif rl in sum_idx_set and rr not in sum_idx_set: parent[rl] = rr else: parent[rr] = rl sub: dict[str, str] = {} for idx in {k for pair in constraints for k in pair}: root = find(idx) if idx != root: sub[idx] = root if not sub: return Product(tuple(factors)) if len(factors) > 1 else factors[0] new_factors = [_apply_index_sub(f, sub) for f in factors] # Add KroneckerDelta for external-external constraints seen_deltas: set[frozenset[str]] = set() for idx, root in sub.items(): if idx not in sum_idx_set and root not in sum_idx_set: pair = frozenset({idx, root}) if pair not in seen_deltas: seen_deltas.add(pair) a, b = sorted(pair) new_factors.append(KroneckerDelta(a, b)) # Isotropic: strip equal component indices from R/C propagators if iso_R or iso_C: new_factors = [ Propagator(f.kind, None, None, f.spatial_left, f.spatial_right) if ( isinstance(f, Propagator) and f.index_left is not None and f.index_left == f.index_right and ((f.kind == "R" and iso_R) or (f.kind == "C" and iso_C)) ) else f for f in new_factors ] if len(new_factors) == 1: return new_factors[0] return Product(tuple(new_factors)) def _canonical_expr_indices(expr: Expr) -> Expr: """Rename SumOverIndex variables to i_0, i_1, ... in DFS first-appearance order.""" seen: list[str] = [] seen_set: set[str] = set() def _collect(e: Expr) -> None: if isinstance(e, SumOverIndex): if e.index_name not in seen_set: seen.append(e.index_name) seen_set.add(e.index_name) _collect(e.body) elif isinstance(e, IntegralOver): _collect(e.body) elif isinstance(e, Product): for child in e.factors: _collect(child) elif isinstance(e, Sum): for child in e.terms: _collect(child) _collect(expr) rename = {old: f"i_{j}" for j, old in enumerate(seen) if old != f"i_{j}"} if not rename: return expr return _apply_index_sub(expr, rename) def _apply_index_sub(expr: Expr, sub: dict[str, str]) -> Expr: """Apply an index-name substitution to an expression.""" if not sub: return expr if isinstance(expr, Propagator): il = sub.get(expr.index_left, expr.index_left) if expr.index_left else expr.index_left ir = sub.get(expr.index_right, expr.index_right) if expr.index_right else expr.index_right if il == expr.index_left and ir == expr.index_right: return expr return Propagator(expr.kind, il, ir, expr.spatial_left, expr.spatial_right) if isinstance(expr, Symbol): new_indices = tuple(sub.get(i, i) for i in expr.indices) new_spatial = tuple(sub.get(s, s) for s in expr.spatial_args) if new_indices == expr.indices and new_spatial == expr.spatial_args: return expr return Symbol(expr.name, new_indices, new_spatial) if isinstance(expr, Sum): return Sum(tuple(_apply_index_sub(t, sub) for t in expr.terms)) if isinstance(expr, Product): return Product(tuple(_apply_index_sub(f, sub) for f in expr.factors)) if isinstance(expr, KroneckerDelta): i1 = sub.get(expr.index1, expr.index1) i2 = sub.get(expr.index2, expr.index2) if i1 == expr.index1 and i2 == expr.index2: return expr return KroneckerDelta(i1, i2) if isinstance(expr, SumOverIndex): new_name = sub.get(expr.index_name, expr.index_name) new_body = _apply_index_sub(expr.body, sub) if new_name == expr.index_name and new_body is expr.body: return expr return SumOverIndex(new_name, expr.dimension, new_body) if isinstance(expr, IntegralOver): new_body = _apply_index_sub(expr.body, sub) if new_body is expr.body: return expr return IntegralOver(expr.variable, new_body) return expr def _multiply_into(expr: Expr, factor: int) -> Expr: """Multiply *expr* by an integer, pushing the factor into inner Products. Descends through :class:`SumOverIndex`, :class:`IntegralOver`, and :class:`Sum` wrappers so that the factor lands next to the existing :class:`Rational` coefficient, enabling :func:`simplify` to absorb it in a single pass. """ if isinstance(expr, Rational): return Rational(factor * expr.numerator, expr.denominator) if isinstance(expr, Product): return Product((Rational(factor, 1),) + expr.factors) if isinstance(expr, SumOverIndex): return SumOverIndex( expr.index_name, expr.dimension, _multiply_into(expr.body, factor), ) if isinstance(expr, IntegralOver): return IntegralOver(expr.variable, _multiply_into(expr.body, factor)) if isinstance(expr, Sum): return Sum(tuple(_multiply_into(t, factor) for t in expr.terms)) return Product((Rational(factor, 1), expr)) def _expr_uses_index(expr: Expr, name: str) -> bool: """Check whether *expr* references index *name* anywhere.""" if isinstance(expr, Propagator): return name == expr.index_left or name == expr.index_right if isinstance(expr, Symbol): return name in expr.indices if isinstance(expr, Product): return any(_expr_uses_index(f, name) for f in expr.factors) if isinstance(expr, Sum): return any(_expr_uses_index(t, name) for t in expr.terms) if isinstance(expr, SumOverIndex): if expr.index_name == name: return False # shadowed by inner sum return _expr_uses_index(expr.body, name) if isinstance(expr, IntegralOver): return _expr_uses_index(expr.body, name) if isinstance(expr, KroneckerDelta): return name == expr.index1 or name == expr.index2 return False