"""``System`` — the top-level user-facing spec for a stochastic-field
problem. Lowers high-level specification objects to raw
``PropagatorModel`` / ``Field`` / ``Vertex`` / ``Action`` / etc.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Callable, Iterable
import numpy as np
from sft_wick.action import Action
from sft_wick.evaluate import PropagatorModel
from sft_wick.fields import Field, FieldOperator, reset_uid_counter
from sft_wick.vertices import Vertex
from .specs import (
FieldSpec,
GaussianNoise,
LinearOp,
LocalVertex,
NonLocalVertex,
)
_OBS_PATTERN = re.compile(
r"""
^\s*
(?P<name>[a-zA-Z]+?) # field name (e.g. 'phi'),
# non-greedy so 'phi_a'
# parses as name='phi'
# + comp='a'
(?:_(?P<comp>[a-zA-Z][a-zA-Z0-9]*))? # optional component index
\(\s*(?P<spatial>[a-zA-Z0-9_]+)\s*\) # (spatial_arg)
\s*$
""",
re.VERBOSE,
)
[docs]
@dataclass(frozen=True)
class System:
"""High-level specification of a stochastic field-theory problem.
Combines all the physics inputs the user needs to describe in a
single immutable object. Use :meth:`expand` and
:meth:`propagators` to get to the computational layer.
Args:
field: :class:`FieldSpec` — the physical field (φ) spec; the
response field (ψ) is derived automatically.
linear: :class:`LinearOp` variant — defines R. Required
unless ``explicit_R`` is supplied.
noise: :class:`GaussianNoise` — κ² (+ optional σ²) defines C
together with R.
vertices: list of :class:`LocalVertex` — F^(n) local
interactions. May be empty (linear theory).
nonlocal_vertices: list of :class:`NonLocalVertex` — κ^(m) for
m ≥ 3 non-Gaussian driving contributions. May be empty.
t_min: lower time bound for propagator integrations.
explicit_R: escape hatch — if set, overrides ``linear``. The
structured alternative is the :class:`ExplicitR` LinearOp
variant; both work, but ``ExplicitR`` is preferred for
serialisation and YAML configs.
"""
field: FieldSpec
linear: LinearOp
noise: GaussianNoise
vertices: tuple[LocalVertex, ...] = field(default_factory=tuple)
nonlocal_vertices: tuple[NonLocalVertex, ...] = field(default_factory=tuple)
t_min: float = 0.0
explicit_R: Callable | None = None
def __post_init__(self):
# Normalise list → tuple for hashability.
if isinstance(self.vertices, list):
object.__setattr__(self, "vertices", tuple(self.vertices))
if isinstance(self.nonlocal_vertices, list):
object.__setattr__(self, "nonlocal_vertices", tuple(self.nonlocal_vertices))
# --------------------------------------------------------------- #
# Derived properties
# --------------------------------------------------------------- #
@property
def n_components(self) -> int:
return self.field.n_components
@property
def homogeneity(self) -> str:
"""Inferred from the noise κ² spec (``'translation'`` / …)."""
return self.noise.kappa2.homogeneity
@property
def iso_R(self) -> bool:
return self.linear.is_iso_R
# --------------------------------------------------------------- #
# Lowering to raw API objects
# --------------------------------------------------------------- #
[docs]
def build_propagator_model(self, diag_C: bool = True) -> PropagatorModel:
"""Lower the spec to a :class:`PropagatorModel`.
Args:
diag_C: when ``True`` (default), the numerical C propagator is
represented as a diagonal vector ``(n, N)``. When ``False``,
the full ``(n, N, N)`` matrix is preserved -- required for
observables that probe cross-component C entries
(e.g. lensing kappa-gamma cross-correlation). Off-diagonal
support is only meaningful in combination with
``Propagators.build(c_closed_form_only=True)`` because the
spline-table paths build only diagonal entries.
"""
R_time = (
self.explicit_R
if self.explicit_R is not None
else self.linear.build_R_callable()
)
kappa2 = self.noise.kappa2.build_callable(self.n_components)
sigma2 = (
None
if self.noise.sigma2 is None
else self.noise.sigma2.build_callable(self.n_components)
)
return PropagatorModel(
R_time=R_time,
kappa2=kappa2,
n_components=self.n_components,
iso_R=self.iso_R,
diag_C=diag_C,
t_min=self.t_min,
sigma2=sigma2,
)
[docs]
def build_fields(self) -> tuple[Field, Field]:
"""Return ``(phi, psi)`` :class:`Field` objects."""
phi = Field(self.field.name, "physical",
n_components=self.n_components)
psi = Field(self.field.response_name, "response",
n_components=self.n_components)
return phi, psi
[docs]
def build_action(self) -> Action:
"""Lower vertices to a raw :class:`Action`."""
phi, psi = self.build_fields()
raw_vertices: list[Vertex] = []
for lv in self.vertices:
rank = np.asarray(lv.coupling).ndim
# Convention: first axis = ψ, rest = φ
fields = [psi] + [phi] * (rank - 1)
raw_vertices.append(
Vertex(fields=fields, coupling=lv.name, local=True)
)
for nv in self.nonlocal_vertices:
raw_vertices.append(
Vertex(
fields=[psi] * nv.order,
coupling=nv.name,
local=False,
equal_time=nv.equal_time,
already_R_contracted=nv.already_R_contracted,
)
)
return Action(vertices=raw_vertices)
[docs]
def build_coupling_values(self) -> dict[str, Any]:
"""Dict passed to ``DiagramTerm.evaluate_coupling`` /
``build_integrand``.
The MSR prefactors are applied here:
- **Local** ``F^(n)``: multiplied by ``-i`` (``F_MSR = -i F``).
- **Non-local** ``κ^(m)``: multiplied by ``-(i^m) / m!``
(``K_MSR = -(i^m)/m! κ^(m)``).
Users pass the **bare** physical tensors when constructing
:class:`LocalVertex` / :class:`NonLocalVertex`; this method
is the single point of truth for the MSR convention. For a
callable non-local coupling the returned dict value is a
wrapped callable that applies the factor at evaluation time.
"""
cv: dict[str, Any] = {}
for lv in self.vertices:
cv[lv.name] = lv.msr_coupling
for nv in self.nonlocal_vertices:
cv[nv.name] = nv.msr_coupling
return cv
# --------------------------------------------------------------- #
# Public operations
# --------------------------------------------------------------- #
[docs]
def expand(
self,
observable: Iterable,
orders: Iterable[int],
*,
response_phase: bool = True,
ito: bool = True,
collect_topology: bool = True,
iso_R: bool | None = None,
diag_R: bool = True,
diag_C: bool = True,
iso_C: bool = False,
cache_path: Any = None,
) -> "Expansion":
"""Run the perturbative expansion and return an
:class:`Expansion` object.
Args:
observable: iterable of either ``FieldOperator`` objects
or strings like ``"phi_a(x)"``. Strings are parsed
as ``field_compIndex(spatialArg)``.
orders: iterable of perturbative orders to compute.
response_phase, ito, collect_topology, iso_R, diag_R, diag_C, iso_C: forwarded to :func:`~sft_wick.perturbation.compute_moment`. When ``iso_R`` is ``None`` the value is inferred from ``self.linear``.
cache_path: directory (or file) for on-disk caching.
``None`` disables caching (prints a one-shot
reminder).
"""
from sft_wick.perturbation import compute_moment
from .cache import load_or_compute
from .expansion import Expansion
orders_list = sorted(set(int(o) for o in orders))
iso_R_val = self.iso_R if iso_R is None else iso_R
obs_ops, obs_repr = _parse_observable(observable, self)
spec_key = {
"system_hash": _system_spec_key(self),
"observable": obs_repr,
"orders": tuple(orders_list),
"flags": (
response_phase, ito, collect_topology,
iso_R_val, diag_R, diag_C, iso_C,
),
}
max_order = max(orders_list)
def _compute():
reset_uid_counter()
action = self.build_action()
result = compute_moment(
obs_ops,
action,
order=max_order,
response_phase=response_phase,
ito=ito,
collect_topology=collect_topology,
diag_R=diag_R,
diag_C=diag_C,
iso_R=iso_R_val,
iso_C=iso_C,
)
return {
"dts_by_order": {
o: result.diagram_terms(o) for o in orders_list
},
"raw_result": result,
}
payload = load_or_compute(
cache_path, spec_key, _compute,
operation_name="expansion",
)
return Expansion(
system=self,
dts_by_order=payload["dts_by_order"],
orders=tuple(orders_list),
observable_repr=obs_repr,
raw_result=payload["raw_result"],
)
[docs]
def propagators(
self,
t_max: float,
n_grid_t: int = 60,
*,
homogeneity: str | None = None,
r_max: float | None = None,
n_grid_r: int | None = None,
n_grid_cos: int | None = None,
x_max: float | None = None,
n_grid_x: int | None = None,
n_jobs: int = 1,
c_closed_form: Callable | None = None,
cache_path: Any = None,
interp_method: str = "linear",
c_closed_form_only: bool = False,
c_closed_form_vectorized: bool = False,
c_method: str = "dblquad",
c_n_gauss: int = 20,
diag_C: bool = True,
) -> "Propagators":
"""Build a :class:`Propagators` object with a precomputed
spatial table matching ``self.homogeneity``.
By default (all grid args ``None``) the cache enters **lazy
mode** for the inferred homogeneity — recommended for moment
calculations at a small fixed set of external positions.
Args:
t_max: upper time bound.
n_grid_t: number of time grid points.
homogeneity: override the inferred homogeneity
(``'translation' | 'rotation' | 'general'``).
r_max, n_grid_r: translation full-grid parameters
(both given ⇒ full 3-D spline; either ``None`` ⇒
lazy).
n_grid_cos: rotation full-grid parameter (given ⇒ full
3-D spline over ``cos θ``).
x_max, n_grid_x: general full-grid parameters
(both given ⇒ full 4-D spline).
n_jobs: parallel workers for grid build (``-1`` = all
cores, via joblib).
c_closed_form: optional callable
``(n1, t1, n2, t2) -> (N, N)`` returning the full C
matrix directly. When provided, the spline-table
builder bypasses ``scipy.integrate.dblquad`` — use
this when the user knows C in closed form (e.g. OU
kernel).
cache_path: directory (or file) for on-disk caching of
the constructed :class:`PropagatorCache`.
interp_method: ``RegularGridInterpolator`` method for the
full-grid C tables. ``'linear'`` (default) is monotone
and safe for steep tails; ``'cubic'`` gives O(h⁴) on
smooth grids. Forwarded to :class:`PropagatorCache`.
c_closed_form_only: when True (with ``c_closed_form`` set),
skip every spline and route C lookups directly through
the user's c_fn -- machine-precision agreement with
the analytical form. Forwarded to
:meth:`Propagators.build`.
c_closed_form_vectorized: c_fn accepts batched arrays and
returns ``(n, N, N)`` (only meaningful with
``c_closed_form_only=True``).
c_method: Quadrature method for the inner C-propagator
``∫ R κ² R`` integral when the cache builds its table.
- ``'dblquad'`` (default) -- adaptive Gauss-Kronrod;
robust on any κ² but slow (10-80 ms / cell).
- ``'gauss_legendre'`` -- tensor-product GL with a
diagonal-aware sub-region split. Recommended for
any κ² that is **piecewise analytic** -- the
package's exponential / Gaussian / Legendre
kernels all qualify. 18-100× faster than dblquad
at machine precision (``c_n_gauss=20``). Not the
right tool for non-smooth κ² (discontinuous,
oscillatory at high frequency, integrable
singularities like ``1/√t``).
Ignored when ``c_closed_form_only=True`` (the user's
C_fn is exact and bypasses both quadrature paths).
See :doc:`/user_guide/workflow` "Choosing an
integrator" for the full decision matrix.
c_n_gauss: Per-dim GL node count for
``c_method='gauss_legendre'`` (default 20 → machine
precision on smooth kernels). Cost ``c_n_gauss²`` per
sub-region.
diag_C: when ``True`` (default), the numerical C propagator is
represented as a diagonal vector ``(n, N)`` -- the
bit-identical behaviour for every pre-existing caller.
When ``False``, the full ``(n, N, N)`` matrix is preserved
so observables can read off-diagonal entries
``C[a, b]`` with ``a != b`` (e.g. the lensing
κ-γ₊ cross-correlation). Only meaningful with
``c_closed_form_only=True``: spline-table paths build
diagonal entries only.
"""
from .propagators import Propagators
return Propagators.build(
self,
t_max=t_max,
n_grid_t=n_grid_t,
homogeneity=homogeneity,
r_max=r_max,
n_grid_r=n_grid_r,
n_grid_cos=n_grid_cos,
x_max=x_max,
n_grid_x=n_grid_x,
n_jobs=n_jobs,
c_closed_form=c_closed_form,
cache_path=cache_path,
interp_method=interp_method,
c_closed_form_only=c_closed_form_only,
c_closed_form_vectorized=c_closed_form_vectorized,
c_method=c_method,
c_n_gauss=c_n_gauss,
diag_C=diag_C,
)
# =========================================================================
# Helpers
# =========================================================================
def _parse_observable(observable, system: System):
"""Normalise ``observable`` into a tuple of FieldOperators.
Accepts:
- ``("phi_a(x)", "phi_b(y)")`` — string form; each string is
``name_component(spatial)``.
- an iterable of ``FieldOperator`` objects (advanced).
Returns ``(obs_ops, obs_repr)`` where ``obs_repr`` is a hashable
tuple describing the observable for caching.
"""
ops = list(observable)
if all(isinstance(o, str) for o in ops):
phi, _psi = system.build_fields()
out = []
repr_list = []
for spec in ops:
m = _OBS_PATTERN.match(spec)
if m is None:
raise ValueError(
f"Cannot parse observable '{spec}'. Expected form "
f"'name_component(spatial)' like 'phi_a(x)'."
)
name = m.group("name")
comp = m.group("comp")
spatial = m.group("spatial")
if name != system.field.name:
raise ValueError(
f"Observable field '{name}' does not match the "
f"system's field '{system.field.name}'."
)
if comp is None:
if system.n_components != 1:
raise ValueError(
f"Observable '{spec}' omits a component index "
f"but the system has n_components="
f"{system.n_components}. Use e.g. "
f"'{name}_a({spatial})'."
)
out.append(phi(spatial))
repr_list.append((name, None, spatial))
else:
out.append(phi(comp, spatial))
repr_list.append((name, comp, spatial))
return tuple(out), tuple(repr_list)
if all(isinstance(o, FieldOperator) for o in ops):
repr_list = tuple(
(o.field.name,
tuple(o.component_indices) if o.component_indices else None,
o.spatial_arg)
for o in ops
)
return tuple(ops), repr_list
raise TypeError(
"observable must be either an iterable of 'name_comp(spatial)' "
"strings or an iterable of FieldOperator objects; got mixed/"
"unknown types."
)
def _system_spec_key(system: System) -> Any:
"""Canonicalised representation of a ``System`` for caching.
Callables inside kernel helpers are captured by their ``repr()``
when the class is a frozen dataclass (joblib's content-hash
handles them), but we include a light summary string in case the
user passes raw ``lambda``s through escape hatches.
"""
return {
"field": (system.field.name, system.field.n_components),
"linear": repr(system.linear),
"noise_kappa2_type": type(system.noise.kappa2).__name__,
"noise_kappa2_repr": repr(system.noise.kappa2),
"noise_sigma2_repr": repr(system.noise.sigma2),
"vertices": tuple(
(v.name, np.asarray(v.coupling).shape,
float(np.sum(np.abs(np.asarray(v.coupling)))))
for v in system.vertices
),
"nonlocal_vertices": tuple(
(v.name, v.order, v.equal_time, v.already_R_contracted)
for v in system.nonlocal_vertices
),
"t_min": system.t_min,
}