#!/usr/bin/env python3
"""
SPARC Full Galaxy Analysis: ISL vs MOND vs NFW
Batch-processes all 175 SPARC galaxies to test ISL modularity overhead hypothesis

Key Test: Is α_ISL ≈ 0.1 universal across all galaxy types?
Falsification Criterion: If α_ISL variance >20%, ISL needs revision
"""

import numpy as np
import glob
import json
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import List, Dict, Tuple

@dataclass
class GalaxyFit:
    """Results from fitting a single galaxy"""
    name: str
    distance_mpc: float
    n_points: int
    
    # ISL Model Results
    isl_ml_disk: float
    isl_rmod_kpc: float
    isl_alpha_isl: float  # Derived: α_ISL = 1/r_mod (normalized)
    isl_chi2: float
    isl_red_chi2: float
    
    # MOND Model Results (for comparison)
    mond_a0: float  # MOND acceleration parameter
    mond_chi2: float
    mond_red_chi2: float
    
    # Newtonian-only (no dark matter, no ISL)
    newton_chi2: float
    newton_red_chi2: float
    
    # Quality flags
    fit_quality: str  # "excellent", "good", "poor", "failed"
    notes: str

def load_sparc_galaxy(filepath: str) -> Tuple[str, float, np.ndarray]:
    """
    Load a SPARC galaxy rotation curve file
    
    Returns:
        galaxy_name: str
        distance_mpc: float
        data: np.ndarray with columns [Rad, Vobs, errV, Vgas, Vdisk, Vbul]
    """
    galaxy_name = Path(filepath).stem.replace('_rotmod', '')
    
    # Read distance from header
    with open(filepath, 'r') as f:
        first_line = f.readline()
        if 'Distance' in first_line:
            distance_str = first_line.split('=')[1].strip().split()[0]
            distance_mpc = float(distance_str)
        else:
            distance_mpc = 10.0  # Default fallback
    
    # Load data (skip header lines starting with #)
    data = np.loadtxt(filepath, comments='#')
    
    return galaxy_name, distance_mpc, data

def isl_model(r: np.ndarray, v_gas: np.ndarray, v_disk: np.ndarray, 
              v_bul: np.ndarray, ml_disk: float, r_mod: float) -> np.ndarray:
    """
    ISL modified gravity model:
    V²(r) = V²_newton(r) * (1 + r/r_mod)
    
    where V²_newton = (ML * V²_disk) + V²_gas + V²_bulge
    and r_mod is the modularity radius (ISL overhead scale)
    """
    v_newton_sq = (ml_disk * v_disk**2) + v_gas**2 + v_bul**2
    v_pred_sq = v_newton_sq * (1.0 + r / r_mod)
    return np.sqrt(v_pred_sq)

def mond_model(r: np.ndarray, v_gas: np.ndarray, v_disk: np.ndarray,
               v_bul: np.ndarray, ml_disk: float, a0: float) -> np.ndarray:
    """
    MOND (Modified Newtonian Dynamics) model:
    V²(r) = V²_newton * μ(V²_newton / (r*a0))
    
    where μ(x) = x / (1 + x) is the interpolation function
    and a0 ≈ 1.2e-10 m/s² is the MOND acceleration scale
    """
    v_newton_sq = (ml_disk * v_disk**2) + v_gas**2 + v_bul**2
    a_newton = v_newton_sq / r  # Newtonian acceleration
    
    # MOND interpolation function
    x = a_newton / (a0 * 1e3)  # Convert a0 from m/s² to kpc/Myr²
    mu = x / (1.0 + x)
    
    v_pred_sq = v_newton_sq * mu
    return np.sqrt(v_pred_sq)

def newton_model(v_gas: np.ndarray, v_disk: np.ndarray, 
                 v_bul: np.ndarray, ml_disk: float) -> np.ndarray:
    """
    Pure Newtonian model (no dark matter, no modifications)
    V² = (ML * V²_disk) + V²_gas + V²_bulge
    """
    v_newton_sq = (ml_disk * v_disk**2) + v_gas**2 + v_bul**2
    return np.sqrt(v_newton_sq)

