#!/usr/bin/env python3
"""
N4 — NEWTON FROM MONOGAMY
==========================
Monogamy Gravity, Folio V numerical experiment.

CLAIM UNDER TEST (Postulate 3 + the geometry reading):
    The force between two static bodies is the gradient of the
    entanglement budget: placing two defects in the vacuum depletes the
    vacuum's mutual-information account between them, and the
    interaction ENERGY tracks the depleted ACCOUNT. If geometry is the
    ledger, E_int(r) and the budget deficit dI(r) must carry the same
    r-dependence — log-like in 1+1D, power-law in 3+1D.

GATE (written before the run):
    If E_int(r) and the vacuum MI depletion carry manifestly different
    r-dependences (one flat where the other falls), the budget-gradient
    reading of force fails and Folio V cannot ship as physics.

PROTOCOL ("two masses on the books"):
    A static body = a pinned site: K_ii += M^2 at site i (a local mass
    term — the simplest stress-energy insertion that keeps the state
    Gaussian and the analysis exact).
      E(r)      : ground-state energy (1/2) Tr K^{1/2} with both pins,
                  vs single-pin and vacuum references:
                  E_int(r) = E_2(r) - 2 E_1 + E_0
      I_AB(r)   : vacuum mutual information between small regions around
                  the two pins, vs the same regions unpinned.
    Sweep separation r. Compare exponents.

Plain numpy. 1D chain exact; 3D lattice via full eigendecomposition.
"""

import numpy as np
import json
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)


def entropy_from_XP(X, P, region):
    idx = np.ix_(region, region)
    nu2 = np.linalg.eigvals(X[idx] @ P[idx]).real
    nu = np.sqrt(np.clip(nu2, 0.25, None))
    up, dn = nu + 0.5, nu - 0.5
    return float(np.sum(up * np.log(up)) -
                 np.sum(np.where(dn > 1e-12,
                                 dn * np.log(np.clip(dn, 1e-300, None)), 0.0)))


def solve(K):
    """Ground state: energy and covariances from the coupling matrix."""
    w2, U = np.linalg.eigh(K)
    w = np.sqrt(np.clip(w2, 1e-30, None))
    E = 0.5 * float(np.sum(w))
    X = (U * (0.5 / w)) @ U.T
    P = (U * (0.5 * w)) @ U.T
    return E, X, P


# ================================================================ 1+1D
def run_1d(N=400, m=1e-3, M2=10.0, half=2, seps=(20, 30, 40, 60, 80, 120, 160)):
    print(f"\n=== 1+1D chain: N={N}, m={m}, pin strength M^2={M2}, "
          f"region half-width={half}")
    K0 = np.zeros((N, N))
    for i in range(N):
        K0[i, i] = m * m + 2.0
        K0[i, (i + 1) % N] = -1.0
        K0[(i + 1) % N, i] = -1.0
    E0, Xv, Pv = solve(K0)

    c = N // 2
    K1 = K0.copy(); K1[c, c] += M2
    E1, _, _ = solve(K1)

    rows = []
    print(f"{'r':>5} {'E_int(r)':>13} {'I_AB vac':>10} {'I_AB pin':>10} {'dI(r)':>11}")
    for r in seps:
        a, b = c - r // 2, c + (r + 1) // 2
        K2 = K0.copy(); K2[a, a] += M2; K2[b, b] += M2
        E2, X2, P2 = solve(K2)
        Eint = E2 - 2 * E1 + E0
        A = list(range(a - half, a + half + 1))
        B = list(range(b - half, b + half + 1))
        def MI(X, P): return (entropy_from_XP(X, P, A) + entropy_from_XP(X, P, B)
                              - entropy_from_XP(X, P, A + B))
        i_vac, i_pin = MI(Xv, Pv), MI(X2, P2)
        rows.append(dict(r=r, Eint=Eint, I_vac=i_vac, I_pin=i_pin,
                         dI=i_pin - i_vac))
        print(f"{r:5d} {Eint:13.4e} {i_vac:10.5f} {i_pin:10.5f} {i_pin - i_vac:+11.4e}")

    # exponents from log-log fits (power-law check)
    rr = np.array([q['r'] for q in rows], float)
    eint = np.abs(np.array([q['Eint'] for q in rows]))
    di = np.abs(np.array([q['dI'] for q in rows]))
    pE = np.polyfit(np.log(rr), np.log(np.clip(eint, 1e-300, None)), 1)[0]
    pI = np.polyfit(np.log(rr), np.log(np.clip(di, 1e-300, None)), 1)[0]
    print(f"--- power-law exponents:  E_int ~ r^{pE:+.2f},   dI ~ r^{pI:+.2f}")
    return dict(dim=1, rows=rows, expE=float(pE), expI=float(pI))


