"""
Batch repository for database operations related to batch data and processing.
Handles batch items, approvals, and batch-related queries.
"""

from typing import Dict, List, Optional, Any, Tuple
import logging
from datetime import datetime

from app.repositories.base import BaseRepository, TimestampMixin
from app.core.config import get_settings
from app.core.exceptions import (
    BatchNotFoundError,
    ValidationError,
    DatabaseError
)

logger = logging.getLogger(__name__)

# Columns that data_test has (no week, units, serial_no, status, shift, analyst_status, tl_status)
DATA_TEST_COLUMNS = frozenset({
    "product", "items", "batch_no", "expected_weight", "actual_weight",
    "scan_time", "operator", "analyst", "scan_date", "teamLead", "mdn",
    "date_manufacturer", "expiry_date", "factory", "comment",
})

# Columns sent on batch insert for main `data` table. Approval/rejection columns exist on
# the table but are set only by the approval/rejection API, not during insert.
DATA_TABLE_COLUMNS = frozenset({
    "product", "items", "batch_no", "expected_weight", "actual_weight",
    "scan_time", "supplier", "operator", "analyst", "scan_date", "week",
    "shift", "units", "status", "mdn", "serial_no", "date_manufacturer",
    "expiry_date", "tick_on_order_of_addition", "teamLead", "analyst_status",
    "tl_status", "comment", "approval_date", "analyst_approval_date",
    "tl_approval_date", "factory",
})