def fit_galaxy(filepath: str, verbose: bool = False) -> GalaxyFit:
    """
    Fit a single galaxy with ISL, MOND, and Newtonian models
    """
    galaxy_name, distance_mpc, data = load_sparc_galaxy(filepath)
    
    if verbose:
        print(f"\nFitting {galaxy_name} (distance: {distance_mpc:.2f} Mpc, {len(data)} points)")
    
    # Extract columns
    rad = data[:, 0]
    v_obs = data[:, 1]
    err_v = data[:, 2]
    v_gas = data[:, 3]
    v_disk = data[:, 4]
    v_bul = data[:, 5]
    
    # Filter out zero/nan values
    valid = (v_obs > 0) & (err_v > 0) & np.isfinite(v_obs) & np.isfinite(err_v)
    rad = rad[valid]
    v_obs = v_obs[valid]
    err_v = err_v[valid]
    v_gas = v_gas[valid]
    v_disk = v_disk[valid]
    v_bul = v_bul[valid]
    
    n_points = len(rad)
    
    if n_points < 3:
        return GalaxyFit(
            name=galaxy_name, distance_mpc=distance_mpc, n_points=n_points,
            isl_ml_disk=0, isl_rmod_kpc=0, isl_alpha_isl=0, isl_chi2=0, isl_red_chi2=0,
            mond_a0=0, mond_chi2=0, mond_red_chi2=0,
            newton_chi2=0, newton_red_chi2=0,
            fit_quality="failed", notes="Insufficient data points"
        )
    
    # === ISL Model Fit ===
    ml_grid = np.linspace(0.3, 3.0, 40)
    rmod_grid = np.linspace(1.0, 100.0, 50)
    
    best_isl_chi2 = float('inf')
    best_isl_params = (1.0, 10.0)
    
    for ml in ml_grid:
        for rmod in rmod_grid:
            v_pred = isl_model(rad, v_gas, v_disk, v_bul, ml, rmod)
            chi2 = np.sum(((v_obs - v_pred) / err_v)**2)
            if chi2 < best_isl_chi2:
                best_isl_chi2 = chi2
                best_isl_params = (ml, rmod)
    
    isl_ml, isl_rmod = best_isl_params
    isl_red_chi2 = best_isl_chi2 / (n_points - 2)
    isl_alpha_isl = 1.0 / isl_rmod  # Normalized modularity parameter
    
    # === MOND Model Fit ===
    ml_grid_mond = np.linspace(0.3, 3.0, 40)
    a0_grid = np.linspace(0.5, 2.0, 30)  # a0 in units of 1e-10 m/s²
    
    best_mond_chi2 = float('inf')
    best_mond_params = (1.0, 1.2)
    
    for ml in ml_grid_mond:
        for a0 in a0_grid:
            v_pred = mond_model(rad, v_gas, v_disk, v_bul, ml, a0)
            chi2 = np.sum(((v_obs - v_pred) / err_v)**2)
            if chi2 < best_mond_chi2:
                best_mond_chi2 = chi2
                best_mond_params = (ml, a0)
    
    mond_ml, mond_a0 = best_mond_params
    mond_red_chi2 = best_mond_chi2 / (n_points - 2)
    
    # === Newtonian Model (ML fit only) ===
    best_newton_chi2 = float('inf')
    best_newton_ml = 1.0
    
    for ml in ml_grid:
        v_pred = newton_model(v_gas, v_disk, v_bul, ml)
        chi2 = np.sum(((v_obs - v_pred) / err_v)**2)
        if chi2 < best_newton_chi2:
            best_newton_chi2 = chi2
            best_newton_ml = ml
    
    newton_red_chi2 = best_newton_chi2 / (n_points - 1)
    
    # Determine fit quality
    if isl_red_chi2 < 1.5:
        quality = "excellent"
    elif isl_red_chi2 < 3.0:
        quality = "good"
    elif isl_red_chi2 < 10.0:
        quality = "poor"
    else:
        quality = "failed"
    
    notes = f"ISL vs MOND: {isl_red_chi2:.2f} vs {mond_red_chi2:.2f}"
    
    if verbose:
        print(f"  ISL: χ²_red = {isl_red_chi2:.3f}, r_mod = {isl_rmod:.1f} kpc, α_ISL = {isl_alpha_isl:.4f}")
        print(f"  MOND: χ²_red = {mond_red_chi2:.3f}, a0 = {mond_a0:.2f}e-10 m/s²")
        print(f"  Newton: χ²_red = {newton_red_chi2:.3f}")
    
    return GalaxyFit(
        name=galaxy_name,
        distance_mpc=distance_mpc,
        n_points=n_points,
        isl_ml_disk=isl_ml,
        isl_rmod_kpc=isl_rmod,
        isl_alpha_isl=isl_alpha_isl,
        isl_chi2=best_isl_chi2,
        isl_red_chi2=isl_red_chi2,
        mond_a0=mond_a0,
        mond_chi2=best_mond_chi2,
        mond_red_chi2=mond_red_chi2,
        newton_chi2=best_newton_chi2,
        newton_red_chi2=newton_red_chi2,
        fit_quality=quality,
        notes=notes
    )

