#! /usr/bin/python3.1
# -*- coding: utf-8 -*-


# -*- coding: utf-8 -*-
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, List, Sequence, Tuple, Set, Optional

from django.core.management.base import BaseCommand
from django.db import connections, DEFAULT_DB_ALIAS
from django.db.migrations.executor import MigrationExecutor
from django.db.migrations.operations.models import CreateModel, AddIndex, AlterIndexTogether


@dataclass(frozen=True)
class PlannedIndex:
    table: str
    name: str
    columns: Tuple[str, ...]
    reason: str


# -------------------------
# MariaDB/MySQL helpers
# -------------------------

def mysql_index_exists(connection, table: str, index_name: str) -> bool:
    with connection.cursor() as cursor:
        cursor.execute(
            """
            SELECT 1
            FROM INFORMATION_SCHEMA.STATISTICS
            WHERE TABLE_SCHEMA = DATABASE()
              AND TABLE_NAME = %s
              AND INDEX_NAME = %s
            LIMIT 1
            """,
            [table, index_name],
        )
        return cursor.fetchone() is not None


def mysql_index_columns(connection, table: str, index_name: str) -> List[str]:
    with connection.cursor() as cursor:
        cursor.execute(
            """
            SELECT COLUMN_NAME
            FROM INFORMATION_SCHEMA.STATISTICS
            WHERE TABLE_SCHEMA = DATABASE()
              AND TABLE_NAME = %s
              AND INDEX_NAME = %s
            ORDER BY SEQ_IN_INDEX ASC
            """,
            [table, index_name],
        )
        return [row[0] for row in cursor.fetchall()]


def mysql_find_tables_for_index(connection, index_name: str) -> List[str]:
    """Find tables (in current DB) that contain index_name."""
    with connection.cursor() as cursor:
        cursor.execute(
            """
            SELECT DISTINCT TABLE_NAME
            FROM INFORMATION_SCHEMA.STATISTICS
            WHERE TABLE_SCHEMA = DATABASE()
              AND INDEX_NAME = %s
            """,
            [index_name],
        )
        return [row[0] for row in cursor.fetchall()]


def mysql_drop_index(connection, table: str, index_name: str) -> None:
    with connection.cursor() as cursor:
        cursor.execute(f"DROP INDEX `{index_name}` ON `{table}`")


# -------------------------
# Migration plan scanning
# -------------------------

def _clean_cols(cols) -> Tuple[str, ...]:
    return tuple([c for c in (cols or []) if c])


def collect_planned_indexes(
    connection,
    executor: MigrationExecutor,
    plan: Sequence[Tuple[Any, bool]],
) -> List[PlannedIndex]:
    """
    Simule l'application des migrations en mémoire et liste les index
    que Django va tenter de créer (MySQL/MariaDB), y compris index_together.
    """
    state = executor._create_project_state(with_applied_migrations=True)
    planned: List[PlannedIndex] = []

    with connection.schema_editor() as schema_editor:
        for migration, backwards in plan:
            if backwards:
                continue

            def table_for(model_name: str) -> str:
                key = (migration.app_label, model_name.lower())
                if key in state.models:
                    ms = state.models[key]
                    return ms.options.get("db_table") or f"{migration.app_label}_{model_name.lower()}"
                return f"{migration.app_label}_{model_name.lower()}"

            for op in migration.operations:
                # 1) CreateModel: Meta.indexes + index_together
                if isinstance(op, CreateModel):
                    table = op.options.get("db_table") or f"{migration.app_label}_{op.name.lower()}"

                    # Meta.indexes
                    for idx in (op.options.get("indexes") or []):
                        cols = _clean_cols(getattr(idx, "fields", None))
                        name = getattr(idx, "name", None)
                        if not name:
                            if not cols:
                                continue
                            name = schema_editor._create_index_name(table, list(cols), suffix="_idx")
                        planned.append(
                            PlannedIndex(
                                table=table,
                                name=name,
                                columns=cols,
                                reason=f"{migration.app_label}.{migration.name}: CreateModel({op.name}) Meta.indexes",
                            )
                        )

                    # index_together (legacy mais encore présent dans des projets)
                    it = op.options.get("index_together") or []
                    for cols in it:
                        cols = _clean_cols(cols)
                        if not cols:
                            continue
                        name = schema_editor._create_index_name(table, list(cols), suffix="_idx")
                        planned.append(
                            PlannedIndex(
                                table=table,
                                name=name,
                                columns=cols,
                                reason=f"{migration.app_label}.{migration.name}: CreateModel({op.name}) index_together={cols}",
                            )
                        )

                # 2) AddIndex
                if isinstance(op, AddIndex):
                    table = table_for(op.model_name)
                    idx = op.index
                    cols = _clean_cols(getattr(idx, "fields", None))
                    name = getattr(idx, "name", None)
                    if not name:
                        if not cols:
                            continue
                        name = schema_editor._create_index_name(table, list(cols), suffix="_idx")
                    planned.append(
                        PlannedIndex(
                            table=table,
                            name=name,
                            columns=cols,
                            reason=f"{migration.app_label}.{migration.name}: AddIndex({op.model_name})",
                        )
                    )

                # 3) AlterIndexTogether (clé pour ton cas)
                if isinstance(op, AlterIndexTogether):
                    table = table_for(op.name)  # op.name = model_name
                    for cols in (op.index_together or []):
                        cols = _clean_cols(cols)
                        if not cols:
                            continue
                        name = schema_editor._create_index_name(table, list(cols), suffix="_idx")
                        planned.append(
                            PlannedIndex(
                                table=table,
                                name=name,
                                columns=cols,
                                reason=f"{migration.app_label}.{migration.name}: AlterIndexTogether({op.name}) index_together={cols}",
                            )
                        )

                # Avance l'état pour rester cohérent
                op.state_forwards(migration.app_label, state)

    # Dédup par (table, name)
    seen: Set[Tuple[str, str]] = set()
    uniq: List[PlannedIndex] = []
    for pi in planned:
        k = (pi.table, pi.name)
        if k not in seen:
            seen.add(k)
            uniq.append(pi)
    return uniq