# ================================================================ 3+1D
def run_3d(L=16, m=1e-2, M2=10.0, seps=(3, 4, 5, 6, 7)):
    N = L ** 3
    print(f"\n=== 3+1D lattice: {L}^3 = {N} sites, m={m}, pin strength M^2={M2}")
    idx = lambda x, y, z: (x % L) * L * L + (y % L) * L + (z % L)
    K0 = np.zeros((N, N))
    for x in range(L):
        for y in range(L):
            for z in range(L):
                i = idx(x, y, z)
                K0[i, i] = m * m + 6.0
                for dx, dy, dz in ((1, 0, 0), (0, 1, 0), (0, 0, 1)):
                    j = idx(x + dx, y + dy, z + dz)
                    K0[i, j] -= 1.0; K0[j, i] -= 1.0
    E0, Xv, Pv = solve(K0)
    c = L // 2
    i1 = idx(c, c, c)
    K1 = K0.copy(); K1[i1, i1] += M2
    E1, _, _ = solve(K1)

    def ball(cx, cy, cz, R=1.5):
        out = []
        for x in range(L):
            for y in range(L):
                for z in range(L):
                    d2 = min((x-cx)%L, (cx-x)%L)**2 + min((y-cy)%L, (cy-y)%L)**2 \
                         + min((z-cz)%L, (cz-z)%L)**2
                    if d2 <= R*R: out.append(idx(x, y, z))
        return out

    rows = []
    print(f"{'r':>5} {'E_int(r)':>13} {'I_AB vac':>10} {'I_AB pin':>10} {'dI(r)':>11}")
    for r in seps:
        i2 = idx(c + r, c, c)
        K2 = K0.copy(); K2[i1, i1] += M2; K2[i2, i2] += M2
        E2, X2, P2 = solve(K2)
        Eint = E2 - 2 * E1 + E0
        A, B = ball(c, c, c), ball(c + r, c, c)
        def MI(X, P): return (entropy_from_XP(X, P, A) + entropy_from_XP(X, P, B)
                              - entropy_from_XP(X, P, A + B))
        i_vac, i_pin = MI(Xv, Pv), MI(X2, P2)
        rows.append(dict(r=r, Eint=Eint, I_vac=i_vac, I_pin=i_pin,
                         dI=i_pin - i_vac))
        print(f"{r:5d} {Eint:13.4e} {i_vac:10.5f} {i_pin:10.5f} {i_pin - i_vac:+11.4e}")

    rr = np.array([q['r'] for q in rows], float)
    eint = np.abs(np.array([q['Eint'] for q in rows]))
    di = np.abs(np.array([q['dI'] for q in rows]))
    pE = np.polyfit(np.log(rr), np.log(np.clip(eint, 1e-300, None)), 1)[0]
    pI = np.polyfit(np.log(rr), np.log(np.clip(di, 1e-300, None)), 1)[0]
    print(f"--- power-law exponents:  E_int ~ r^{pE:+.2f},   dI ~ r^{pI:+.2f}")
    return dict(dim=3, rows=rows, expE=float(pE), expI=float(pI))


if __name__ == '__main__':
    print("N4 — Newton from monogamy: two pinned masses, energy vs account.")
    d1 = run_1d()
    d3 = run_3d()
    print("\n" + "=" * 64)
    print("VERDICT")
    print(f"1+1D: E_int exponent {d1['expE']:+.2f} vs account exponent {d1['expI']:+.2f}")
    print(f"3+1D: E_int exponent {d3['expE']:+.2f} vs account exponent {d3['expI']:+.2f}")
    json.dump(dict(d1=d1, d3=d3),
              open('/Users/antoine/agi/ledger/n4_results.json', 'w'), indent=1)
    print("results -> n4_results.json")