def analyze_all_galaxies(data_dir: str = "/home/shri/Desktop/MATHTRUTH/sparc_data",
                         output_dir: str = "/home/shri/Desktop/MATHTRUTH/cosmic_synthesis/reports",
                         verbose: bool = True) -> List[GalaxyFit]:
    """
    Batch process all SPARC galaxies
    """
    galaxy_files = sorted(glob.glob(f"{data_dir}/*_rotmod.dat"))
    
    print(f"Found {len(galaxy_files)} SPARC galaxies")
    print("=" * 70)
    
    results = []
    
    for i, filepath in enumerate(galaxy_files, 1):
        if verbose and i % 10 == 0:
            print(f"\nProgress: {i}/{len(galaxy_files)} galaxies processed")
        
        try:
            fit = fit_galaxy(filepath, verbose=False)
            results.append(fit)
        except Exception as e:
            print(f"ERROR fitting {Path(filepath).stem}: {e}")
            continue
    
    # Save results
    output_file = Path(output_dir) / "SPARC_FULL_ANALYSIS_RESULTS.json"
    with open(output_file, 'w') as f:
        json.dump([asdict(r) for r in results], f, indent=2)
    
    print(f"\n✅ Results saved to {output_file}")
    
    return results

def generate_statistical_summary(results: List[GalaxyFit]) -> Dict:
    """
    Generate statistical summary of ISL universality test
    """
    # Filter out failed fits
    valid_results = [r for r in results if r.fit_quality != "failed"]
    
    # Extract α_ISL values
    alpha_isl_values = np.array([r.isl_alpha_isl for r in valid_results])
    
    # Calculate statistics
    mean_alpha = np.mean(alpha_isl_values)
    std_alpha = np.std(alpha_isl_values)
    variance_pct = (std_alpha / mean_alpha) * 100
    
    # Count wins: ISL vs MOND vs Newton
    isl_wins = sum(1 for r in valid_results if r.isl_red_chi2 < r.mond_red_chi2 and r.isl_red_chi2 < r.newton_red_chi2)
    mond_wins = sum(1 for r in valid_results if r.mond_red_chi2 < r.isl_red_chi2 and r.mond_red_chi2 < r.newton_red_chi2)
    newton_wins = sum(1 for r in valid_results if r.newton_red_chi2 < r.isl_red_chi2 and r.newton_red_chi2 < r.mond_red_chi2)
    
    # Average χ² values
    avg_isl_chi2 = np.mean([r.isl_red_chi2 for r in valid_results])
    avg_mond_chi2 = np.mean([r.mond_red_chi2 for r in valid_results])
    avg_newton_chi2 = np.mean([r.newton_red_chi2 for r in valid_results])
    
    summary = {
        "total_galaxies": len(results),
        "valid_fits": len(valid_results),
        "failed_fits": len(results) - len(valid_results),
        "alpha_isl_statistics": {
            "mean": float(mean_alpha),
            "std_dev": float(std_alpha),
            "variance_percent": float(variance_pct),
            "min": float(np.min(alpha_isl_values)),
            "max": float(np.max(alpha_isl_values)),
            "median": float(np.median(alpha_isl_values))
        },
        "universality_test": {
            "variance_threshold": 20.0,
            "actual_variance": float(variance_pct),
            "passes_test": bool(variance_pct < 20.0),
            "verdict": "UNIVERSAL" if variance_pct < 20.0 else "NON-UNIVERSAL (ISL needs revision)"
        },
        "model_comparison": {
            "isl_wins": isl_wins,
            "mond_wins": mond_wins,
            "newton_wins": newton_wins,
            "isl_win_rate": float(isl_wins / len(valid_results)),
            "avg_chi2_isl": float(avg_isl_chi2),
            "avg_chi2_mond": float(avg_mond_chi2),
            "avg_chi2_newton": float(avg_newton_chi2)
        },
        "quality_distribution": {
            "excellent": sum(1 for r in valid_results if r.fit_quality == "excellent"),
            "good": sum(1 for r in valid_results if r.fit_quality == "good"),
            "poor": sum(1 for r in valid_results if r.fit_quality == "poor")
        }
    }
    
    return summary

