"""
Photo QR Cloud Server — SaaS Edition
Multi-user, multi-event photo hosting with admin panel and user portal.
Flask WSGI — runs natively with cPanel Passenger.
"""

import os
import sys
import uuid
import io
import time
import logging
import threading
import queue
import pymysql
import bcrypt
import jwt as pyjwt
from datetime import datetime, timedelta, timezone
from functools import wraps
from pathlib import Path
from contextlib import contextmanager
from PIL import Image

from flask import (
    Flask, request, jsonify, render_template,
    send_file, abort, session, redirect, url_for, g
)
from dotenv import load_dotenv

# ---------------------------------------------------------------------------
# Setup
# ---------------------------------------------------------------------------
APP_DIR = Path(__file__).parent.resolve()
load_dotenv(APP_DIR / ".env")

logging.basicConfig(level=logging.INFO, stream=sys.stderr)
logger = logging.getLogger("photo-qr")

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
DB_HOST     = os.getenv("DB_HOST", "localhost")
DB_USER     = os.getenv("DB_USER", "root")
DB_NAME     = os.getenv("DB_NAME", "booth")
DB_PASSWORD = os.getenv("DB_PASSWORD", "")
DB_PORT     = int(os.getenv("DB_PORT", "3306"))

API_KEY       = os.getenv("API_KEY", "change-this-key")
JWT_SECRET    = os.getenv("JWT_SECRET", API_KEY + "-jwt")
FLASK_SECRET  = os.getenv("FLASK_SECRET", API_KEY + "-session")

ADMIN_EMAIL    = os.getenv("ADMIN_EMAIL", "admin@example.com")
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD", "changeme")

_upload_dir = os.getenv("UPLOAD_DIR", "uploads")
UPLOAD_DIR  = Path(_upload_dir) if os.path.isabs(_upload_dir) else APP_DIR / _upload_dir
BASE_URL    = os.getenv("BASE_URL", "http://localhost:8000").rstrip("/")
EXPIRY_HOURS = 72   # 3 days — long enough for event guests to scan

ALLOWED_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".bmp", ".tiff", ".gif"}
MAX_UPLOAD_MB = 50  # reject uploads larger than this

try:
    UPLOAD_DIR.mkdir(parents=True, exist_ok=True)
except Exception as e:
    logger.warning(f"Could not create upload dir: {e}")

THUMB_CACHE_DIR = UPLOAD_DIR / "_thumbs"
try:
    THUMB_CACHE_DIR.mkdir(parents=True, exist_ok=True)
except Exception as e:
    logger.warning(f"Could not create thumb cache dir: {e}")

# ---------------------------------------------------------------------------
# Flask App
# ---------------------------------------------------------------------------
app = Flask(__name__, template_folder=str(APP_DIR / "templates"))
app.secret_key = FLASK_SECRET
app.config["MAX_CONTENT_LENGTH"] = MAX_UPLOAD_MB * 1024 * 1024  # reject large uploads early

# ---------------------------------------------------------------------------
# Database helpers
# ---------------------------------------------------------------------------

# ---------------------------------------------------------------------------
# Simple connection pool — reuse connections instead of open/close each request
# ---------------------------------------------------------------------------

class ConnectionPool:
    def __init__(self, max_size=10):
        self._pool = queue.Queue(maxsize=max_size)
        self._max_size = max_size
        self._lock = threading.Lock()
        self._created = 0

    def _create_conn(self):
        return pymysql.connect(
            host=DB_HOST, user=DB_USER, password=DB_PASSWORD,
            database=DB_NAME, port=DB_PORT,
            charset="utf8mb4", cursorclass=pymysql.cursors.DictCursor,
        )

    def get(self):
        try:
            conn = self._pool.get_nowait()
            conn.ping(reconnect=True)  # check if still alive
            return conn
        except queue.Empty:
            pass
        with self._lock:
            if self._created < self._max_size:
                self._created += 1
                return self._create_conn()
        # Pool exhausted — block until one is returned
        return self._pool.get(timeout=10)

    def put(self, conn):
        try:
            self._pool.put_nowait(conn)
        except queue.Full:
            conn.close()


_pool = ConnectionPool(max_size=10)


def get_db():
    return _pool.get()


@contextmanager
def db_connection():
    conn = get_db()
    try:
        yield conn
        conn.commit()
    except Exception:
        conn.rollback()
        raise
    finally:
        _pool.put(conn)


# ---------------------------------------------------------------------------
# Password helpers (must be defined BEFORE init_db runs)
# ---------------------------------------------------------------------------

def hash_password(plain: str) -> str:
    return bcrypt.hashpw(plain.encode(), bcrypt.gensalt()).decode()


def check_password(plain: str, hashed: str) -> bool:
    try:
        return bcrypt.checkpw(plain.encode(), hashed.encode())
    except Exception:
        return False


