Source code for sft_wick.workflow.expansion

"""``Expansion`` — diagram-level view of the perturbative expansion.

Everything a user wants to do between ``compute_moment`` and the final
numeric result: inspect diagrams, classify by vertex composition,
draw, render LaTeX, integrate point-by-point or as a sweep.
"""

from __future__ import annotations

import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Iterable, Mapping


from sft_wick.evaluate import integrate_moment
from sft_wick.expressions import Expr, Product, Symbol, Sum, Rational


[docs] @dataclass(frozen=True) class Expansion: """Result of :meth:`System.expand`. Opaque to construct; use the accessors below. Attributes: system: the :class:`System` this expansion was built from. dts_by_order: ``{order: [DiagramTerm, ...]}``. orders: tuple of computed orders (sorted). observable_repr: hashable description of the observable. raw_result: the underlying :class:`PerturbativeResult` (advanced users). """ system: Any # sft_wick.workflow.System (forward decl avoids cycle) dts_by_order: Mapping[int, list] orders: tuple observable_repr: tuple raw_result: Any # --------------------------------------------------------------- # # Inspection # --------------------------------------------------------------- #
[docs] def diagrams(self, order: int) -> list: """All :class:`DiagramTerm` objects at this order.""" return list(self.dts_by_order[order])
[docs] def summary(self) -> dict[int, dict[str, int]]: """Per-order count, plus a ``vertex_type`` histogram.""" out: dict[int, dict[str, int]] = {} for o in self.orders: dts = self.dts_by_order[o] vtypes: dict[str, int] = defaultdict(int) ncross: dict[int, int] = defaultdict(int) for dt in dts: vtypes[self._vertex_type_label(dt)] += 1 ncross[count_cross_group_c(dt)] += 1 out[o] = { "n_diagrams": len(dts), "by_vertex_type": dict(vtypes), "by_n_cross_C": dict(ncross), } return out
[docs] def by_vertex_type(self, order: int) -> dict[str, list]: """Group this order's diagrams by vertex composition label. Label format: the concatenation of the *unique sorted* coupling-symbol names appearing in the diagram's ``coupling_sum``. E.g.: - A diagram using only the local ``F`` vertex → ``"F"``. - Order-2 diagram mixing local ``F`` and non-local ``K`` (demo2 FK channel) → ``"FK"``. - Pure-``K`` diagram → ``"K"``. The label is a *set* of vertex types, not a multiset — this matches demo2's classification convention. """ groups: dict[str, list] = defaultdict(list) for dt in self.dts_by_order[order]: groups[self._vertex_type_label(dt)].append(dt) return dict(groups)
[docs] def latex(self, order: int) -> str: """Concatenated LaTeX of every diagram at this order.""" parts = [] for dt in self.dts_by_order[order]: parts.append(dt.to_latex()) return r"\; + \;".join(parts) if parts else "0"
[docs] def plot(self, order: int, i: int = 0, **kwargs): """Render the ``i``-th diagram at ``order`` via the package's :class:`DiagramRenderer`. Args: order: Perturbative order to look up in :attr:`raw_result`. i: Index of the diagram within that order. **kwargs: Renderer-level kwargs (``figsize``, ``style``, ``external_label_fn``, ``vertex_label_fn``, ``label_format``) plus per-call ``draw`` kwargs (``ax``, ``title``, ``external_labels``, ``vertex_labels``, ``positions``, ``show_legend``). Returns: The matplotlib :class:`~matplotlib.figure.Figure` containing the rendered diagram. """ from sft_wick import DiagramRenderer renderer_keys = { "figsize", "style", "external_label_fn", "vertex_label_fn", "label_format", } renderer_kwargs = { k: kwargs.pop(k) for k in list(kwargs) if k in renderer_keys } fd = self.raw_result.diagrams_by_order[order][i].to_feynman_diagram() renderer = DiagramRenderer(**renderer_kwargs) ax = renderer.draw(fd, **kwargs) return ax.figure
# --------------------------------------------------------------- # # Numerical evaluation # --------------------------------------------------------------- #
[docs] def evaluate( self, propagators, *, positions: dict[str, Any], t_final: float, component_pair: tuple = (0, 0), orders: Iterable[int] | None = None, vertex_types: Iterable[str] | None = None, integrate_over: Any = None, method: str = "qmc_vectorized", n_samples: int = 2 ** 13, seed: int | None = 42, n_jobs: int = 1, n_gauss: int = 8, ): """Integrate the expansion at a single ``(positions, t_final, component_pair)`` point. Returns a :class:`Result`. Args: propagators: :class:`Propagators` from :meth:`System.propagators`. positions: ``{spatial_arg: x_value}`` mapping — e.g. ``{"x": 0.0, "y": 0.5}``. t_final: upper time bound for external-time integration (``lambda_f``). component_pair: ``(a, b)`` component indices for the observable (e.g. ``(1, 1)``). For scalar observables use ``(0, 0)``. orders: subset of the expansion's orders; ``None`` uses all. vertex_types: subset of the vertex-composition labels (e.g. ``{"F"}``, ``{"FK"}``) to include — labels match :meth:`by_vertex_type` keys. ``None`` ⇒ all. Useful for computing a single channel, or for skipping channels that require a bespoke integrator (e.g. non-local K whose coupling is spacetime-dependent; see demo2). integrate_over: Controls which **external** points have their time integrated. - ``None`` (default — **physics observable**): all externals held fixed at ``t_final``. Matches the equal-time correlator ``⟨φ(t_f) · φ(t_f)⟩`` that is compared to MC data and demo notebooks. - ``"all"``: all externals integrated over ``[t_min, t_final]`` — the time-integrated moment ``⟨∫φ(t)dt · ∫φ(t')dt'⟩``. Natural e.g. for weak-lensing line-of-sight integrals. - Iterable of external-point names: mixed — those listed are integrated, others fixed. E.g. ``{"x"}`` for a source integrated along the line of sight × a detector field at ``t_final``. method: time-integrator selector. Recommended choice depends on the diagram's number of internal time-integration variables ``d = len(time_integration_vars)`` and integrand smoothness: .. list-table:: :header-rows: 1 :widths: 28 38 34 * - Method - Best for - Trade-off * - ``'qmc_vectorized'`` (default) - ``d >= 6`` / non-smooth integrand - ``~ 1/sqrt(n_samples)`` bias * - ``'gauss_legendre'`` - ``d <= 5`` smooth (the typical sft-wick case) - exponential convergence in ``n_gauss``; cost ``n_gauss^d`` * - ``'nquad'`` - Adaptive 1-3D fallback - slow; raises ``NotImplementedError`` on dynamic-coupling * - ``'qmc'`` / ``'qmc_scalar'`` - Compatibility / debugging - slow Python loop See :doc:`/user_guide/workflow` "Choosing an integrator" for the full decision matrix and worked examples. n_samples, seed: forwarded to the integrator (QMC only). n_gauss: nodes per dimension for ``method='gauss_legendre'`` (default 8 — exact for polynomials up to degree 15). Cost scales as ``n_gauss^d``; bump to 12-20 for stiff integrands at large ``t_final``. """ from .result import Result orders_list = ( sorted(set(int(o) for o in orders)) if orders is not None else list(self.orders) ) vtype_filter = ( None if vertex_types is None else set(vertex_types) ) coupling_values = self.system.build_coupling_values() fi = _component_indices(component_pair, self.observable_repr) # Collect tasks (diagram_term + metadata) up-front, in stable order, # then dispatch the whole batch to ``integrate_diagrams`` which handles # the sequential vs joblib loky path internally (n_jobs=1 stays serial, # bit-identical to the pre-refactor loop). tasks: list[tuple[int, int, str, Any]] = [] for order in orders_list: for i, dt in enumerate(self.dts_by_order[order]): vtype = self._vertex_type_label(dt) if vtype_filter is not None and vtype not in vtype_filter: continue tasks.append((order, i, vtype, dt)) diagram_terms = [task[3] for task in tasks] from sft_wick.evaluate import integrate_diagrams _total, details = integrate_diagrams( diagram_terms, coupling_values=coupling_values, lambda_f=t_final, cache=propagators.cache, method=method, n_samples=n_samples, seed=seed, fixed_indices=fi, n_jobs=n_jobs, positions=positions, integrate_over=integrate_over, n_gauss=n_gauss, ) per_diagram = [] per_order: dict[int, float] = defaultdict(float) per_vtype: dict[str, float] = defaultdict(float) total = 0.0 for (order, i, vtype, dt), (val, err) in zip(tasks, details): per_diagram.append({ "order": order, "diagram_idx": i, "vertex_type": vtype, "n_cross_C": count_cross_group_c(dt), "value": val, "error": err, }) per_order[order] += val per_vtype[vtype] += val total += val return Result( total=total, by_order=dict(per_order), by_vertex_type=dict(per_vtype), per_diagram=per_diagram, positions=dict(positions), t_final=t_final, component_pair=tuple(component_pair), n_samples=n_samples, seed=seed, )
[docs] def sweep( self, propagators, *, positions_grid: dict[str, list], t_final_grid: list, component_pairs: Iterable[tuple] = ((0, 0),), orders: Iterable[int] | None = None, vertex_types: Iterable[str] | None = None, integrate_over: Any = None, method: str = "qmc_vectorized", n_samples: int = 2 ** 13, seed: int | None = 42, n_jobs: int = 1, evaluate_n_jobs: int = 1, n_gauss: int = 8, ): """Cartesian-product sweep over positions, t_final, and component pairs. Args: positions_grid: ``{spatial_arg: [list of values]}``. Each key's list is swept independently; result is the full Cartesian product. E.g. ``{"x": [0.0], "y": [0.0, 0.5, 1.0]}``. t_final_grid: list of upper time bounds. component_pairs: list of ``(a, b)`` component index tuples. vertex_types: optional filter — same semantics as in :meth:`evaluate`; only diagrams whose :meth:`_vertex_type_label` lies in this set are integrated. ``None`` ⇒ all channels. n_jobs: parallelise over Cartesian-product grid points (``positions × t_final × component_pairs``). ``1`` (default) preserves the original sequential behaviour — bit-identical when seed is fixed. ``-1`` uses all CPU cores via joblib loky. evaluate_n_jobs: parallelise over diagrams **inside** each grid point's :meth:`evaluate` call. Mutually exclusive with ``n_jobs > 1`` (nested loky pools are not supported); the dispatcher raises if both are set. Use ``n_jobs > 1`` when the sweep grid is large; use ``evaluate_n_jobs > 1`` when each grid point has many diagrams (typical at orders >= 2). method, n_samples, seed, n_gauss: integrator knobs -- see :meth:`evaluate` for the recommendation matrix and :doc:`/user_guide/workflow` "Choosing an integrator". ``'gauss_legendre'`` with ``n_gauss=8`` is the right default for ``d ≤ 5`` smooth integrands (exponential convergence, deterministic, no seed). Returns: :class:`SweepResult` with a pandas-friendly tidy table. """ from .result import SweepResult if int(n_jobs) != 1 and int(evaluate_n_jobs) != 1: raise ValueError( "Specify exactly one of {n_jobs, evaluate_n_jobs} > 1; " "nested joblib loky pools are not supported." ) orders_list = ( sorted(set(int(o) for o in orders)) if orders is not None else list(self.orders) ) pos_keys = list(positions_grid.keys()) pos_values = [positions_grid[k] for k in pos_keys] # Flatten the Cartesian product to a list of grid-point tasks. grid_tasks: list[tuple[dict, Any, tuple]] = [] for pos_tuple in itertools.product(*pos_values): positions = dict(zip(pos_keys, pos_tuple)) for t_f in t_final_grid: for (a, b) in component_pairs: grid_tasks.append((positions, t_f, (a, b))) def _eval_grid_point(task): positions, t_f, comp = task res = self.evaluate( propagators, positions=positions, t_final=t_f, component_pair=comp, orders=orders_list, vertex_types=vertex_types, integrate_over=integrate_over, method=method, n_samples=n_samples, seed=seed, n_jobs=evaluate_n_jobs, n_gauss=n_gauss, ) return positions, t_f, comp, res if int(n_jobs) == 1 or len(grid_tasks) <= 2: # Sequential — bit-identical to the pre-refactor nested loops. results = [_eval_grid_point(t) for t in grid_tasks] else: from joblib import Parallel, delayed results = Parallel(n_jobs=n_jobs, backend="loky")( delayed(_eval_grid_point)(t) for t in grid_tasks ) rows = [] for positions, t_f, (a, b), res in results: # Hashable normalisation: d-dim vector positions arrive as # ``list`` or ``np.ndarray``; pandas ``groupby`` (used in # :meth:`SweepResult.totals`) factorises group keys via a # hash table, which rejects list-typed cells with # ``TypeError: unhashable type: 'list'``. Coerce here so # downstream aggregation works for both scalar and d-dim # positions. hashable_positions = { k: (tuple(v.tolist()) if hasattr(v, "tolist") else (tuple(v) if isinstance(v, (list, tuple)) else v)) for k, v in positions.items() } for pd_row in res.per_diagram: rows.append({ **hashable_positions, "t_final": t_f, "a": a, "b": b, **pd_row, }) return SweepResult(rows=rows, position_keys=tuple(pos_keys))
# --------------------------------------------------------------- # # Internal helpers # --------------------------------------------------------------- # @staticmethod def _vertex_type_label(dt) -> str: """Extract the sorted unique coupling-symbol names from ``dt.coupling_sum`` and join them into a single string (matches demo2's convention: ``{'F'}`` → ``"F"``, ``{'F', 'K'}`` → ``"FK"``).""" names = _collect_symbol_names(dt.coupling_sum) if not names: return "" # order-0 trivial return "".join(sorted(names))
# ========================================================================= # Module-level helpers # =========================================================================
[docs] def count_cross_group_c(dt) -> int: """Number of C propagators whose endpoints land in distinct direction groups (same helper as :class:`tests.test_deductive_numerics.TestSpatialAwareCache` and :mod:`examples.demo1.validate_phase5`).""" spatial = dt.analyze_spatial() n = 0 for p in dt.propagators: if p.kind != "C": continue d_l = spatial.direction_map[p.spatial_left] d_r = spatial.direction_map[p.spatial_right] if d_l != d_r: n += 1 return n
def _collect_symbol_names(expr: Expr) -> set[str]: """Walk a coupling-sum expression tree and return the set of :class:`Symbol` names referenced (the coupling-tensor names).""" out: set[str] = set() def walk(e: Expr): if isinstance(e, Symbol): out.add(e.name) return if isinstance(e, Rational): return if isinstance(e, (Product, Sum)): for f in e.factors if isinstance(e, Product) else e.terms: walk(f) return # Catch other Expr subclasses that may wrap children for attr in ("expr", "body", "integrand"): child = getattr(e, attr, None) if isinstance(child, Expr): walk(child) walk(expr) return out def _component_indices(component_pair, observable_repr): """Build the ``fixed_indices`` dict for :meth:`DiagramTerm.build_integrand` from the observable's component-index names. For observable ``(phi_a(x), phi_b(y))`` this yields ``{"a": component_pair[0], "b": component_pair[1]}``. For a scalar observable (no component indices), returns ``{}``. """ fi: dict[str, int] = {} for i, (name, comp, spatial) in enumerate(observable_repr): if comp is not None and i < len(component_pair): fi[comp] = int(component_pair[i]) return fi