class BatchRepository(BaseRepository, TimestampMixin):
    """Repository for batch data operations."""
    
    def __init__(self):
        table = (get_settings().batch_data_table or "data").strip() or "data"
        super().__init__(table)
        self.primary_key = "id"
        if table == "data_test":
            logger.info("[BATCH_REPO] Using test table data_test (config batch_data_table=data_test)")
    
    def create_batch_items(self, batch_items: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        Create multiple batch items in a single transaction.
        
        Args:
            batch_items: List of batch item data
            
        Returns:
            List of created batch items
        """
        try:
            created_items = []
            
            with self.db.transaction() as (cursor, conn):
                for idx, item_data in enumerate(batch_items):
                    # Validate required fields
                    required_fields = ['product', 'items', 'batch_no', 'actual_weight', 'operator', 'analyst', 'teamLead']
                    for field in required_fields:
                        if not item_data.get(field):
                            raise ValidationError(f"Field '{field}' is required", field)
                    
                    # Never overwrite expected_weight/comment if already set; only set when missing
                    if "expected_weight" not in item_data:
                        item_data["expected_weight"] = ""
                    if "comment" not in item_data:
                        item_data["comment"] = None
                    # Set default values for other fields
                    item_data.setdefault("status", "1")
                    item_data.setdefault("analyst_status", "0")
                    item_data.setdefault("tl_status", "0")
                    item_data.setdefault("units", "kg")
                    # Log at debug to avoid noisy production logs
                    logger.debug(
                        "[BATCH_REPO] insert item[%s]: items=%s expected_weight=%s comment=%s",
                        idx, item_data.get("items"), item_data.get("expected_weight"), item_data.get("comment"),
                    )
                    if idx == 0:
                        logger.debug("[BATCH_REPO] INSERT columns: %s", list(item_data.keys()))
                    # Restrict to table schema: main data table uses DATA_TABLE_COLUMNS; data_test uses DATA_TEST_COLUMNS
                    insert_data = item_data
                    if self.table_name == "data":
                        insert_data = {k: v for k, v in item_data.items() if k in DATA_TABLE_COLUMNS}
                    elif self.table_name == "data_test":
                        insert_data = {k: v for k, v in item_data.items() if k in DATA_TEST_COLUMNS}
                    # Insert batch item (columns = all keys in insert_data; must include expected_weight, comment)
                    query, params = self.query_builder.insert(self.table_name, insert_data)
                    cursor.execute(query, params)
                    
                    # Get the inserted item (if auto-increment ID)
                    if cursor.lastrowid:
                        item_data['id'] = cursor.lastrowid
                    
                    created_items.append(item_data)
                
                logger.info(f"Created {len(created_items)} batch items")
                return created_items
                
        except ValidationError:
            raise
        except Exception as e:
            logger.error(f"Error creating batch items: {e}")
            raise DatabaseError(f"Failed to create batch items: {str(e)}")
    
    def get_recent_data_table_rows(self, limit: int = 20) -> List[Dict[str, Any]]:
        """
        Fetch recent rows from the data table (for debugging: expected_weight, comment, etc.).
        Returns rows ordered by id DESC.
        """
        try:
            rows = self.find_all(where=None, order_by='id DESC', limit=limit)
            return rows or []
        except Exception as e:
            logger.error(f"Error fetching recent data table rows: {e}")
            raise DatabaseError(f"Failed to fetch recent data: {str(e)}")
    
    def find_batch_items(
        self,
        serial_no: Optional[str] = None,
        shift: Optional[str] = None,
        operator: Optional[str] = None,
        product: Optional[str] = None,
        batch_no: Optional[str] = None,
        status: Optional[str] = None,
        limit: Optional[int] = None
    ) -> List[Dict[str, Any]]:
        """
        Find batch items with various filters.
        
        Args:
            serial_no: Serial number filter
            shift: Shift filter
            operator: Operator email filter
            product: Product name filter
            batch_no: Batch number filter
            status: Status filter
            limit: Maximum number of results
            
        Returns:
            List of matching batch items
        """
        try:
            where_conditions = {}
            
            if serial_no:
                where_conditions['serial_no'] = serial_no
            if shift:
                where_conditions['shift'] = shift
            if operator:
                where_conditions['operator'] = operator
            if product:
                where_conditions['product'] = product
            if batch_no:
                where_conditions['batch_no'] = batch_no
            if status:
                where_conditions['status'] = status
            
            return self.find_all(
                where=where_conditions,
                order_by="scan_date DESC, scan_time DESC",
                limit=limit
            )
            
        except Exception as e:
            logger.error(f"Error finding batch items: {e}")
            raise DatabaseError(f"Failed to find batch items: {str(e)}")
    
    def find_items_for_approval(
        self,
        user_email: str,
        user_role: str,
        approval_type: str = "analyst",
        areas: Optional[List[str]] = None,
        page: int = 1,
        limit: int = 20
    ) -> Dict[str, Any]:
        """
        Find items pending approval for a specific user based on role and areas with pagination.
        
        Args:
            user_email: User's email address
            user_role: User's role (analyst, team_lead, head_of_analyst, head_of_team_lead, admin, super_admin)
            approval_type: Type of approval (analyst, tl)
            areas: User's areas of operation (NT, Test, BW, PC)
            page: Page number (starts from 1)
            limit: Items per page
            
        Returns:
            Dict with items and pagination info
        """
        try:
            conditions = []
            params = []
            
            if approval_type == "analyst":
                if user_role == "analyst":
                    # Regular analyst: only items assigned to them
                    conditions.append("analyst = %s AND analyst_status = '0'")
                    params.append(user_email)
                elif user_role == "head_of_analyst":
                    # Head of analyst: all analyst pending items in their areas
                    conditions.append("analyst_status = '0'")
                    if areas:
                        area_placeholders = ",".join(["%s"] * len(areas))
                        conditions.append(f"factory IN ({area_placeholders})")
                        params.extend(areas)
                elif user_role in ["admin", "super_admin"]:
                    # Admin/Super admin: all analyst pending items (view only)
                    conditions.append("analyst_status = '0'")
                    
            elif approval_type == "tl":
                if user_role == "team_lead":
                    # Regular team lead: only items assigned to them that are analyst-approved
                    conditions.append("teamLead = %s AND analyst_status = '1' AND tl_status = '0'")
                    params.append(user_email)
                elif user_role == "head_of_team_lead":
                    # Head of team lead: all TL pending items in their areas
                    conditions.append("analyst_status = '1' AND tl_status = '0'")
                    if areas:
                        area_placeholders = ",".join(["%s"] * len(areas))
                        conditions.append(f"factory IN ({area_placeholders})")
                        params.extend(areas)
                elif user_role in ["admin", "super_admin"]:
                    # Admin/Super admin: all TL pending items (view only)
                    conditions.append("analyst_status = '1' AND tl_status = '0'")
            
            if not conditions:
                return {"items": [], "total": 0, "batch_count": 0}
            
            where_clause = " AND ".join(conditions)
            
            # Get total item count
            count_query = f"SELECT COUNT(*) as total FROM `{self.table_name}` WHERE {where_clause}"
            count_result = self.db.execute_query(count_query, tuple(params), fetch_one=True)
            total = count_result['total'] if count_result else 0
            
            # Get distinct batch count (serial_no + operator + shift + batch_no = one submission/batch)
            batch_count_query = f"""
                SELECT COUNT(DISTINCT CONCAT(COALESCE(serial_no,''), '|', COALESCE(operator,''), '|', COALESCE(shift,''), '|', COALESCE(batch_no,''))) as batch_count
                FROM `{self.table_name}` WHERE {where_clause}
            """
            batch_result = self.db.execute_query(batch_count_query, tuple(params), fetch_one=True)
            batch_count = batch_result['batch_count'] if batch_result else 0
            
            # Get paginated items
            offset = (page - 1) * limit
            items_query = f"""
                SELECT * FROM `{self.table_name}` 
                WHERE {where_clause} 
                ORDER BY scan_date DESC, scan_time DESC
                LIMIT %s OFFSET %s
            """
            
            items_params = tuple(params) + (limit, offset)
            items = self.db.execute_query(items_query, items_params, fetch_all=True)
            
            return {
                "items": items or [],
                "total": total,
                "batch_count": batch_count,
            }
            
        except Exception as e:
            logger.error(f"Error finding items for approval: {e}")
            raise DatabaseError(f"Failed to find items for approval: {str(e)}")
    
    def approve_batch(
        self,
        serial_no: str,
        operator: str,
        shift: str,
        approval_type: str,
        approved_by: str,
        batch_no: Optional[str] = None
    ) -> int:
        """
        Approve all items in a batch. Batch = (serial_no, operator, shift) or, when batch_no
        is provided, (serial_no, operator, shift, batch_no) so only one submission is updated.
        """
        try:
            update_data = {}
            if approval_type == "analyst":
                update_data['analyst_status'] = '1'
                update_data['analyst_approved_by'] = approved_by
                update_data['analyst_approved_at'] = datetime.now()
            elif approval_type == "tl":
                update_data['tl_status'] = '1'
                update_data['tl_approved_by'] = approved_by
                update_data['tl_approved_at'] = datetime.now()
            where: Dict[str, Any] = {'serial_no': serial_no, 'operator': operator, 'shift': shift}
            if batch_no is not None and str(batch_no).strip():
                where['batch_no'] = batch_no.strip()
            query, params = self.query_builder.update(self.table_name, update_data, where)
            rows_affected = self.db.execute_update(query, params)
            logger.info(f"Batch {serial_no} ({operator}, {shift}{', batch_no=' + batch_no if where.get('batch_no') else ''}): {rows_affected} items approved ({approval_type}) by {approved_by}")
            return rows_affected
        except Exception as e:
            logger.error(f"Error approving batch: {e}")
            raise DatabaseError(f"Failed to approve batch: {str(e)}")
    
    def reject_batch(
        self,
        serial_no: str,
        operator: str,
        shift: str,
        approval_type: str,
        rejected_by: str,
        reason: Optional[str] = None,
        batch_no: Optional[str] = None
    ) -> int:
        """
        Reject all items in a batch. When batch_no is provided, only that submission is rejected.
        """
        try:
            update_data = {}
            if approval_type == "analyst":
                update_data['analyst_status'] = '-1'
                update_data['analyst_rejected_by'] = rejected_by
                update_data['analyst_rejected_at'] = datetime.now()
                if reason:
                    update_data['analyst_rejection_reason'] = reason
            elif approval_type == "tl":
                update_data['tl_status'] = '-1'
                update_data['tl_rejected_by'] = rejected_by
                update_data['tl_rejected_at'] = datetime.now()
                if reason:
                    update_data['tl_rejection_reason'] = reason
            where: Dict[str, Any] = {'serial_no': serial_no, 'operator': operator, 'shift': shift}
            if batch_no is not None and str(batch_no).strip():
                where['batch_no'] = batch_no.strip()
            query, params = self.query_builder.update(self.table_name, update_data, where)
            rows_affected = self.db.execute_update(query, params)
            logger.info(f"Batch {serial_no} ({operator}, {shift}{', batch_no=' + batch_no if where.get('batch_no') else ''}): {rows_affected} items rejected ({approval_type}) by {rejected_by}")
            return rows_affected
        except Exception as e:
            logger.error(f"Error rejecting batch: {e}")
            raise DatabaseError(f"Failed to reject batch: {str(e)}")
    
    def get_batch_statistics(
        self,
        date_from: Optional[str] = None,
        date_to: Optional[str] = None,
        product: Optional[str] = None
    ) -> Dict[str, Any]:
        """
        Get batch processing statistics.
        
        Args:
            date_from: Start date filter (YYYY-MM-DD)
            date_to: End date filter (YYYY-MM-DD)
            product: Product filter
            
        Returns:
            Statistics dictionary
        """
        try:
            stats = {}
            
            # Build base conditions
            conditions = []
            params = []
            
            if date_from:
                conditions.append("scan_date >= %s")
                params.append(date_from)
            if date_to:
                conditions.append("scan_date <= %s")
                params.append(date_to)
            if product:
                conditions.append("product = %s")
                params.append(product)
            
            where_clause = " WHERE " + " AND ".join(conditions) if conditions else ""
            
            # Total items
            query = f"SELECT COUNT(*) as count FROM `{self.table_name}`{where_clause}"
            result = self.db.execute_query(query, tuple(params), fetch_one=True)
            stats['total_items'] = result['count'] if result else 0
            
            # Items by status
            status_query = f"""
                SELECT 
                    CASE 
                        WHEN analyst_status = '0' THEN 'pending_analyst'
                        WHEN analyst_status = '1' AND tl_status = '0' THEN 'pending_tl'
                        WHEN analyst_status = '1' AND tl_status = '1' THEN 'approved'
                        WHEN analyst_status IN ('-1', '2') OR tl_status IN ('-1', '2') THEN 'rejected'
                        ELSE 'unknown'
                    END as status,
                    COUNT(*) as count
                FROM `{self.table_name}`{where_clause}
                GROUP BY status
            """
            status_results = self.db.execute_query(status_query, tuple(params), fetch_all=True) or []
            stats['items_by_status'] = {row['status']: row['count'] for row in status_results}
            
            # Items by product
            product_query = f"""
                SELECT product, COUNT(*) as count 
                FROM `{self.table_name}`{where_clause}
                GROUP BY product 
                ORDER BY count DESC
            """
            product_results = self.db.execute_query(product_query, tuple(params), fetch_all=True) or []
            stats['items_by_product'] = {row['product']: row['count'] for row in product_results}
            
            # Items by operator
            operator_query = f"""
                SELECT operator, COUNT(*) as count 
                FROM `{self.table_name}`{where_clause}
                GROUP BY operator 
                ORDER BY count DESC
                LIMIT 10
            """
            operator_results = self.db.execute_query(operator_query, tuple(params), fetch_all=True) or []
            stats['items_by_operator'] = {row['operator']: row['count'] for row in operator_results}
            
            return stats
            
        except Exception as e:
            logger.error(f"Error getting batch statistics: {e}")
            raise DatabaseError(f"Failed to get batch statistics: {str(e)}")
    
    def find_processed_data_paginated(
        self,
        user_role: str,
        user_email: str,
        user_areas: List[str],
        serial_no: Optional[str] = None,
        shift: Optional[str] = None,
        operator: Optional[str] = None,
        approval_status: Optional[str] = None,
        page: int = 1,
        limit: int = 20
    ) -> Dict[str, Any]:
        """
        Find processed batch data with role-based filtering and pagination.
        
        Args:
            user_role: User's role
            user_email: User's email
            user_areas: User's areas of operation
            serial_no: Optional serial number filter
            shift: Optional shift filter
            operator: Optional operator filter
            approval_status: Optional approval status filter
            page: Page number
            limit: Items per page
            
        Returns:
            Dict with items and total count
        """
        try:
            conditions = []
            params = []
            
            # Apply role-based filtering
            if user_role in ['admin', 'super_admin']:
                # Admin can see all data
                pass
            elif user_role in ['head_of_analyst', 'head_of_team_lead']:
                # Head roles can see data in their areas
                if user_areas:
                    area_placeholders = ",".join(["%s"] * len(user_areas))
                    conditions.append(f"factory IN ({area_placeholders})")
                    params.extend(user_areas)
            elif user_role in ['analyst', 'team_lead']:
                # Regular roles see only assigned items
                if user_role == 'analyst':
                    conditions.append("analyst = %s")
                    params.append(user_email)
                else:
                    conditions.append("teamLead = %s")
                    params.append(user_email)
            elif user_role == 'operator':
                # Operators see only their own data
                conditions.append("operator = %s")
                params.append(user_email)
            else:
                # Default: no access
                conditions.append("1 = 0")
            
            # Apply optional filters
            if serial_no:
                conditions.append("serial_no = %s")
                params.append(serial_no)
            if shift:
                conditions.append("shift = %s")
                params.append(shift)
            if operator:
                conditions.append("operator = %s")
                params.append(operator)
            
            # Apply approval status filter
            if approval_status == "approved":
                conditions.append("analyst_status = '1' AND tl_status = '1'")
            elif approval_status == "not_approved":
                conditions.append("NOT (analyst_status = '1' AND tl_status = '1')")
            elif approval_status == "rejected":
                conditions.append("(analyst_status IN ('-1', '2') OR tl_status IN ('-1', '2'))")
            elif approval_status == "pending":
                conditions.append("analyst_status = '0' OR (analyst_status = '1' AND tl_status = '0')")
            # If approval_status is "all" or None, don't add any status filter
            
            where_clause = " AND ".join(conditions) if conditions else "1 = 1"
            
            # Get total count
            count_query = f"SELECT COUNT(*) as total FROM `{self.table_name}` WHERE {where_clause}"
            count_result = self.db.execute_query(count_query, tuple(params), fetch_one=True)
            total = count_result['total'] if count_result else 0
            
            # Get paginated items
            offset = (page - 1) * limit
            items_query = f"""
                SELECT * FROM `{self.table_name}` 
                WHERE {where_clause} 
                ORDER BY scan_date DESC, scan_time DESC
                LIMIT %s OFFSET %s
            """
            
            items_params = tuple(params) + (limit, offset)
            items = self.db.execute_query(items_query, items_params, fetch_all=True)
            
            return {
                "items": items or [],
                "total": total
            }
            
        except Exception as e:
            logger.error(f"Error finding processed data: {e}")
            raise DatabaseError(f"Failed to find processed data: {str(e)}")
    
    def find_processed_data(
        self,
        serial_no: str,
        shift: str,
        operator: str
    ) -> List[Dict[str, Any]]:
        """
        Find processed batch data for reporting (legacy method).
        
        Args:
            serial_no: Serial number
            shift: Shift name
            operator: Operator email
            
        Returns:
            List of processed batch items
        """
        try:
            # Find items that are fully approved (both analyst and TL)
            query = """
                SELECT * FROM `data` 
                WHERE serial_no = %s AND shift = %s AND operator = %s 
                AND analyst_status = '1' AND tl_status = '1'
                ORDER BY scan_time ASC
            """
            
            result = self.db.execute_query(query, (serial_no, shift, operator), fetch_all=True)
            return result or []
            
        except Exception as e:
            logger.error(f"Error finding processed data: {e}")
            raise DatabaseError(f"Failed to find processed data: {str(e)}")
    
    def update_batch_item(
        self,
        item_id: int,
        update_data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Update a batch item.
        
        Args:
            item_id: Batch item ID
            update_data: Data to update
            
        Returns:
            Updated batch item
        """
        try:
            # Check if item exists
            existing_item = self.find_by_id(item_id)
            if not existing_item:
                raise BatchNotFoundError(str(item_id))
            
            # Update item
            updated_item = self.update(update_data, {'id': item_id})
            
            logger.info(f"Batch item {item_id} updated")
            return updated_item
            
        except BatchNotFoundError:
            raise
        except Exception as e:
            logger.error(f"Error updating batch item {item_id}: {e}")
            raise DatabaseError(f"Failed to update batch item: {str(e)}")


class ProductRepository(BaseRepository):
    """Repository for product-related operations."""
    
    def __init__(self):
        super().__init__("products")
        self.primary_key = "id"
    
    def get_distinct_sections(self) -> List[str]:
        """
        Get distinct section (warehouse/area) values from products table.
        Used for user area-of-operation assignment (same as PHP get_all_sections).
        """
        try:
            query = """
                SELECT DISTINCT section FROM `products`
                WHERE section IS NOT NULL AND TRIM(section) != ''
                ORDER BY section
            """
            rows = self.db.execute_query(query, fetch_all=True) or []
            return [str(row["section"]).strip() for row in rows if row.get("section")]
        except Exception as e:
            logger.warning(f"Could not fetch distinct sections (products.section): {e}")
            return []

    def find_by_name(self, product_name: str) -> Optional[Dict[str, Any]]:
        """Find product by name."""
        return self.find_one({'product_name': product_name})
    
    def get_product_ingredients(self, product_name: str) -> List[Dict[str, Any]]:
        """
        Get ingredients for a product from its dedicated table.
        
        Args:
            product_name: Product name (also table name)
            
        Returns:
            List of ingredients with expected weights
        """
        try:
            # Validate table name to prevent SQL injection
            if not product_name.replace('_', '').isalnum():
                raise ValidationError("Invalid product name", "product_name")
            
            query = f"SELECT ingredient, expected_weight, weight_variance, to_be_measured, mdn, batch_no FROM `{product_name}`"
            result = self.db.execute_query(query, fetch_all=True)
            return result or []
            
        except Exception as e:
            logger.error(f"Error getting product ingredients for {product_name}: {e}")
            raise DatabaseError(f"Failed to get product ingredients: {str(e)}")
    
    def update_product_ingredient(
        self,
        product_name: str,
        ingredient_id: int,
        update_data: Dict[str, Any]
    ) -> Dict[str, Any]:
        """
        Update a product ingredient.
        
        Args:
            product_name: Product name (table name)
            ingredient_id: Ingredient ID
            update_data: Data to update
            
        Returns:
            Updated ingredient data
        """
        try:
            # Validate table name
            if not product_name.replace('_', '').isalnum():
                raise ValidationError("Invalid product name", "product_name")
            
            # Validate allowed columns
            allowed_columns = ['ingredient', 'expected_weight', 'weight_variance', 'to_be_measured', 'mdn', 'batch_no']
            for column in update_data.keys():
                if column not in allowed_columns:
                    raise ValidationError(f"Column '{column}' is not allowed for update", column)
            
            # Build update query
            set_clauses = []
            params = []
            
            for key, value in update_data.items():
                set_clauses.append(f"`{key}` = %s")
                params.append(value)
            
            params.append(ingredient_id)
            
            query = f"UPDATE `{product_name}` SET {', '.join(set_clauses)} WHERE id = %s"
            
            rows_affected = self.db.execute_update(query, tuple(params))
            
            if rows_affected == 0:
                raise ValidationError("No rows were updated", "ingredient_id")
            
            # Return updated data
            select_query = f"SELECT * FROM `{product_name}` WHERE id = %s"
            result = self.db.execute_query(select_query, (ingredient_id,), fetch_one=True)
            
            return result or update_data
            
        except ValidationError:
            raise
        except Exception as e:
            logger.error(f"Error updating product ingredient: {e}")
            raise DatabaseError(f"Failed to update product ingredient: {str(e)}")


# Global repository instances
batch_repository = BatchRepository()
product_repository = ProductRepository()