def init_db():
    with db_connection() as conn:
        with conn.cursor() as cur:
            # Admins table (credentials stored in DB, not .env)
            cur.execute("""
                CREATE TABLE IF NOT EXISTS admins (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    email VARCHAR(255) UNIQUE NOT NULL,
                    password_hash VARCHAR(255) NOT NULL,
                    created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    INDEX idx_admin_email (email)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
            """)

            # Users table
            cur.execute("""
                CREATE TABLE IF NOT EXISTS users (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    email VARCHAR(255) UNIQUE NOT NULL,
                    password_hash VARCHAR(255) NOT NULL,
                    is_active TINYINT(1) DEFAULT 1,
                    device_id VARCHAR(64) NULL,
                    last_login_at DATETIME NULL,
                    created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    INDEX idx_email (email)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
            """)

            # Events table
            cur.execute("""
                CREATE TABLE IF NOT EXISTS events (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    user_id INT NOT NULL,
                    name VARCHAR(255) NOT NULL,
                    created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    INDEX idx_user (user_id),
                    FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
            """)

            # Photos table (original + new columns)
            cur.execute("""
                CREATE TABLE IF NOT EXISTS photos (
                    id INT AUTO_INCREMENT PRIMARY KEY,
                    token VARCHAR(64) UNIQUE NOT NULL,
                    original_filename VARCHAR(512) NOT NULL,
                    stored_filename VARCHAR(512) NOT NULL,
                    file_size BIGINT DEFAULT 0,
                    created_at DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP,
                    expires_at DATETIME NOT NULL,
                    download_count INT DEFAULT 0,
                    is_expired TINYINT(1) DEFAULT 0,
                    user_id INT NULL,
                    event_id INT NULL,
                    INDEX idx_token (token),
                    INDEX idx_expires (expires_at),
                    INDEX idx_event (event_id),
                    INDEX idx_user (user_id)
                ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4
            """)

            # Add columns to existing tables (idempotent)
            for col, defn in [("user_id", "INT NULL"), ("event_id", "INT NULL")]:
                try:
                    cur.execute(f"ALTER TABLE photos ADD COLUMN {col} {defn}")
                except Exception:
                    pass  # column already exists

            for col, defn in [("device_id", "VARCHAR(64) NULL"), ("last_login_at", "DATETIME NULL")]:
                try:
                    cur.execute(f"ALTER TABLE users ADD COLUMN {col} {defn}")
                except Exception:
                    pass  # column already exists

            # Add share_token column to events (for shareable album links)
            try:
                cur.execute("ALTER TABLE events ADD COLUMN share_token VARCHAR(64) NULL")
            except Exception:
                pass  # column already exists
            try:
                cur.execute("CREATE UNIQUE INDEX idx_share_token ON events (share_token)")
            except Exception:
                pass  # index already exists

            # Seed default admin from .env on first run (if no admins exist)
            cur.execute("SELECT COUNT(*) as cnt FROM admins")
            if cur.fetchone()["cnt"] == 0 and ADMIN_EMAIL and ADMIN_PASSWORD:
                cur.execute(
                    "INSERT INTO admins (email, password_hash) VALUES (%s, %s)",
                    (ADMIN_EMAIL, hash_password(ADMIN_PASSWORD)),
                )
                logger.info(f"Seeded default admin: {ADMIN_EMAIL}")


try:
    init_db()
    logger.info("Database initialized successfully")
except Exception as e:
    logger.error(f"Database init failed: {e}")

# ---------------------------------------------------------------------------
# Auth helpers — API key (legacy) + JWT (users) + session (admin/portal)
# ---------------------------------------------------------------------------

def require_api_key(f):
    """Legacy API-key auth for desktop uploader (still supported)."""
    @wraps(f)
    def decorated(*args, **kwargs):
        api_key = request.headers.get("X-API-Key", "")
        if api_key != API_KEY:
            return jsonify({"error": "Invalid API key"}), 401
        return f(*args, **kwargs)
    return decorated


def _decode_jwt(token: str):
    try:
        return pyjwt.decode(token, JWT_SECRET, algorithms=["HS256"])
    except pyjwt.ExpiredSignatureError:
        return None
    except pyjwt.InvalidTokenError:
        return None


def require_user_jwt(f):
    """JWT auth for user API endpoints — also verifies device_id matches DB."""
    @wraps(f)
    def decorated(*args, **kwargs):
        auth = request.headers.get("Authorization", "")
        token = auth.removeprefix("Bearer ").strip()
        if not token:
            # Also accept from query string (for portal thumbnails etc.)
            token = request.args.get("token", "")
        payload = _decode_jwt(token)
        if not payload:
            return jsonify({"error": "Unauthorized"}), 401
        g.user_id = payload["sub"]
        g.user_email = payload.get("email", "")

        # Verify device_id still matches DB (admin force-logout clears it)
        jwt_device = payload.get("device_id")
        if jwt_device:
            with db_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute("SELECT device_id FROM users WHERE id = %s", (g.user_id,))
                    row = cur.fetchone()
                    if row and row["device_id"] != jwt_device:
                        return jsonify({"error": "Session invalidated. Please login again.", "code": "device_mismatch"}), 401

        return f(*args, **kwargs)
    return decorated