class Command(BaseCommand):
    help = (
        "Fix MariaDB/MySQL migration failures caused by existing indexes "
        "(OperationalError 1061 'Duplicate key name'). "
        "Scans migration plan (Meta.indexes, AddIndex, index_together/AlterIndexTogether) "
        "and drops conflicting indexes BEFORE you run migrate. "
        "Also supports a surgical mode: --drop-index <index_name>."
    )

    def add_arguments(self, parser):
        parser.add_argument("--database", default=DEFAULT_DB_ALIAS)
        parser.add_argument("--dry-run", action="store_true")
        parser.add_argument("--apps", nargs="*", default=None)
        parser.add_argument("--verbose-sql", action="store_true")
        parser.add_argument(
            "--drop-index",
            default=None,
            help="Surgical: drop a specific index by name (finds the table automatically).",
        )

    def handle(self, *args, **opts):
        db_alias = opts["database"]
        dry_run = opts["dry_run"]
        restrict_apps = set(opts["apps"] or [])
        verbose_sql = opts["verbose_sql"]
        drop_index = opts["drop_index"]

        connection = connections[db_alias]

        # -------------------------
        # Mode chirurgical
        # -------------------------
        if drop_index:
            tables = mysql_find_tables_for_index(connection, drop_index)
            if not tables:
                self.stdout.write(self.style.WARNING(f"Index '{drop_index}' not found in current database."))
                return

            for table in tables:
                cols = mysql_index_columns(connection, table, drop_index)
                self.stdout.write(f"[FOUND] {table}.{drop_index} (cols={cols})")
                if dry_run:
                    continue
                sql = f"DROP INDEX `{drop_index}` ON `{table}`;"
                if verbose_sql:
                    self.stdout.write(sql)
                mysql_drop_index(connection, table, drop_index)
                self.stdout.write(self.style.SUCCESS(f"[DROP] {table}.{drop_index}"))

            if dry_run:
                self.stdout.write(self.style.WARNING("DRY-RUN: nothing dropped."))
            else:
                self.stdout.write(self.style.SUCCESS("Done. Now run: python manage.py migrate"))
            return

        # -------------------------
        # Mode auto: scan migration plan
        # -------------------------
        executor = MigrationExecutor(connection)
        targets = executor.loader.graph.leaf_nodes()
        plan = executor.migration_plan(targets)

        if restrict_apps:
            plan = [(m, b) for (m, b) in plan if m.app_label in restrict_apps]

        planned = collect_planned_indexes(connection, executor, plan)

        to_drop: List[PlannedIndex] = []
        for pi in planned:
            if mysql_index_exists(connection, pi.table, pi.name):
                cols = mysql_index_columns(connection, pi.table, pi.name)
                self.stdout.write(
                    f"[FOUND] {pi.table}.{pi.name} exists (db cols={cols}) -> will drop to avoid 1061. "
                    f"Reason: {pi.reason}"
                )
                to_drop.append(pi)

        if not to_drop:
            self.stdout.write(self.style.SUCCESS("No conflicting indexes found."))
            self.stdout.write("If migrate still fails, use: --drop-index <name_from_error>")
            return

        if dry_run:
            self.stdout.write(self.style.WARNING(f"DRY-RUN: would drop {len(to_drop)} index(es)."))
            return

        dropped = 0
        for pi in to_drop:
            sql = f"DROP INDEX `{pi.name}` ON `{pi.table}`;"
            if verbose_sql:
                self.stdout.write(sql)
            try:
                mysql_drop_index(connection, pi.table, pi.name)
                dropped += 1
                self.stdout.write(self.style.SUCCESS(f"[DROP] {pi.table}.{pi.name}"))
            except Exception as e:
                self.stdout.write(self.style.ERROR(f"[ERR] {pi.table}.{pi.name}: {e}"))

        self.stdout.write(self.style.SUCCESS(f"Done. Dropped {dropped}/{len(to_drop)} index(es)."))
        self.stdout.write("Now run: python manage.py migrate")
