#!/usr/bin/env python3
"""
Retail Store People Counter with OpenCV
=======================================
A computer vision system for analyzing retail store videos to:
- Count customers entering and exiting
- Track customer movement patterns
- Detect dwell times in different zones
- Generate real statistics and heatmaps

Requirements:
- OpenCV (cv2)
- NumPy
- Optional: imutils for utilities

Installation:
    pip install opencv-python numpy imutils
"""

import cv2
import numpy as np
import argparse
import time
import json
import csv
from datetime import datetime
from collections import defaultdict, deque
import os
from pathlib import Path

class RetailPeopleCounter:
    """
    Main class for retail store people counting and tracking.
    """
    
    def __init__(self, config=None):
        """
        Initialize the people counter with configuration.
        
        Args:
            config (dict): Configuration dictionary with parameters
        """
        # Default configuration
        self.config = {
            # Video processing
            'skip_frames': 2,
            'min_area': 500,
            'max_area': 50000,
            
            # Tracking
            'max_disappeared': 50,
            'max_distance': 50,
            
            # Zones and lines
            'entry_line_y': 0.3,
            'exit_line_y': 0.7,
            'dwell_zones': [],
            
            # Display
            'show_video': True,
            'show_tracks': True,
            'show_zones': True,
            
            # Output
            'output_video': False,
            'output_stats': True,
            'output_heatmap': True,
        }
        
        if config:
            self.config.update(config)
        
        # State variables
        self.total_entries = 0
        self.total_exits = 0
        self.current_count = 0
        self.max_concurrent = 0
        
        # Tracking
        self.trackers = {}
        self.disappeared = {}
        self.next_object_id = 0
        
        # Dwell time tracking
        self.entry_times = {}
        self.zone_times = {}
        self.dwell_stats = defaultdict(list)
        
        # Video properties
        self.video_width = 0
        self.video_height = 0
        self.fps = 0
        
        # Background subtractor
        self.bg_subtractor = cv2.createBackgroundSubtractorMOG2(
            history=500, varThreshold=16, detectShadows=True
        )
    
    def detect_people(self, frame):
        """
        Detect people in a frame using background subtraction.
        """
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        blurred = cv2.GaussianBlur(gray, (7, 7), 0)
        
        fg_mask = self.bg_subtractor.apply(blurred)
        _, thresh = cv2.threshold(fg_mask, 25, 255, cv2.THRESH_BINARY)
        
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
        thresh = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
        thresh = cv2.morphologyEx(thresh, cv2.MORPH_OPEN, kernel)
        
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        boxes = []
        for contour in contours:
            area = cv2.contourArea(contour)
            if area < self.config['min_area'] or area > self.config['max_area']:
                continue
            
            x, y, w, h = cv2.boundingRect(contour)
            aspect_ratio = w / float(h)
            if aspect_ratio < 0.2 or aspect_ratio > 1.5:
                continue
            
            boxes.append((x, y, w, h))
        
        return boxes
    
    def update_trackers(self, boxes):
        """
        Update object trackers with new detections.
        """
        if len(boxes) == 0:
            for object_id in list(self.disappeared.keys()):
                self.disappeared[object_id] += 1
                if self.disappeared[object_id] > self.config['max_disappeared']:
                    self._remove_tracker(object_id)
            return self.trackers
        
        centroids = np.zeros((len(boxes), 2), dtype="int")
        for (i, (x, y, w, h)) in enumerate(boxes):
            cx = x + w // 2
            cy = y + h // 2
            centroids[i] = (cx, cy)
        
        if len(self.trackers) == 0:
            for i in range(len(boxes)):
                self._register_new_tracker(centroids[i])
        else:
            object_ids = list(self.trackers.keys())
            object_centroids = [self.trackers[obj_id][-1] for obj_id in object_ids]
            
            object_centroids = np.array(object_centroids)
            D = np.linalg.norm(object_centroids[:, np.newaxis] - centroids, axis=2)
            
            rows = D.min(axis=1).argsort()
            cols = D.argmin(axis=1)[rows]
            
            used_rows = set()
            used_cols = set()
            
            for (row, col) in zip(rows, cols):
                if row in used_rows or col in used_cols:
                    continue
                if D[row, col] > self.config['max_distance']:
                    continue
                
                object_id = object_ids[row]
                self.trackers[object_id].append(centroids[col])
                self.disappeared[object_id] = 0
                used_rows.add(row)
                used_cols.add(col)
            
            unused_rows = set(range(0, D.shape[0])).difference(used_rows)
            for row in unused_rows:
                object_id = object_ids[row]
                self.disappeared[object_id] += 1
                if self.disappeared[object_id] > self.config['max_disappeared']:
                    self._remove_tracker(object_id)
            
            unused_cols = set(range(0, D.shape[1])).difference(used_cols)
            for col in unused_cols:
                self._register_new_tracker(centroids[col])
        
        return self.trackers
    
    def _register_new_tracker(self, centroid):
        """Register a new object tracker."""
        self.trackers[self.next_object_id] = [centroid]
        self.disappeared[self.next_object_id] = 0
        self.entry_times[self.next_object_id] = time.time()
        self.zone_times[self.next_object_id] = {}
        self.next_object_id += 1
    
    def _remove_tracker(self, object_id):
        """Remove an object tracker and calculate dwell times."""
        if object_id in self.entry_times:
            entry_time = self.entry_times[object_id]
            dwell_time = time.time() - entry_time
            self.dwell_stats['total'].append(dwell_time)
            del self.entry_times[object_id]
        
        if object_id in self.zone_times:
            for zone_id, times in self.zone_times[object_id].items():
                if len(times) >= 2:
                    zone_dwell = times[-1] - times[0]
                    self.dwell_stats[zone_id].append(zone_dwell)
            del self.zone_times[object_id]
        
        if object_id in self.trackers:
            del self.trackers[object_id]
        if object_id in self.disappeared:
            del self.disappeared[object_id]
    
    def check_crossings(self, frame_num):
        """
        Check if objects have crossed entry/exit lines.
        """
        entry_line = int(self.video_height * self.config['entry_line_y'])
        exit_line = int(self.video_height * self.config['exit_line_y'])
        
        for object_id, centroids in list(self.trackers.items()):
            if len(centroids) < 2:
                continue
            
            prev_centroid = centroids[-2]
            curr_centroid = centroids[-1]
            
            if (prev_centroid[1] > entry_line and curr_centroid[1] <= entry_line):
                self.total_entries += 1
                self.current_count += 1
                self.max_concurrent = max(self.max_concurrent, self.current_count)
                print(f"Frame {frame_num}: Person {object_id} entered")
            
            elif (prev_centroid[1] < exit_line and curr_centroid[1] >= exit_line):
                self.total_exits += 1
                self.current_count -= 1
                print(f"Frame {frame_num}: Person {object_id} exited")
    
    def process_video(self, video_path):
        """
        Process a video file and generate statistics.
        """
        print(f"Processing video: {video_path}")
        
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error: Cannot open video {video_path}")
            return False
        
        self.video_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        self.video_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        self.fps = int(cap.get(cv2.CAP_PROP_FPS))
        
        print(f"Video: {self.video_width}x{self.video_height}, FPS: {self.fps}")
        
        frame_count = 0
        start_time = time.time()
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            frame_count += 1
            
            if frame_count % self.config['skip_frames'] != 0:
                continue
            
            boxes = self.detect_people(frame)
            self.update_trackers(boxes)
            self.check_crossings(frame_count)
            
            if self.config['show_video']:
                display_frame = self.draw_overlays(frame, frame_count)
                cv2.imshow('Retail People Counter', display_frame)
                
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break
            
            if frame_count % 100 == 0:
                print(f"Processed {frame_count} frames - Current: {self.current_count}")
        
        cap.release()
        cv2.destroyAllWindows()
        
        processing_time = time.time() - start_time
        print(f"\nProcessing complete!")
        print(f"Total frames: {frame_count}")
        print(f"Processing time: {processing_time:.2f} seconds")
        
        self.generate_report(video_path, frame_count, processing_time)
        
        return True
    
    def draw_overlays(self, frame, frame_num):
        """
        Draw overlays on the frame for visualization.
        """
        overlay = frame.copy()
        
        entry_line = int(self.video_height * self.config['entry_line_y'])
        exit_line = int(self.video_height * self.config['exit_line_y'])
        
        cv2.line(overlay, (0, entry_line), (self.video_width, entry_line),
                (0, 255, 0), 2)
        cv2.line(overlay, (0, exit_line), (self.video_width, exit_line),
                (0, 0, 255), 2)
        
        cv2.putText(overlay, "ENTRY", (10, entry_line - 10),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
        cv2.putText(overlay, "EXIT", (10, exit_line + 20),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
        
        if self.config['show_tracks']:
            for object_id, centroids in self.trackers.items():
                if len(centroids) > 0:
                    cx, cy = centroids[-1]
                    cv2.circle(overlay, (cx, cy), 5, (0, 0, 255), -1)
                    cv2.putText(overlay, str(object_id), (cx - 10, cy - 10),
                               cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 2)
        
        stats_text = f"Frame: {frame_num} | Current: {self.current_count} | Entries: {self.total_entries} | Exits: {self.total_exits}"
        cv2.putText(overlay, stats_text, (10, 30),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        return overlay
    
    def generate_report(self, video_path, total_frames, processing_time):
        """
        Generate comprehensive analysis report.
        """
        report = {
            "analysis_info": {
                "video_file": os.path.basename(video_path),
                "analysis_date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
                "processing_time_seconds": round(processing_time, 2),
                "frames_processed": total_frames,
                "processing_fps": round(total_frames / processing_time, 2)
            },
            "video_properties": {
                "width": self.video_width,
                "height": self.video_height,
                "fps": self.fps,
                "entry_line": int(self.video_height * self.config['entry_line_y']),
                "exit_line": int(self.video_height * self.config['exit_line_y'])
            },
            "people_counting_results": {
                "total_customers_entered": self.total_entries,
                "total_customers_exited": self.total_exits,
                "customers_in_store_at_end": self.current_count,
                "maximum_concurrent_customers": self.max_concurrent,
                "total_objects_tracked": self.next_object_id
            },
            "dwell_time_statistics": {
                "total_customers_with_dwell_time": len(self.dwell_stats.get('total', [])),
                "average_dwell_time_seconds": np.mean(self.dwell_stats.get('total', [0])) if self.dwell_stats.get('total') else 0,
                "median_dwell_time_seconds": np.median(self.dwell_stats.get('total', [0])) if self.dwell_stats.get('total') else 0,
                "min_dwell_time_seconds": np.min(self.dwell_stats.get('total', [0])) if self.dwell_stats.get('total') else 0,
                "max_dwell_time_seconds": np.max(self.dwell_stats.get('total', [0])) if self.dwell_stats.get('total') else 0
            }
        }
        
        # Save report
        output_dir = "analysis_results"
        os.makedirs(output_dir, exist_ok=True)
        
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        report_path = os.path.join(output_dir, f"analysis_report_{timestamp}.json")
        
        with open(report_path, 'w') as f:
            json.dump(report, f, indent=2)
        
        print(f"\nReport saved to: {report_path}")
        
        # Print summary
        print("\n" + "="*50)
        print("ANALYSIS SUMMARY")
        print("="*50)
        print(f"Total customers entered: {self.total_entries}")
        print(f"Total customers exited: {self.total_exits}")
        print(f"Maximum concurrent customers: {self.max_concurrent}")
        print(f"Customers in store at end: {self.current_count}")
        print(f"Average dwell time: {report['dwell_time_statistics']['average_dwell_time_seconds']:.1f} seconds")
        print(f"Processing speed: {report['analysis_info']['processing_fps']:.1f} FPS")
        print("="*50)
        
        return report_path

def main():
    parser = argparse.ArgumentParser(description='Retail Store People Counter')
    parser.add_argument('video', help='Path to input video file')
    parser.add_argument('--skip-frames', type=int, default=2, help='Process every Nth frame')
    parser.add_argument('--no-display', action='store_true', help='Disable video display')
    parser.add_argument('--entry-line', type=float, default=0.3, help='Entry line position (0-1)')
    parser.add_argument('--exit-line', type=float, default=0.7, help='Exit line position (0-1)')
    
    args = parser.parse_args()
    
    config = {
        'skip_frames': args.skip_frames,
        'show_video': not args.no_display,
        'entry_line_y': args.entry_line,
        'exit_line_y': args.exit_line,
    }
    
    print("Retail Store People Counter")
    print("="*50)
    
    counter = RetailPeopleCounter(config)
    counter.process_video(args.video)

if __name__ == "__main__":
    main()