def require_admin_session(f):
    """Admin panel session guard."""
    @wraps(f)
    def decorated(*args, **kwargs):
        if not session.get("admin_logged_in"):
            return redirect(url_for("admin_login"))
        return f(*args, **kwargs)
    return decorated


def require_portal_session(f):
    """User portal session guard."""
    @wraps(f)
    def decorated(*args, **kwargs):
        if not session.get("portal_user_id"):
            return redirect(url_for("portal_login"))
        g.user_id = session["portal_user_id"]
        g.user_email = session.get("portal_email", "")
        return f(*args, **kwargs)
    return decorated


# ---------------------------------------------------------------------------
# Public routes
# ---------------------------------------------------------------------------

@app.route("/")
def root():
    return jsonify({
        "app": "Photo QR Cloud Server (SaaS)",
        "status": "running",
    })


# Health check with simple time-based cache (avoid DB hit on every poll)
_health_cache = {"result": None, "ts": 0}
_HEALTH_TTL = 30  # seconds

@app.route("/health")
def health_check():
    now = time.time()
    if _health_cache["result"] and (now - _health_cache["ts"]) < _HEALTH_TTL:
        return _health_cache["result"]
    try:
        with db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT 1")
        resp = jsonify({"status": "healthy", "database": "connected"})
        _health_cache["result"] = resp
        _health_cache["ts"] = now
        return resp
    except Exception as e:
        _health_cache["result"] = None
        return jsonify({"status": "unhealthy", "error": str(e)}), 503


# ---------------------------------------------------------------------------
# Download page (public — QR scan lands here)
# ---------------------------------------------------------------------------

@app.route("/download/<token>")
def download_page(token):
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM photos WHERE token = %s", (token,))
            photo = cur.fetchone()

    if not photo:
        return render_template("download.html", error="not_found", photo=None)

    expires_at = photo["expires_at"]
    if isinstance(expires_at, str):
        expires_at = datetime.fromisoformat(expires_at)
    if expires_at.tzinfo is None:
        expires_at = expires_at.replace(tzinfo=timezone.utc)

    if datetime.now(timezone.utc) > expires_at:
        return render_template("download.html", error="expired", photo=photo)

    remaining = int((expires_at - datetime.now(timezone.utc)).total_seconds())
    return render_template("download.html",
        error=None, photo=photo,
        file_url=f"{BASE_URL}/file/{token}",
        remaining_seconds=remaining,
    )


@app.route("/file/<token>")
def serve_file(token):
    # Single DB connection for both read and update
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM photos WHERE token = %s", (token,))
            photo = cur.fetchone()

            if not photo:
                abort(404)

            expires_at = photo["expires_at"]
            if isinstance(expires_at, str):
                expires_at = datetime.fromisoformat(expires_at)
            if expires_at.tzinfo is None:
                expires_at = expires_at.replace(tzinfo=timezone.utc)
            if datetime.now(timezone.utc) > expires_at:
                abort(410)

            file_path = UPLOAD_DIR / photo["stored_filename"]
            if not file_path.exists():
                abort(404)

            # Increment download count in the same connection
            cur.execute(
                "UPDATE photos SET download_count = download_count + 1 WHERE token = %s",
                (token,),
            )

    return send_file(str(file_path), download_name=photo["original_filename"], as_attachment=True)

# ---------------------------------------------------------------------------
# Upload (JWT OR legacy API key)
# ---------------------------------------------------------------------------

@app.route("/upload", methods=["POST"])
def upload_photo():
    # Accept JWT or legacy API key
    auth = request.headers.get("Authorization", "")
    bearer = auth.removeprefix("Bearer ").strip()
    payload = _decode_jwt(bearer) if bearer else None

    # Fallback: legacy API key
    if payload is None:
        api_key = request.headers.get("X-API-Key", "")
        if api_key != API_KEY:
            return jsonify({"error": "Unauthorized"}), 401
        user_id = None
        user_email = None
    else:
        user_id = payload["sub"]
        user_email = payload.get("email")

    if "file" not in request.files:
        return jsonify({"error": "No file provided"}), 400

    file = request.files["file"]
    if not file.filename:
        return jsonify({"error": "Empty filename"}), 400

    ext = Path(file.filename).suffix.lower()
    if ext not in ALLOWED_EXTENSIONS:
        return jsonify({"error": f"File type {ext} not allowed"}), 400

    event_id = request.form.get("event_id") or None
    if event_id:
        event_id = int(event_id)

    # ── Build organized folder path: uploads/<user>/<event>/ ──
    def _safe_name(name):
        """Sanitize folder name — keep alphanumerics, spaces, dashes, underscores."""
        import re
        clean = re.sub(r'[^\w\s\-]', '', str(name)).strip()
        return clean or "_unnamed"

    # Determine user folder
    if user_email:
        user_folder = _safe_name(user_email.split("@")[0])  # e.g. "mk" from "mk@example.com"
    elif user_id:
        user_folder = f"user_{user_id}"
    else:
        user_folder = "_legacy"

    # Determine event folder
    event_name = None
    if event_id:
        with db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT name FROM events WHERE id = %s", (event_id,))
                row = cur.fetchone()
                if row:
                    event_name = row["name"]
    event_folder = _safe_name(event_name) if event_name else "_unsorted"

    # Create the subdirectory
    sub_dir = UPLOAD_DIR / user_folder / event_folder
    sub_dir.mkdir(parents=True, exist_ok=True)

    token = uuid.uuid4().hex
    stored_filename = f"{user_folder}/{event_folder}/{token}{ext}"
    file_path = UPLOAD_DIR / stored_filename

    # Stream to disk instead of loading entire file into memory
    file.save(str(file_path))
    file_size = file_path.stat().st_size

    now = datetime.now(timezone.utc)
    expires_at = now + timedelta(hours=EXPIRY_HOURS)

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                """INSERT INTO photos
                   (token, original_filename, stored_filename, file_size,
                    created_at, expires_at, user_id, event_id)
                   VALUES (%s, %s, %s, %s, %s, %s, %s, %s)""",
                (token, file.filename, stored_filename, file_size,
                 now, expires_at, user_id, event_id),
            )

    return jsonify({
        "success": True,
        "token": token,
        "download_url": f"{BASE_URL}/download/{token}",
        "filename": file.filename,
        "file_size": file_size,
        "expires_at": expires_at.isoformat(),
        "expires_in_seconds": EXPIRY_HOURS * 3600,
    })

