import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Constants (Galactic Units)
# G = 4.30e-6 kpc (km/s)^2 / M_sun
# But for stability, we use N-Body units: G=1, M=1, R=1
G = 1.0

class GalaxySim:
    def __init__(self, n_particles=500, alpha_isl=0.0):
        self.N = n_particles
        self.alpha = alpha_isl
        
        # Initial Conditions (Plummer Model approximation)
        # "Puffy" Dwarf Galaxy
        # Positions: Gaussian blob
        scale_radius = 1.0
        self.pos = np.random.randn(self.N, 3) * scale_radius
        
        # Velocities: Initialized with High Dispersion (Unstable for Newton)
        # We give it enough KE to be unbound in Newton (Q > 1)
        # Velocities: Initialized with Dispersion
        # Adjusted to 0.4 for Stability Test
        self.vel = np.random.randn(self.N, 3) * 0.4 
        
        self.mass = np.ones(self.N) / self.N # Total mass = 1
        self.acc = np.zeros_like(self.pos)
        
        # Diagnostics
        v_sq = np.sum(self.vel**2)
        K = 0.5 * np.sum(self.mass * np.sum(self.vel**2, axis=1)) # 0.5 * m * v^2
        # Approx Potential W ~ -0.4 GM^2 / r50?
        # Let's just trust the run for now, but print K
        print(f"Init K: {K:.3f}")
        
        # Metrics
        self.time = 0
        self.history = {'t': [], 'r50': [], 'v_disp': []}

    def compute_forces(self):
        """
        Direct Summation Force Calculation O(N^2)
        F_ij = -G mi mj / r^3 * r_vec * (1 + alpha)
        """
        # Vectorized dist calculation
        # pos shape (N, 3).
        # Diff matrix (N, N, 3)
        # Caution: large memory for big N. N=500 is fine.
        
        diff = self.pos[:, np.newaxis, :] - self.pos[np.newaxis, :, :] # r_j - r_i ? No, r_i - r_j
        
        # We want Force on i from j. F_i = Sum_j G mi mj / r_ij^3 * (r_j - r_i)
        # diff[i,j] = pos[i] - pos[j] = -(r_j - r_i)
        
        r_sq = np.sum(diff**2, axis=-1)
        
        # Softening to prevent explosion at r=0
        epsilon = 0.1
        r_sq += epsilon**2
        
        r_dist = np.sqrt(r_sq)
        r_cube = r_dist**3
        
        # Force magnitude terms (G * mj / r^3)
        # Matrix of coefficients
        # Ignore self-force (diagonal)
        np.fill_diagonal(r_cube, np.inf)
        
        # F_factor shape (N, N)
        # F_ij magnitude ~ m_j / r_ij^3
        # We multiply by diff which is vector r_ij
        
        # ISL Modification: Effective Gravity is stronger by (1 + alpha)
        G_eff = G * (1.0 + self.alpha)
        
        # Scalar coefficient M_ij = m_j / r_ij^3
        # mass shape: (1, N), r_cube: (N, N) -> Result (N, N)
        scalar_coeff = self.mass[np.newaxis, :] / r_cube
        
        # Expand to (N, N, 1) to multiply with diff (N, N, 3)
        acc_matrix = -G_eff * scalar_coeff[:, :, np.newaxis] * diff
        
        # Sum over j (axis 1) to get acc on i
        self.acc = np.sum(acc_matrix, axis=1)

    def step(self, dt):
        # Kick-Drift-Kick (Leapfrog)
        
        # 1. First Kick (v += a * dt/2)
        self.vel += self.acc * (dt / 2.0)
        
        # 2. Drift (x += v * dt)
        self.pos += self.vel * dt
        
        # 3. Recompute Forces
        self.compute_forces()
        
        # 4. Second Kick
        self.vel += self.acc * (dt / 2.0)
        
        self.time += dt
        
        # Record Metrics
        r50 = np.median(np.linalg.norm(self.pos, axis=1))
        v_disp = np.std(self.vel)
        self.history['t'].append(self.time)
        self.history['r50'].append(r50)
        self.history['v_disp'].append(v_disp)

def run_comparison():
    print("🚀  Running Toy N-Body Mechanism Test...")
    print("Scenario: High-Dispersion Dwarf Galaxy (Baryons Only)")
    print("Hypothesis: Newton (Alpha=0) explodes. ISL (Alpha=0.35) stays bound.")
    
    # Run 1: Newtonian (Alpha = 0)
    print("\n[1/2] Running Newtonian Control...")
    sim_newt = GalaxySim(n_particles=200, alpha_isl=0.0)
    sim_newt.compute_forces() # Initial ACC
    for _ in tqdm(range(200)):
        sim_newt.step(dt=0.05)
        
    # Run 2: ISL (Alpha = 0.35)
    print("\n[2/2] Running ISL Mechanism Test...")
    sim_isl = GalaxySim(n_particles=200, alpha_isl=0.35)
    sim_isl.compute_forces() # Initial ACC
    for _ in tqdm(range(200)):
        sim_isl.step(dt=0.05)
        
    # Validation
    final_r50_newt = sim_newt.history['r50'][-1]
    final_r50_isl = sim_isl.history['r50'][-1]
    init_r50 = sim_newt.history['r50'][0]
    
    print("\n📊 RESULTS (Half-Mass Radius R50):")
    print(f"   Initial R50: {init_r50:.2f}")
    print(f"   Newton Final: {final_r50_newt:.2f} (Ratio: {final_r50_newt/init_r50:.2f})")
    print(f"   ISL Final:    {final_r50_isl:.2f} (Ratio: {final_r50_isl/init_r50:.2f})")
    
    # Check Kill Conditions
    print("\n📝 VERDICT:")
    
    # 1. Newton should expand (> 1.5x)
    newton_unbound = (final_r50_newt / init_r50) > 1.5
    if newton_unbound:
        print("   ✅ Control Valid: Newtonian Control evaporated (Unbound).")
    else:
        print("   ⚠️ Control Warning: Newtonian system was too stable (Initial V too low?).")
        
    # 2. ISL should be stable (< 1.5x)
    isl_stable = (final_r50_isl / init_r50) < 1.5 and (final_r50_isl / init_r50) > 0.5
    if isl_stable:
        print("   ✅ Mechanism Valid: ISL System remained bound & stable.")
    else:
        if (final_r50_isl / init_r50) > 1.5:
             print("   ❌ FAILED: ISL System also evaporated (Alpha too weak).")
        else:
             print("   ❌ FAILED: ISL System collapsed (Alpha too strong).")

    # Plot
    plt.figure(figsize=(10,5))
    plt.plot(sim_newt.history['t'], sim_newt.history['r50'], label='Newton (Alpha=0)', color='gray', linestyle='--')
    plt.plot(sim_isl.history['t'], sim_isl.history['r50'], label='ISL (Alpha=0.35)', color='red', linewidth=2)
    plt.xlabel("Time (Code Units)")
    plt.ylabel("Half-Mass Radius (R50)")
    plt.title("Mechanism Test: Dwarf Stability (Baryons Only)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig("cosmic_synthesis/reports/nbody_mechanism_test.png")
    print("   - Plot saved to cosmic_synthesis/reports/nbody_mechanism_test.png")

if __name__ == "__main__":
    run_comparison()
