"""Custom symbolic expression tree for Wick contraction results.
All expression types are frozen dataclasses (immutable + hashable).
Uses exact rational arithmetic via fractions.Fraction.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from fractions import Fraction
from math import gcd
from typing import Sequence
[docs]
class Expr(ABC):
"""Base class for all symbolic expressions."""
[docs]
@abstractmethod
def to_latex(self) -> str:
...
def __add__(self, other: Expr) -> Sum:
return Sum(_flatten_sum([self, other]))
def __radd__(self, other: object) -> Expr:
if isinstance(other, int) and other == 0:
return self
return NotImplemented
def __mul__(self, other: Expr) -> Product:
return Product(_flatten_product([self, other]))
def __rmul__(self, other: object) -> Expr:
if isinstance(other, (int, Fraction)):
return Product((Rational.from_number(other), self))
return NotImplemented
def __neg__(self) -> Product:
return Product((Rational(-1, 1), self))
def __sub__(self, other: Expr) -> Sum:
return self + (-other)
@abstractmethod
def __eq__(self, other: object) -> bool:
...
@abstractmethod
def __hash__(self) -> int:
...
def __repr__(self) -> str:
return self.to_latex()
# ---------------------------------------------------------------------------
# Scalar / numeric expressions
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class Rational(Expr):
"""Exact rational number."""
numerator: int
denominator: int = 1
def __post_init__(self) -> None:
if self.denominator == 0:
raise ZeroDivisionError("Rational denominator cannot be zero")
# Normalize: keep denominator positive, reduce
d = gcd(abs(self.numerator), abs(self.denominator))
sign = 1 if self.denominator > 0 else -1
object.__setattr__(self, "numerator", sign * self.numerator // d)
object.__setattr__(self, "denominator", sign * self.denominator // d)
[docs]
@classmethod
def from_number(cls, n: int | Fraction) -> Rational:
if isinstance(n, Fraction):
return cls(n.numerator, n.denominator)
return cls(n, 1)
@property
def is_zero(self) -> bool:
return self.numerator == 0
@property
def is_one(self) -> bool:
return self.numerator == 1 and self.denominator == 1
[docs]
def to_fraction(self) -> Fraction:
return Fraction(self.numerator, self.denominator)
[docs]
def to_latex(self) -> str:
if self.denominator == 1:
return str(self.numerator)
if self.numerator < 0:
return rf"-\frac{{{-self.numerator}}}{{{self.denominator}}}"
return rf"\frac{{{self.numerator}}}{{{self.denominator}}}"
def __mul__(self, other: object) -> Expr:
if isinstance(other, Rational):
f = self.to_fraction() * other.to_fraction()
return Rational(f.numerator, f.denominator)
if isinstance(other, Expr):
return Product(_flatten_product([self, other]))
return NotImplemented
def __add__(self, other: object) -> Expr:
if isinstance(other, Rational):
f = self.to_fraction() + other.to_fraction()
return Rational(f.numerator, f.denominator)
if isinstance(other, Expr):
return Sum(_flatten_sum([self, other]))
return NotImplemented
def __eq__(self, other: object) -> bool:
if isinstance(other, Rational):
return self.numerator == other.numerator and self.denominator == other.denominator
return NotImplemented
def __hash__(self) -> int:
return hash((Rational, self.numerator, self.denominator))
ZERO = Rational(0, 1)
ONE = Rational(1, 1)
MINUS_ONE = Rational(-1, 1)
# ---------------------------------------------------------------------------
# Named symbols (coupling constants, etc.)
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class Symbol(Expr):
"""A named symbol, possibly with indices.
Examples:
Symbol('F', ('i', 'j', 'k')) -> F_{ijk}
Symbol('K', ('i', 'j'), ('y_0', 'y_1')) -> K_{ij}(y_0, y_1)
"""
name: str
indices: tuple[str, ...] = ()
spatial_args: tuple[str, ...] = ()
[docs]
def to_latex(self) -> str:
s = self.name
if self.indices:
s += "_{" + "".join(_latex_index(i) for i in self.indices) + "}"
if self.spatial_args:
s += "(" + ", ".join(_latex_index(a) for a in self.spatial_args) + ")"
return s
def __eq__(self, other: object) -> bool:
if isinstance(other, Symbol):
return (self.name == other.name
and self.indices == other.indices
and self.spatial_args == other.spatial_args)
return NotImplemented
def __hash__(self) -> int:
return hash((Symbol, self.name, self.indices, self.spatial_args))
# ---------------------------------------------------------------------------
# Propagators
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class Propagator(Expr):
"""A two-point function C_{ij}(x, x') or R_{ij}(x, x').
For scalar fields, index_left and index_right are None.
Convention for R: the physical field's index/position is always on the left.
R_{ij}(x, x') means <phi_i(x) psi_j(x')>_{S_0}.
"""
kind: str # 'C' or 'R'
index_left: str | None
index_right: str | None
spatial_left: str
spatial_right: str
[docs]
def to_latex(self) -> str:
s = self.kind
if self.index_left is not None and self.index_right is not None:
s += "_{" + _latex_index(self.index_left) + _latex_index(self.index_right) + "}"
s += "(" + _latex_index(self.spatial_left) + ", " + _latex_index(self.spatial_right) + ")"
return s
def __eq__(self, other: object) -> bool:
if isinstance(other, Propagator):
return (self.kind == other.kind
and self.index_left == other.index_left
and self.index_right == other.index_right
and self.spatial_left == other.spatial_left
and self.spatial_right == other.spatial_right)
return NotImplemented
def __hash__(self) -> int:
return hash((Propagator, self.kind, self.index_left, self.index_right,
self.spatial_left, self.spatial_right))
# ---------------------------------------------------------------------------
# Composite expressions
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True, init=False)
class Sum(Expr):
"""Sum of expressions."""
terms: tuple[Expr, ...]
def __init__(self, terms: Sequence[Expr]) -> None:
object.__setattr__(self, "terms", _flatten_sum(terms))
[docs]
def to_latex(self) -> str:
if not self.terms:
return "0"
parts: list[str] = []
for i, t in enumerate(self.terms):
s = t.to_latex()
if i > 0 and not s.startswith("-"):
parts.append("+ " + s)
else:
parts.append(s)
return " ".join(parts)
def __eq__(self, other: object) -> bool:
if isinstance(other, Sum):
return self.terms == other.terms
return NotImplemented
def __hash__(self) -> int:
return hash((Sum, self.terms))
def __add__(self, other: Expr) -> Sum:
if isinstance(other, Sum):
return Sum(self.terms + other.terms)
return Sum(self.terms + (other,))
[docs]
@dataclass(frozen=True, init=False)
class Product(Expr):
"""Product of expressions."""
factors: tuple[Expr, ...]
def __init__(self, factors: Sequence[Expr]) -> None:
object.__setattr__(self, "factors", _flatten_product(factors))
[docs]
def to_latex(self) -> str:
if not self.factors:
return "1"
parts = []
for f in self.factors:
s = f.to_latex()
if isinstance(f, Sum) and len(f.terms) > 1:
s = r"\left(" + s + r"\right)"
parts.append(s)
return " ".join(parts)
def __eq__(self, other: object) -> bool:
if isinstance(other, Product):
return self.factors == other.factors
return NotImplemented
def __hash__(self) -> int:
return hash((Product, self.factors))
def __mul__(self, other: Expr) -> Product:
if isinstance(other, Product):
return Product(self.factors + other.factors)
return Product(self.factors + (other,))
# ---------------------------------------------------------------------------
# Index/spatial wrappers
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class SumOverIndex(Expr):
"""Summation over a component index: sum_{i=1}^{N} body."""
index_name: str
dimension: int
body: Expr
[docs]
def to_latex(self) -> str:
return rf"\sum_{{{_latex_index(self.index_name)}=1}}^{{{self.dimension}}} {self.body.to_latex()}"
def __eq__(self, other: object) -> bool:
if isinstance(other, SumOverIndex):
return (self.index_name == other.index_name
and self.dimension == other.dimension
and self.body == other.body)
return NotImplemented
def __hash__(self) -> int:
return hash((SumOverIndex, self.index_name, self.dimension, self.body))
[docs]
@dataclass(frozen=True)
class IntegralOver(Expr):
"""Integration over a spatial variable: integral d(var) body."""
variable: str
body: Expr
[docs]
def to_latex(self) -> str:
return rf"\int \mathrm{{d}}{_latex_index(self.variable)}\, {self.body.to_latex()}"
def __eq__(self, other: object) -> bool:
if isinstance(other, IntegralOver):
return self.variable == other.variable and self.body == other.body
return NotImplemented
def __hash__(self) -> int:
return hash((IntegralOver, self.variable, self.body))
# ---------------------------------------------------------------------------
# Delta functions
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class KroneckerDelta(Expr):
"""delta_{ij} for component indices."""
index1: str
index2: str
[docs]
def to_latex(self) -> str:
return rf"\delta_{{{_latex_index(self.index1)}{_latex_index(self.index2)}}}"
def __eq__(self, other: object) -> bool:
if isinstance(other, KroneckerDelta):
return {self.index1, self.index2} == {other.index1, other.index2}
return NotImplemented
def __hash__(self) -> int:
return hash((KroneckerDelta, frozenset({self.index1, self.index2})))
[docs]
@dataclass(frozen=True)
class DiracDelta(Expr):
"""delta(x - y) for spatial arguments."""
arg1: str
arg2: str
[docs]
def to_latex(self) -> str:
return rf"\delta({self.arg1} - {self.arg2})"
def __eq__(self, other: object) -> bool:
if isinstance(other, DiracDelta):
return {self.arg1, self.arg2} == {other.arg1, other.arg2}
return NotImplemented
def __hash__(self) -> int:
return hash((DiracDelta, frozenset({self.arg1, self.arg2})))
# ---------------------------------------------------------------------------
# Imaginary unit
# ---------------------------------------------------------------------------
[docs]
@dataclass(frozen=True)
class ImaginaryUnit(Expr):
r"""The imaginary unit :math:`\mathrm{i}`.
Used to represent the phase factor :math:`(-\mathrm{i})^n` that arises
from the MSR convention :math:`\langle\phi\,\psi\rangle \propto -\mathrm{i}\,R`.
"""
[docs]
def to_latex(self) -> str:
return r"\mathrm{i}"
def __eq__(self, other: object) -> bool:
if isinstance(other, ImaginaryUnit):
return True
return NotImplemented
def __hash__(self) -> int:
return hash(ImaginaryUnit)
I = ImaginaryUnit()
"""Module-level constant for the imaginary unit."""
# ---------------------------------------------------------------------------
# Response phase utility
# ---------------------------------------------------------------------------
def _count_r_deep(expr: Expr) -> int:
"""Count R propagators recursively through the expression tree.
For a **Sum**, returns the count from the first term (within a
single vertex combination all surviving pairings have the same R
count). Returns -1 if terms disagree.
"""
if isinstance(expr, Propagator):
return 1 if expr.kind == "R" else 0
if isinstance(expr, Product):
return sum(_count_r_deep(f) for f in expr.factors)
if isinstance(expr, Sum) and expr.terms:
first = _count_r_deep(expr.terms[0])
if all(_count_r_deep(t) == first for t in expr.terms[1:]):
return first
return -1 # terms disagree
if isinstance(expr, (IntegralOver, SumOverIndex)):
return _count_r_deep(expr.body)
return 0
[docs]
def apply_response_phase(expr: Expr) -> Expr:
r"""Multiply each term by :math:`(-\mathrm{i})^n` where *n* is the
number of response propagators *R* in that term.
This implements the MSR convention
:math:`\langle\phi(a)\,\psi(b)\rangle = -\mathrm{i}\,R(a,b)`.
The phase is **absorbed into the existing rational coefficient**
so that the factored form of the expression is preserved. For
example, :math:`(-\mathrm{i})^3 = \mathrm{i}` applied to a term
with prefactor :math:`-\tfrac{1}{6}` yields
:math:`-\tfrac{\mathrm{i}}{6}`.
"""
if isinstance(expr, Sum):
return Sum(tuple(apply_response_phase(t) for t in expr.terms))
if isinstance(expr, Product):
n_r = _count_r_deep(expr)
if n_r <= 0:
return expr # 0 R's or disagreeing terms
r = n_r % 4
if r == 0:
return expr
factors = list(expr.factors)
# Find the existing Rational coefficient (if any)
rat_idx = next(
(i for i, f in enumerate(factors) if isinstance(f, Rational)),
None,
)
# (-i)^n phase rules applied to existing coefficient c:
# r=1: (-i) → coeff becomes -c, insert i
# r=2: (-1) → coeff becomes -c
# r=3: (+i) → coeff stays c, insert i
if r in (1, 2):
# Negate the rational coefficient
if rat_idx is not None:
old = factors[rat_idx]
factors[rat_idx] = Rational(-old.numerator, old.denominator)
else:
factors.insert(0, Rational(-1, 1))
rat_idx = 0
if r in (1, 3):
# Insert imaginary unit right after the rational
insert_pos = (rat_idx + 1) if rat_idx is not None else 0
factors.insert(insert_pos, ImaginaryUnit())
# Drop Rational(1) — it's the identity and just adds clutter.
factors = [f for f in factors
if not (isinstance(f, Rational) and f.is_one)]
if not factors:
return ONE
return Product(tuple(factors))
if isinstance(expr, Propagator) and expr.kind == "R":
return Product((Rational(-1, 1), ImaginaryUnit(), expr))
if isinstance(expr, IntegralOver):
return IntegralOver(expr.variable, apply_response_phase(expr.body))
if isinstance(expr, SumOverIndex):
return SumOverIndex(
expr.index_name, expr.dimension, apply_response_phase(expr.body)
)
return expr
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _latex_index(name: str) -> str:
"""Escape an index name for LaTeX.
Converts ``i_0`` to ``i_{0}``, ``i_10`` to ``i_{10}``, ``y_0`` to
``y_{0}``, etc. Single-character names (``a``, ``b``) are
returned unchanged.
"""
if "_" in name:
prefix, suffix = name.split("_", 1)
return prefix + "_{" + suffix + "}"
return name
def _flatten_sum(terms: Sequence[Expr]) -> tuple[Expr, ...]:
"""Recursively flatten nested Sums."""
result: list[Expr] = []
for t in terms:
if isinstance(t, Sum):
result.extend(t.terms)
else:
result.append(t)
return tuple(result)
def _flatten_product(factors: Sequence[Expr]) -> tuple[Expr, ...]:
"""Recursively flatten nested Products."""
result: list[Expr] = []
for f in factors:
if isinstance(f, Product):
result.extend(f.factors)
else:
result.append(f)
return tuple(result)