# ---------------------------------------------------------------------------
# User Auth API (JWT)
# ---------------------------------------------------------------------------

@app.route("/auth/login", methods=["POST"])
def auth_login():
    data = request.get_json(force=True) or {}
    email = (data.get("email") or "").strip().lower()
    password = data.get("password") or ""
    device_id = data.get("device_id") or ""

    if not email or not password:
        return jsonify({"error": "Email and password required"}), 400

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM users WHERE email = %s AND is_active = 1", (email,))
            user = cur.fetchone()

    if not user or not check_password(password, user["password_hash"]):
        return jsonify({"error": "Invalid credentials"}), 401

    # Enforce 1-user-1-device: reject if already logged in on a different device
    if device_id and user.get("device_id") and user["device_id"] != device_id:
        return jsonify({
            "error": "already_logged_in",
            "message": "This account is already logged in on another device. "
                       "Please logout from that device or contact your admin.",
        }), 409

    # Store device_id and login time
    now = datetime.now(timezone.utc)
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "UPDATE users SET device_id = %s, last_login_at = %s WHERE id = %s",
                (device_id or None, now, user["id"]),
            )

    token = pyjwt.encode({
        "sub": user["id"],
        "email": user["email"],
        "device_id": device_id,
        "exp": datetime.now(timezone.utc) + timedelta(days=365),
    }, JWT_SECRET, algorithm="HS256")

    return jsonify({"success": True, "token": token, "email": user["email"]})


@app.route("/auth/verify", methods=["GET"])
@require_user_jwt
def auth_verify():
    return jsonify({"valid": True, "user_id": g.user_id, "email": g.user_email})


@app.route("/auth/logout", methods=["POST"])
@require_user_jwt
def auth_logout():
    """Clear device_id so the user can login from another device."""
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("UPDATE users SET device_id = NULL WHERE id = %s", (g.user_id,))
    return jsonify({"success": True})


@app.route("/auth/change-password", methods=["POST"])
@require_user_jwt
def auth_change_password():
    """Change the logged-in user's password."""
    data = request.get_json(force=True) or {}
    current_password = data.get("current_password") or ""
    new_password = data.get("new_password") or ""

    if not current_password or not new_password:
        return jsonify({"error": "Current and new password required"}), 400
    if len(new_password) < 6:
        return jsonify({"error": "New password must be at least 6 characters"}), 400

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT password_hash FROM users WHERE id = %s", (g.user_id,))
            user = cur.fetchone()
            if not user or not check_password(current_password, user["password_hash"]):
                return jsonify({"error": "Current password is incorrect"}), 401
            cur.execute(
                "UPDATE users SET password_hash = %s WHERE id = %s",
                (hash_password(new_password), g.user_id),
            )
    return jsonify({"success": True})

# ---------------------------------------------------------------------------
# Events API (JWT)
# ---------------------------------------------------------------------------

