#!/usr/bin/env python3
"""
NanoCERN Knowledge Reactor - Enhanced CLI Interface
Constraint-based reasoning engine using Knowledge Units (KUs).
"""

import argparse
import json
import struct
import sys
from pathlib import Path
from collections import defaultdict

def load_ku_binary(ku_file: Path) -> dict:
    """Load a Knowledge Unit from binary .ku file."""
    with open(ku_file, 'rb') as f:
        content = f.read()
    
    # Find JSON start (after KUAT header)
    json_start = content.find(b'{')
    if json_start == -1:
        raise ValueError(f"No JSON found in {ku_file}")
    
    json_data = content[json_start:].decode('utf-8', errors='ignore')
    return json.loads(json_data)

def load_ku_json(ku_file: Path) -> dict:
    """Load a Knowledge Unit from JSON file."""
    with open(ku_file, 'r') as f:
        return json.load(f)

def load_ku(ku_file: Path) -> dict:
    """Load a Knowledge Unit from file (auto-detect format)."""
    if ku_file.suffix == '.ku':
        return load_ku_binary(ku_file)
    else:
        return load_ku_json(ku_file)

def validate_ku(ku: dict) -> bool:
    """Validate KU structure."""
    required_fields = ['id', 'domain', 'invariant']
    return all(field in ku for field in required_fields)

def load_all_kus(atoms_dir: Path) -> list:
    """Load all KUs from atoms directory."""
    kus = []
    for ku_file in atoms_dir.glob('*.ku'):
        try:
            ku = load_ku(ku_file)
            kus.append(ku)
        except Exception as e:
            print(f"⚠️  Failed to load {ku_file.name}: {e}", file=sys.stderr)
    return kus

def list_kus(atoms_dir: Path, domain_filter: str = None):
    """List all available Knowledge Units."""
    kus = load_all_kus(atoms_dir)
    
    if domain_filter:
        kus = [ku for ku in kus if ku.get('domain', '').lower() == domain_filter.lower()]
    
    print(f"\n{'='*80}")
    print(f"KNOWLEDGE UNITS ({len(kus)} total)")
    if domain_filter:
        print(f"Domain filter: {domain_filter}")
    print(f"{'='*80}\n")
    
    for ku in sorted(kus, key=lambda k: (k.get('domain', ''), k.get('id', ''))):
        domain = ku.get('domain', 'unknown')
        ku_id = ku.get('id', 'UNKNOWN')
        invariant = ku.get('invariant', {})
        
        # Handle both dict and string invariants
        if isinstance(invariant, dict):
            inv_text = invariant.get('name', invariant.get('expression', ''))
        else:
            inv_text = str(invariant)
        
        inv_display = inv_text[:70] + '...' if len(inv_text) > 70 else inv_text
        
        print(f"[{domain.upper():15s}] {ku_id}")
        print(f"  {inv_display}")
        print()

def show_stats(atoms_dir: Path):
    """Show statistics about KU library."""
    kus = load_all_kus(atoms_dir)
    
    # Calculate statistics
    domains = defaultdict(int)
    confidence_sum = 0
    confidence_count = 0
    
    for ku in kus:
        domain = ku.get('domain', 'unknown')
        domains[domain] += 1
        
        conf = ku.get('confidence', 0)
        if conf > 0:
            confidence_sum += conf
            confidence_count += 1
    
    avg_confidence = confidence_sum / confidence_count if confidence_count > 0 else 0
    
    print(f"\n{'='*80}")
    print(f"NANOCERN KU LIBRARY STATISTICS")
    print(f"{'='*80}\n")
    print(f"📊 Total KUs: {len(kus)}")
    print(f"📈 Average Confidence: {avg_confidence:.2f}")
    print(f"\n🏷️  KUs by Domain:")
    
    for domain, count in sorted(domains.items(), key=lambda x: x[1], reverse=True):
        percentage = (count / len(kus) * 100) if len(kus) > 0 else 0
        bar_length = int(percentage / 2)
        bar = '█' * bar_length
        print(f"  {domain:15s}: {count:4d} {bar} ({percentage:5.1f}%)")
    
    print(f"\n{'='*80}\n")

def show_domains(atoms_dir: Path):
    """Show all available domains."""
    kus = load_all_kus(atoms_dir)
    domains = set(ku.get('domain', 'unknown') for ku in kus)
    
    print(f"\n{'='*80}")
    print(f"AVAILABLE DOMAINS ({len(domains)} total)")
    print(f"{'='*80}\n")
    
    for domain in sorted(domains):
        count = sum(1 for ku in kus if ku.get('domain', '') == domain)
        print(f"  {domain:20s}: {count:4d} KUs")
    
    print(f"\n{'='*80}\n")

