#!/usr/bin/env python3
"""
Extract text from PowerPoint (PPTX) presentations.
"""

import sys
import json
from pathlib import Path

def extract_pptx_text(pptx_path):
    """Extract all text from a PPTX file."""
    try:
        from pptx import Presentation
    except ImportError:
        print("ERROR: python-pptx library not installed.")
        print("Install with: pip install python-pptx")
        return None
    
    try:
        prs = Presentation(pptx_path)
        presentation_data = {
            "filename": Path(pptx_path).name,
            "total_slides": len(prs.slides),
            "slides": []
        }
        
        for i, slide in enumerate(prs.slides):
            slide_data = {
                "slide_number": i + 1,
                "title": "",
                "text_content": [],
                "shapes": len(slide.shapes)
            }
            
            # Extract title if available
            if slide.shapes.title:
                slide_data["title"] = slide.shapes.title.text
            
            # Extract all text from shapes
            for shape in slide.shapes:
                if hasattr(shape, "text"):
                    text = shape.text.strip()
                    if text and text != slide_data.get("title", ""):
                        slide_data["text_content"].append(text)
            
            presentation_data["slides"].append(slide_data)
        
        return presentation_data
        
    except Exception as e:
        print(f"ERROR processing {pptx_path}: {e}")
        return None

def main():
    if len(sys.argv) < 2:
        print("Usage: python3 extract_pptx_text.py <presentation.pptx> [output.json]")
        sys.exit(1)
    
    pptx_path = sys.argv[1]
    output_path = sys.argv[2] if len(sys.argv) > 2 else None
    
    if not Path(pptx_path).exists():
        print(f"ERROR: File not found: {pptx_path}")
        sys.exit(1)
    
    data = extract_pptx_text(pptx_path)
    
    if data:
        if output_path:
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(data, f, indent=2, ensure_ascii=False)
            print(f"Text extracted and saved to: {output_path}")
        else:
            # Print summary
            print(f"Presentation: {data['filename']}")
            print(f"Total slides: {data['total_slides']}")
            print("\nSlide Summary:")
            for slide in data['slides']:
                title = slide['title'] or "No title"
                print(f"  Slide {slide['slide_number']}: {title}")
                if slide['text_content']:
                    for text in slide['text_content'][:3]:  # Show first 3 text items
                        print(f"    - {text[:80]}..." if len(text) > 80 else f"    - {text}")
                    if len(slide['text_content']) > 3:
                        print(f"    ... and {len(slide['text_content']) - 3} more items")
                print()
    else:
        sys.exit(1)

if __name__ == "__main__":
    main()