def main():
    """Main execution"""
    print("=" * 70)
    print("SPARC FULL GALAXY ANALYSIS: ISL vs MOND vs NFW")
    print("Testing ISL Modularity Overhead Universality")
    print("=" * 70)
    
    # Analyze all galaxies
    results = analyze_all_galaxies(verbose=True)
    
    # Generate statistical summary
    summary = generate_statistical_summary(results)
    
    # Save summary
    output_dir = Path("/home/shri/Desktop/MATHTRUTH/cosmic_synthesis/reports")
    summary_file = output_dir / "SPARC_STATISTICAL_SUMMARY.json"
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    
    print("\n" + "=" * 70)
    print("STATISTICAL SUMMARY")
    print("=" * 70)
    print(f"\nTotal Galaxies: {summary['total_galaxies']}")
    print(f"Valid Fits: {summary['valid_fits']}")
    print(f"Failed Fits: {summary['failed_fits']}")
    
    print(f"\n📊 α_ISL Statistics:")
    print(f"  Mean: {summary['alpha_isl_statistics']['mean']:.4f}")
    print(f"  Std Dev: {summary['alpha_isl_statistics']['std_dev']:.4f}")
    print(f"  Variance: {summary['alpha_isl_statistics']['variance_percent']:.2f}%")
    print(f"  Range: [{summary['alpha_isl_statistics']['min']:.4f}, {summary['alpha_isl_statistics']['max']:.4f}]")
    
    print(f"\n🎯 Universality Test:")
    print(f"  Threshold: <20% variance")
    print(f"  Actual: {summary['universality_test']['actual_variance']:.2f}%")
    print(f"  Verdict: {summary['universality_test']['verdict']}")
    
    print(f"\n🏆 Model Comparison (χ² wins):")
    print(f"  ISL: {summary['model_comparison']['isl_wins']} ({summary['model_comparison']['isl_win_rate']*100:.1f}%)")
    print(f"  MOND: {summary['model_comparison']['mond_wins']}")
    print(f"  Newton: {summary['model_comparison']['newton_wins']}")
    
    print(f"\n📈 Average χ²_red:")
    print(f"  ISL: {summary['model_comparison']['avg_chi2_isl']:.3f}")
    print(f"  MOND: {summary['model_comparison']['avg_chi2_mond']:.3f}")
    print(f"  Newton: {summary['model_comparison']['avg_chi2_newton']:.3f}")
    
    print("\n" + "=" * 70)
    print(f"✅ Full analysis complete!")
    print(f"📁 Results: {output_dir}/SPARC_FULL_ANALYSIS_RESULTS.json")
    print(f"📁 Summary: {summary_file}")
    print("=" * 70)

if __name__ == "__main__":
    main()
