#! /usr/bin/python3.1
# -*- coding: utf-8 -*-

#!/usr/bin/env python3
import json, os, re, sqlite3, subprocess, time
from datetime import datetime, timezone

# ====== CONFIG ======
LOCAL_HTTPS_PORT = 443

# seuils (ajuste selon ton trafic normal)
TH_CONN_IP = 25          # nb de connexions ESTABLISHED simultanées depuis une IP
TH_CONN_SUBNET24 = 220   # nb de connexions sur /24 (attaque distribuée dans un même bloc)
BAN_SECONDS = 3600       # 1h (ipset timeout)
MIN_IDLE_SCORE = 2       # score minimal pour bannir

# fichiers
LOG_DIR = "/var/log/ideolab"
EVENTS_JSONL = os.path.join(LOG_DIR, "connmon_events.jsonl")
METRICS_PROM = "/var/lib/node_exporter/textfile_collector/ideolab_connmon.prom"  # prometheus node_exporter
DB_PATH = "/var/lib/ideolab/connmon_cache.sqlite3"

IPSET_NAME = "ideolab_ban"

# ====== UTILS ======
IP_RE = re.compile(r"(\d{1,3}(?:\.\d{1,3}){3})")

def sh(cmd: list[str]) -> str:
    return subprocess.check_output(cmd, text=True, stderr=subprocess.DEVNULL)

def ensure_dirs():
    os.makedirs(LOG_DIR, exist_ok=True)
    os.makedirs(os.path.dirname(DB_PATH), exist_ok=True)
    os.makedirs(os.path.dirname(METRICS_PROM), exist_ok=True)

def now_iso():
    return datetime.now(timezone.utc).isoformat()

def ip_to_subnet24(ip: str) -> str:
    p = ip.split(".")
    return ".".join(p[:3]) + ".0/24"

def init_db():
    con = sqlite3.connect(DB_PATH)
    cur = con.cursor()
    cur.execute("""
        CREATE TABLE IF NOT EXISTS asn_cache (
            ip TEXT PRIMARY KEY,
            asn TEXT,
            asname TEXT,
            country TEXT,
            updated_at INTEGER
        )
    """)
    con.commit()
    return con

def lookup_asn_cymru(ip: str, con: sqlite3.Connection) -> dict:
    """
    whois cymru: ASN + country + AS name (rapide, sans API)
    """
    cur = con.cursor()
    cur.execute("SELECT asn, asname, country, updated_at FROM asn_cache WHERE ip=?", (ip,))
    row = cur.fetchone()
    if row:
        asn, asname, country, updated_at = row
        # cache 7j
        if int(time.time()) - int(updated_at) < 7*24*3600:
            return {"asn": asn, "asname": asname, "country": country, "src": "cache"}

    try:
        out = sh(["whois", "-h", "whois.cymru.com", f" -v {ip}"])
        # format: AS | IP | BGP Prefix | CC | Registry | Allocated | AS Name
        lines = [l.strip() for l in out.splitlines() if l.strip() and not l.startswith("AS")]
        if not lines:
            raise RuntimeError("no cymru data")
        parts = [p.strip() for p in lines[0].split("|")]
        asn = parts[0]
        cc = parts[3]
        asname = parts[-1]
        cur.execute(
            "INSERT OR REPLACE INTO asn_cache(ip, asn, asname, country, updated_at) VALUES (?,?,?,?,?)",
            (ip, asn, asname, cc, int(time.time()))
        )
        con.commit()
        return {"asn": asn, "asname": asname, "country": cc, "src": "cymru"}
    except Exception:
        return {"asn": None, "asname": None, "country": None, "src": "none"}