@app.route("/events", methods=["GET"])
@require_user_jwt
def list_events():
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT e.*, COUNT(p.id) as photo_count
                FROM events e
                LEFT JOIN photos p ON p.event_id = e.id AND p.is_expired = 0
                WHERE e.user_id = %s
                GROUP BY e.id
                ORDER BY e.created_at DESC
            """, (g.user_id,))
            events = cur.fetchall()
    return jsonify({"events": events})


@app.route("/events", methods=["POST"])
@require_user_jwt
def create_event():
    data = request.get_json(force=True) or {}
    name = (data.get("name") or "").strip()
    if not name:
        return jsonify({"error": "Event name required"}), 400

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "INSERT INTO events (user_id, name) VALUES (%s, %s)",
                (g.user_id, name),
            )
            event_id = cur.lastrowid
            cur.execute("SELECT * FROM events WHERE id = %s", (event_id,))
            event = cur.fetchone()

    return jsonify({"success": True, "event": event}), 201

# ---------------------------------------------------------------------------
# Photo Delete API (JWT)
# ---------------------------------------------------------------------------

@app.route("/photos/<int:photo_id>", methods=["DELETE"])
@require_user_jwt
def delete_photo(photo_id):
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM photos WHERE id = %s AND user_id = %s",
                        (photo_id, g.user_id))
            photo = cur.fetchone()

    if not photo:
        return jsonify({"error": "Photo not found"}), 404

    tp = THUMB_CACHE_DIR / f"{photo['token']}.jpg"
    if tp.exists():
        tp.unlink()
    file_path = UPLOAD_DIR / photo["stored_filename"]
    if file_path.exists():
        file_path.unlink()

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("DELETE FROM photos WHERE id = %s AND user_id = %s",
                        (photo_id, g.user_id))

    return jsonify({"success": True})


@app.route("/photos/bulk-delete", methods=["POST"])
@require_user_jwt
def bulk_delete_photos():
    data = request.get_json(force=True) or {}
    ids = data.get("ids", [])
    if not ids or not isinstance(ids, list):
        return jsonify({"error": "ids list required"}), 400

    ids = [int(i) for i in ids if str(i).isdigit()]
    deleted = 0

    with db_connection() as conn:
        with conn.cursor() as cur:
            fmt = ",".join(["%s"] * len(ids))
            cur.execute(
                f"SELECT * FROM photos WHERE id IN ({fmt}) AND user_id = %s",
                (*ids, g.user_id),
            )
            photos = cur.fetchall()

            for photo in photos:
                tp = THUMB_CACHE_DIR / f"{photo['token']}.jpg"
                if tp.exists():
                    tp.unlink()
                file_path = UPLOAD_DIR / photo["stored_filename"]
                if file_path.exists():
                    file_path.unlink()
                deleted += 1

            if photos:
                photo_ids = [p["id"] for p in photos]
                fmt2 = ",".join(["%s"] * len(photo_ids))
                cur.execute(f"DELETE FROM photos WHERE id IN ({fmt2})", photo_ids)

    return jsonify({"success": True, "deleted": deleted})

# ---------------------------------------------------------------------------
# Cleanup (legacy API key)
# ---------------------------------------------------------------------------

@app.route("/cleanup", methods=["DELETE"])
@require_api_key
def cleanup_expired():
    now = datetime.now(timezone.utc)
    deleted_count = 0

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT * FROM photos WHERE expires_at < %s AND is_expired = 0", (now,)
            )
            expired = cur.fetchall()
            for photo in expired:
                tp = THUMB_CACHE_DIR / f"{photo['token']}.jpg"
                if tp.exists():
                    tp.unlink()
                file_path = UPLOAD_DIR / photo["stored_filename"]
                if file_path.exists():
                    file_path.unlink()
                    deleted_count += 1
                cur.execute("UPDATE photos SET is_expired = 1 WHERE id = %s", (photo["id"],))

    return jsonify({"success": True, "deleted_count": deleted_count})

# ---------------------------------------------------------------------------
# Admin Panel
# ---------------------------------------------------------------------------

@app.route("/admin")
@require_admin_session
def admin_panel():
    return redirect(url_for("admin_users"))


@app.route("/admin/login", methods=["GET", "POST"])
def admin_login():
    error = None
    if request.method == "POST":
        email = request.form.get("email", "").strip().lower()
        password = request.form.get("password", "")
        with db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT * FROM admins WHERE email = %s", (email,))
                admin = cur.fetchone()
        if admin and check_password(password, admin["password_hash"]):
            session["admin_logged_in"] = True
            session["admin_email"] = admin["email"]
            return redirect(url_for("admin_users"))
        error = "Invalid admin credentials"
    return render_template("admin_login.html", error=error)


@app.route("/admin/logout", methods=["POST"])
def admin_logout():
    session.pop("admin_logged_in", None)
    session.pop("admin_email", None)
    return redirect(url_for("admin_login"))


@app.route("/admin/users", methods=["GET", "POST"])
@require_admin_session
def admin_users():
    message = None
    error = None

    if request.method == "POST":
        action = request.form.get("action")

        if action == "create":
            email = request.form.get("email", "").strip().lower()
            password = request.form.get("password", "")
            if not email or not password:
                error = "Email and password are required"
            else:
                try:
                    with db_connection() as conn:
                        with conn.cursor() as cur:
                            cur.execute(
                                "INSERT INTO users (email, password_hash) VALUES (%s, %s)",
                                (email, hash_password(password)),
                            )
                    message = f"User '{email}' created successfully"
                except Exception as e:
                    error = f"Could not create user: {e}"

        elif action == "delete":
            user_id = request.form.get("user_id")
            with db_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute("SELECT email FROM users WHERE id = %s", (user_id,))
                    u = cur.fetchone()
                    if u:
                        cur.execute("DELETE FROM users WHERE id = %s", (user_id,))
                        message = f"User '{u['email']}' deleted"

        elif action == "toggle":
            user_id = request.form.get("user_id")
            with db_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute(
                        "UPDATE users SET is_active = NOT is_active WHERE id = %s", (user_id,)
                    )
            message = "User status toggled"

        elif action == "force_logout":
            user_id = request.form.get("user_id")
            with db_connection() as conn:
                with conn.cursor() as cur:
                    cur.execute("UPDATE users SET device_id = NULL WHERE id = %s", (user_id,))
            message = "User has been force-logged out"

        elif action == "change_user_password":
            user_id = request.form.get("user_id")
            new_password = request.form.get("new_password", "")
            if not new_password or len(new_password) < 6:
                error = "Password must be at least 6 characters"
            else:
                with db_connection() as conn:
                    with conn.cursor() as cur:
                        cur.execute(
                            "UPDATE users SET password_hash = %s WHERE id = %s",
                            (hash_password(new_password), user_id),
                        )
                message = "User password changed successfully"

        elif action == "change_admin_password":
            current_pw = request.form.get("current_password", "")
            new_pw = request.form.get("new_password", "")
            admin_email = session.get("admin_email", "")
            if not current_pw or not new_pw:
                error = "Both current and new password required"
            elif len(new_pw) < 6:
                error = "New password must be at least 6 characters"
            else:
                with db_connection() as conn:
                    with conn.cursor() as cur:
                        cur.execute("SELECT * FROM admins WHERE email = %s", (admin_email,))
                        admin = cur.fetchone()
                        if admin and check_password(current_pw, admin["password_hash"]):
                            cur.execute(
                                "UPDATE admins SET password_hash = %s WHERE id = %s",
                                (hash_password(new_pw), admin["id"]),
                            )
                            message = "Admin password changed successfully"
                        else:
                            error = "Current admin password is incorrect"

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT u.*, COUNT(e.id) as event_count
                FROM users u LEFT JOIN events e ON e.user_id = u.id
                GROUP BY u.id ORDER BY u.created_at DESC
            """)
            users = cur.fetchall()

    return render_template("admin_users.html", users=users, message=message, error=error)