def search_kus(atoms_dir: Path, query: str):
    """Search for KUs by keyword."""
    kus = load_all_kus(atoms_dir)
    query_lower = query.lower()
    
    matches = []
    for ku in kus:
        # Search in ID, domain, invariant
        ku_id = ku.get('id', '').lower()
        domain = ku.get('domain', '').lower()
        invariant = ku.get('invariant', {})
        
        if isinstance(invariant, dict):
            inv_text = (invariant.get('name', '') + ' ' + invariant.get('expression', '')).lower()
        else:
            inv_text = str(invariant).lower()
        
        if query_lower in ku_id or query_lower in domain or query_lower in inv_text:
            matches.append(ku)
    
    print(f"\n{'='*80}")
    print(f"SEARCH RESULTS: '{query}' ({len(matches)} matches)")
    print(f"{'='*80}\n")
    
    for ku in matches:
        domain = ku.get('domain', 'unknown')
        ku_id = ku.get('id', 'UNKNOWN')
        invariant = ku.get('invariant', {})
        
        if isinstance(invariant, dict):
            inv_text = invariant.get('name', invariant.get('expression', ''))
        else:
            inv_text = str(invariant)
        
        print(f"[{domain.upper():15s}] {ku_id}")
        print(f"  {inv_text[:100]}")
        print()

def check_constraint(ku_file: Path, state: dict):
    """Check if a state satisfies KU constraints."""
    ku = load_ku(ku_file)
    
    print(f"\n{'='*80}")
    print(f"CONSTRAINT CHECK: {ku['id']}")
    print(f"{'='*80}\n")
    
    invariant = ku.get('invariant', {})
    if isinstance(invariant, dict):
        inv_text = invariant.get('name', invariant.get('expression', ''))
    else:
        inv_text = str(invariant)
    
    print(f"Invariant: {inv_text}\n")
    
    # Check applicability conditions
    applies_if = ku.get('applies_if', {})
    violations = []
    
    for condition, threshold in applies_if.items():
        if condition in state:
            actual = state[condition]
            print(f"  {condition}: {actual} (expected: {threshold})")
            
            # Enhanced threshold check
            if isinstance(threshold, str):
                if threshold.startswith('>='):
                    expected_val = float(threshold[2:])
                    if actual < expected_val:
                        violations.append(condition)
                elif threshold.startswith('<='):
                    expected_val = float(threshold[2:])
                    if actual > expected_val:
                        violations.append(condition)
                elif threshold.startswith('>'):
                    expected_val = float(threshold[1:])
                    if actual <= expected_val:
                        violations.append(condition)
                elif threshold.startswith('<'):
                    expected_val = float(threshold[1:])
                    if actual >= expected_val:
                        violations.append(condition)
                elif threshold.startswith('=='):
                    expected_val = float(threshold[2:])
                    if actual != expected_val:
                        violations.append(condition)
                elif threshold.startswith('!='):
                    expected_val = float(threshold[2:])
                    if actual == expected_val:
                        violations.append(condition)
            elif isinstance(threshold, (int, float)):
                if actual != threshold:
                    violations.append(condition)
        else:
            print(f"  {condition}: NOT PROVIDED")
            violations.append(condition)
    
    print()
    if violations:
        print(f"❌ CONSTRAINT VIOLATED")
        print(f"   Failed conditions: {', '.join(violations)}")
        
        failure_modes = ku.get('failure_modes', {})
        if isinstance(failure_modes, dict):
            modes = failure_modes.get('conditions', [])
        else:
            modes = failure_modes
        
        if modes:
            print(f"   Failure modes: {', '.join(str(m) for m in modes)}")
    else:
        print(f"✅ CONSTRAINT SATISFIED")
    print()

def main():
    parser = argparse.ArgumentParser(
        description='NanoCERN Knowledge Reactor - Constraint-based reasoning',
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # List all Knowledge Units
  nanocern list
  
  # List KUs in specific domain
  nanocern list --domain physics
  
  # Show statistics
  nanocern stats
  
  # Show available domains
  nanocern domains
  
  # Search for KUs
  nanocern search "gravity"
  
  # Check constraint against state
  nanocern check atoms/PHYS-1_F.ku --state '{"Value >": 1}'
  
  # Validate KU structure
  nanocern validate atoms/PHYS-1_F.ku
        """
    )
    
    parser.add_argument('command', 
                       choices=['list', 'check', 'validate', 'stats', 'domains', 'search'],
                       help='Command to execute')
    parser.add_argument('ku_file', nargs='?', help='Knowledge Unit file or search query')
    parser.add_argument('--state', type=str, help='JSON state to check')
    parser.add_argument('--domain', type=str, help='Filter by domain')
    parser.add_argument('--atoms-dir', type=Path, default=Path('atoms'),
                       help='Directory containing KU files')
    
    args = parser.parse_args()
    
    if args.command == 'list':
        list_kus(args.atoms_dir, args.domain)
    
    elif args.command == 'stats':
        show_stats(args.atoms_dir)
    
    elif args.command == 'domains':
        show_domains(args.atoms_dir)
    
    elif args.command == 'search':
        if not args.ku_file:
            print("Error: search query required")
            sys.exit(1)
        search_kus(args.atoms_dir, args.ku_file)
    
    elif args.command == 'validate':
        if not args.ku_file:
            print("Error: ku_file required for validate command")
            sys.exit(1)
        ku = load_ku(Path(args.ku_file))
        if validate_ku(ku):
            print(f"✅ {args.ku_file} is valid")
        else:
            print(f"❌ {args.ku_file} is invalid")
            sys.exit(1)
    
    elif args.command == 'check':
        if not args.ku_file or not args.state:
            print("Error: ku_file and --state required for check command")
            sys.exit(1)
        state = json.loads(args.state)
        check_constraint(Path(args.ku_file), state)

if __name__ == '__main__':
    main()