def ipset_add(ip: str, seconds: int):
    subprocess.run(["ipset", "add", IPSET_NAME, ip, "timeout", str(seconds), "-exist"],
                   stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

def parse_ss_established_443():
    """
    ss output sample:
    ESTAB 0 0 172.31.28.54:443 187.62.150.233:14648 timer:(keepalive,41sec,0) uid:...
    We want peer IP, send-q/recv-q and timer hint.
    """
    out = sh(["ss", "-Hnt", "state", "established", f"( sport = :{LOCAL_HTTPS_PORT} )"])
    rows = []
    for line in out.splitlines():
        if not line.strip():
            continue
        # basic split
        cols = line.split()
        # cols[1]=Recv-Q cols[2]=Send-Q in many builds
        recvq = int(cols[1]) if len(cols) > 2 and cols[1].isdigit() else 0
        sendq = int(cols[2]) if len(cols) > 2 and cols[2].isdigit() else 0

        m = IP_RE.search(line.split()[-2] if len(cols) >= 5 else line)
        # safer: find the "peer" like x.x.x.x:port -> last occurrence
        ips = IP_RE.findall(line)
        peer_ip = ips[-1] if ips else None
        if not peer_ip:
            continue

        timer = None
        mt = re.search(r"timer:\(([^)]+)\)", line)
        if mt:
            timer = mt.group(1)

        rows.append({"peer_ip": peer_ip, "recvq": recvq, "sendq": sendq, "timer": timer, "raw": line})
    return rows

def write_event(event: dict):
    with open(EVENTS_JSONL, "a", encoding="utf-8") as f:
        f.write(json.dumps(event, ensure_ascii=False) + "\n")

def write_metrics(metrics: dict):
    # Prometheus textfile format
    lines = []
    lines.append("# HELP ideolab_connmon_bans_total Total bans issued by connmon")
    lines.append("# TYPE ideolab_connmon_bans_total counter")
    lines.append(f"ideolab_connmon_bans_total {metrics.get('bans_total', 0)}")

    lines.append("# HELP ideolab_connmon_established_443 Current established connections to 443")
    lines.append("# TYPE ideolab_connmon_established_443 gauge")
    lines.append(f"ideolab_connmon_established_443 {metrics.get('established_443', 0)}")

    lines.append("# HELP ideolab_connmon_suspect_ips Current suspect IPs detected")
    lines.append("# TYPE ideolab_connmon_suspect_ips gauge")
    lines.append(f"ideolab_connmon_suspect_ips {metrics.get('suspect_ips', 0)}")

    with open(METRICS_PROM, "w", encoding="utf-8") as f:
        f.write("\n".join(lines) + "\n")

def main():
    ensure_dirs()
    con = init_db()

    rows = parse_ss_established_443()
    established = len(rows)

    per_ip = {}
    per_24 = {}

    for r in rows:
        ip = r["peer_ip"]
        per_ip.setdefault(ip, []).append(r)
        sn = ip_to_subnet24(ip)
        per_24[sn] = per_24.get(sn, 0) + 1

    bans_total = 0
    suspects = []

    # 1) suspects par IP
    for ip, conns in per_ip.items():
        if len(conns) < TH_CONN_IP:
            continue

        # score: beaucoup de connexions + queues à zéro + keepalive/timer présents
        score = 0
        score += 1
        if all(c["recvq"] == 0 and c["sendq"] == 0 for c in conns):
            score += 1
        if any(c.get("timer") and ("keepalive" in c["timer"] or "on" in c["timer"]) for c in conns):
            score += 1

        if score >= MIN_IDLE_SCORE:
            asn = lookup_asn_cymru(ip, con)
            suspects.append({"ip": ip, "count": len(conns), "score": score, "asn": asn})

    # 2) suspects par /24 (attaque en “spray”)
    hot_subnets = {sn: c for sn, c in per_24.items() if c >= TH_CONN_SUBNET24}

    # Ban suspects
    for s in suspects:
        ipset_add(s["ip"], BAN_SECONDS)
        bans_total += 1
        write_event({
            "ts": now_iso(),
            "type": "ban_ip",
            "ip": s["ip"],
            "count": s["count"],
            "score": s["score"],
            "asn": s["asn"],
            "reason": f"ESTABLISHED:443 >= {TH_CONN_IP} (idle/keepalive heuristic)",
            "ban_seconds": BAN_SECONDS
        })

    # Ban subnets chauds (optionnel: ici on log seulement, tu peux activer le ban /24 si tu veux)
    for sn, c in hot_subnets.items():
        write_event({
            "ts": now_iso(),
            "type": "hot_subnet24",
            "subnet": sn,
            "count": c,
            "reason": f"/24 established_443 >= {TH_CONN_SUBNET24}"
        })

    write_metrics({
        "bans_total": bans_total,
        "established_443": established,
        "suspect_ips": len(suspects),
    })

if __name__ == "__main__":
    main()