@app.route("/admin/events")
@require_admin_session
def admin_events():
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT e.id, e.name, e.created_at, u.email as user_email,
                       COUNT(p.id) as photo_count
                FROM events e
                JOIN users u ON u.id = e.user_id
                LEFT JOIN photos p ON p.event_id = e.id AND p.is_expired = 0
                GROUP BY e.id ORDER BY e.created_at DESC
            """)
            events = cur.fetchall()
    return render_template("admin_events.html", events=events)

# ---------------------------------------------------------------------------
# User Portal
# ---------------------------------------------------------------------------

@app.route("/portal")
def portal_root():
    if session.get("portal_user_id"):
        return redirect(url_for("portal_dashboard"))
    return redirect(url_for("portal_login"))


@app.route("/portal/login", methods=["GET", "POST"])
def portal_login():
    error = None
    if request.method == "POST":
        email = request.form.get("email", "").strip().lower()
        password = request.form.get("password", "")

        with db_connection() as conn:
            with conn.cursor() as cur:
                cur.execute("SELECT * FROM users WHERE email = %s AND is_active = 1", (email,))
                user = cur.fetchone()

        if user and check_password(password, user["password_hash"]):
            # Portal does NOT enforce device_id — only desktop app does.
            session["portal_user_id"] = user["id"]
            session["portal_email"] = user["email"]
            return redirect(url_for("portal_dashboard"))
        error = "Invalid email or password"

    return render_template("portal_login.html", error=error)


@app.route("/portal/logout", methods=["POST"])
def portal_logout():
    # Portal logout only clears the session — does NOT touch device_id.
    session.pop("portal_user_id", None)
    session.pop("portal_email", None)
    return redirect(url_for("portal_login"))


@app.route("/portal/dashboard")
@require_portal_session
def portal_dashboard():
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("""
                SELECT e.*, COUNT(p.id) as photo_count
                FROM events e
                LEFT JOIN photos p ON p.event_id = e.id AND p.is_expired = 0
                WHERE e.user_id = %s
                GROUP BY e.id ORDER BY e.created_at DESC
            """, (g.user_id,))
            events = cur.fetchall()

    return render_template("portal_dashboard.html",
        events=events, user_email=g.user_email)


@app.route("/portal/events/<int:event_id>")
@require_portal_session
def portal_event(event_id):
    with db_connection() as conn:
        with conn.cursor() as cur:
            # Verify event belongs to this user
            cur.execute("SELECT * FROM events WHERE id = %s AND user_id = %s",
                        (event_id, g.user_id))
            event = cur.fetchone()

            if not event:
                abort(404)

            # Get total photo count for the event header
            cur.execute(
                "SELECT COUNT(*) as cnt FROM photos WHERE event_id = %s AND is_expired = 0",
                (event_id,),
            )
            total_photos = cur.fetchone()["cnt"]

    return render_template("portal_event.html",
        event=event, total_photos=total_photos, user_email=g.user_email,
        base_url=BASE_URL)


@app.route("/portal/events/<int:event_id>/photos")
@require_portal_session
def portal_event_photos_api(event_id):
    """Paginated JSON endpoint for event photos (used by infinite scroll)."""
    page = max(1, int(request.args.get("page", 1)))
    per_page = min(100, max(1, int(request.args.get("per_page", 15))))
    offset = (page - 1) * per_page

    with db_connection() as conn:
        with conn.cursor() as cur:
            # Verify event belongs to this user
            cur.execute("SELECT id FROM events WHERE id = %s AND user_id = %s",
                        (event_id, g.user_id))
            if not cur.fetchone():
                return jsonify({"error": "Not found"}), 404

            cur.execute(
                "SELECT COUNT(*) as cnt FROM photos WHERE event_id = %s AND is_expired = 0",
                (event_id,),
            )
            total = cur.fetchone()["cnt"]

            cur.execute("""
                SELECT id, token, original_filename, file_size, created_at,
                       expires_at, download_count
                FROM photos
                WHERE event_id = %s AND is_expired = 0
                ORDER BY created_at DESC
                LIMIT %s OFFSET %s
            """, (event_id, per_page, offset))
            photos = cur.fetchall()

    import math
    return jsonify({
        "photos": photos,
        "total": total,
        "page": page,
        "per_page": per_page,
        "pages": math.ceil(total / per_page) if total > 0 else 0,
    })


@app.route("/portal/events/<int:event_id>/delete", methods=["POST"])
@require_portal_session
def portal_delete_event(event_id):
    """Delete an entire event and all its photos (files + DB records)."""
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT id FROM events WHERE id = %s AND user_id = %s",
                        (event_id, g.user_id))
            if not cur.fetchone():
                return jsonify({"error": "Not found"}), 404

            # Delete photo files and thumbs from disk
            cur.execute(
                "SELECT stored_filename, token FROM photos WHERE event_id = %s",
                (event_id,),
            )
            for row in cur.fetchall():
                tp = THUMB_CACHE_DIR / f"{row['token']}.jpg"
                if tp.exists():
                    tp.unlink()
                fp = UPLOAD_DIR / row["stored_filename"]
                if fp.exists():
                    fp.unlink()

            # Delete photo records
            cur.execute("DELETE FROM photos WHERE event_id = %s", (event_id,))
            # Delete the event
            cur.execute("DELETE FROM events WHERE id = %s AND user_id = %s",
                        (event_id, g.user_id))

    return jsonify({"success": True})


def _serve_thumb(token, owner_check_sql, owner_check_params):
    """Generate (once) and serve a cached JPEG thumbnail for any photo token."""
    # Check disk cache first — no DB hit needed
    cache_path = THUMB_CACHE_DIR / f"{token}.jpg"
    if cache_path.exists():
        resp = send_file(str(cache_path), mimetype="image/jpeg")
        resp.headers["Cache-Control"] = "public, max-age=31536000, immutable"
        return resp

    # Not cached yet — verify ownership and generate
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(owner_check_sql, owner_check_params)
            photo = cur.fetchone()

    if not photo:
        abort(404)

    file_path = UPLOAD_DIR / photo["stored_filename"]
    if not file_path.exists():
        abort(404)

    try:
        img = Image.open(file_path)
        img.thumbnail((400, 400))
        if img.mode in ("RGBA", "P", "LA"):
            img = img.convert("RGB")
        # Save to disk cache
        img.save(str(cache_path), format="JPEG", quality=82)
    except Exception:
        logger.exception(f"Thumbnail generation failed for token {token}")
        abort(500)

    resp = send_file(str(cache_path), mimetype="image/jpeg")
    resp.headers["Cache-Control"] = "public, max-age=31536000, immutable"
    return resp


@app.route("/portal/thumb/<token>")
@require_portal_session
def portal_thumb(token):
    """Serve a JPEG thumbnail (400px) for the portal photo grid."""
    return _serve_thumb(
        token,
        "SELECT stored_filename FROM photos WHERE token = %s AND user_id = %s",
        (token, g.user_id),
    )


@app.route("/portal/photos/<int:photo_id>/delete", methods=["POST"])
@require_portal_session
def portal_delete_photo(photo_id):
    """Delete a single photo (PORTAL SESSION AUTH — called from portal via form/fetch)."""
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM photos WHERE id = %s AND user_id = %s",
                        (photo_id, g.user_id))
            photo = cur.fetchone()

    if not photo:
        return jsonify({"error": "Not found"}), 404

    tp = THUMB_CACHE_DIR / f"{photo['token']}.jpg"
    if tp.exists():
        tp.unlink()
    file_path = UPLOAD_DIR / photo["stored_filename"]
    if file_path.exists():
        file_path.unlink()

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("DELETE FROM photos WHERE id = %s", (photo_id,))

    return jsonify({"success": True})


@app.route("/portal/photos/bulk-delete", methods=["POST"])
@require_portal_session
def portal_bulk_delete():
    """Bulk delete photos (PORTAL SESSION AUTH)."""
    data = request.get_json(force=True) or {}
    ids = [int(i) for i in data.get("ids", []) if str(i).isdigit()]
    if not ids:
        return jsonify({"error": "No IDs"}), 400

    deleted = 0
    with db_connection() as conn:
        with conn.cursor() as cur:
            fmt = ",".join(["%s"] * len(ids))
            cur.execute(
                f"SELECT * FROM photos WHERE id IN ({fmt}) AND user_id = %s",
                (*ids, g.user_id),
            )
            photos = cur.fetchall()
            for photo in photos:
                tp = THUMB_CACHE_DIR / f"{photo['token']}.jpg"
                if tp.exists():
                    tp.unlink()
                fp = UPLOAD_DIR / photo["stored_filename"]
                if fp.exists():
                    fp.unlink()
                deleted += 1
            if photos:
                pids = [p["id"] for p in photos]
                fmt2 = ",".join(["%s"] * len(pids))
                cur.execute(f"DELETE FROM photos WHERE id IN ({fmt2})", pids)

    return jsonify({"success": True, "deleted": deleted})


# ---------------------------------------------------------------------------
# Shareable Album Links (public — no login required)
# ---------------------------------------------------------------------------

@app.route("/portal/events/<int:event_id>/share", methods=["POST"])
@require_portal_session
def portal_share_event(event_id):
    """Generate a shareable link for an event album."""
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT * FROM events WHERE id = %s AND user_id = %s",
                        (event_id, g.user_id))
            event = cur.fetchone()
            if not event:
                return jsonify({"error": "Not found"}), 404

            # Generate or return existing share token
            share_token = event.get("share_token")
            if not share_token:
                share_token = uuid.uuid4().hex
                cur.execute(
                    "UPDATE events SET share_token = %s WHERE id = %s",
                    (share_token, event_id),
                )

    share_url = f"{BASE_URL}/album/{share_token}"
    return jsonify({"success": True, "share_token": share_token, "share_url": share_url})


@app.route("/portal/events/<int:event_id>/unshare", methods=["POST"])
@require_portal_session
def portal_unshare_event(event_id):
    """Remove the shareable link for an event."""
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT id FROM events WHERE id = %s AND user_id = %s",
                        (event_id, g.user_id))
            if not cur.fetchone():
                return jsonify({"error": "Not found"}), 404
            cur.execute("UPDATE events SET share_token = NULL WHERE id = %s", (event_id,))
    return jsonify({"success": True})


@app.route("/album/<share_token>")
def public_album(share_token):
    """Public album page — anyone with the link can view."""
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute(
                "SELECT e.*, u.email as user_email FROM events e JOIN users u ON u.id = e.user_id WHERE e.share_token = %s",
                (share_token,),
            )
            event = cur.fetchone()
            if not event:
                abort(404)

            cur.execute(
                "SELECT COUNT(*) as cnt FROM photos WHERE event_id = %s AND is_expired = 0",
                (event["id"],),
            )
            total_photos = cur.fetchone()["cnt"]

    return render_template("public_album.html",
        event=event, total_photos=total_photos,
        share_token=share_token, base_url=BASE_URL)


@app.route("/album/<share_token>/photos")
def public_album_photos(share_token):
    """Paginated JSON endpoint for public album photos."""
    page = max(1, int(request.args.get("page", 1)))
    per_page = min(100, max(1, int(request.args.get("per_page", 15))))
    offset = (page - 1) * per_page

    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT id FROM events WHERE share_token = %s", (share_token,))
            event = cur.fetchone()
            if not event:
                return jsonify({"error": "Not found"}), 404
            event_id = event["id"]

            cur.execute(
                "SELECT COUNT(*) as cnt FROM photos WHERE event_id = %s AND is_expired = 0",
                (event_id,),
            )
            total = cur.fetchone()["cnt"]

            cur.execute("""
                SELECT id, token, original_filename, file_size, created_at,
                       expires_at, download_count
                FROM photos
                WHERE event_id = %s AND is_expired = 0
                ORDER BY created_at DESC
                LIMIT %s OFFSET %s
            """, (event_id, per_page, offset))
            photos = cur.fetchall()

    import math
    return jsonify({
        "photos": photos,
        "total": total,
        "page": page,
        "per_page": per_page,
        "pages": math.ceil(total / per_page) if total > 0 else 0,
    })


@app.route("/album/<share_token>/thumb/<token>")
def public_album_thumb(share_token, token):
    """Serve thumbnail for public album — verifies share_token is valid."""
    # Quick share_token validation (lightweight query)
    with db_connection() as conn:
        with conn.cursor() as cur:
            cur.execute("SELECT id FROM events WHERE share_token = %s", (share_token,))
            event = cur.fetchone()
            if not event:
                abort(404)

    return _serve_thumb(
        token,
        "SELECT stored_filename FROM photos WHERE token = %s AND event_id = %s AND is_expired = 0",
        (token, event["id"]),
    )


# ---------------------------------------------------------------------------
# Entry point (local dev)
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    app.run(host="0.0.0.0", port=8000, debug=True)
