#!/usr/bin/env python3
"""
Voice Message Transcriber
- Loops through all threads in follow_up_FULL.csv
- Finds voice_media messages, downloads + transcribes with Whisper
- Saves transcriptions to voice_transcripts.json (cache)
- Rebuilds follow_up_FULL.csv with [VOICE: transcript] inline
"""

import csv
import json
import os
import time
import tempfile
import urllib.request
from pathlib import Path
from urllib.parse import unquote

import whisper
import rookiepy
from instagrapi import Client
from instagrapi.exceptions import ClientError, DirectThreadNotFound

# ── Config ────────────────────────────────────────────────────────────────────
MSGS_PER_THREAD   = 30
DELAY             = 2.0
DM_CSV            = Path(__file__).parent / "follow_up_FULL.csv"
VOICE_CACHE       = Path(__file__).parent / "voice_transcripts.json"
SESSION_FILE      = Path(__file__).parent / "session.json"
WHISPER_MODEL     = "large-v3-turbo"

# ── Auth ──────────────────────────────────────────────────────────────────────
print("logging in via chrome session...")
cl = Client()
cl.delay_range = [DELAY, DELAY + 1]
cookies    = rookiepy.chrome(domains=["instagram.com"])
session_id = unquote({c["name"]: c["value"] for c in cookies}["sessionid"])
cl.login_by_sessionid(session_id)
my_id = str(cl.user_id)
print(f"logged in as: {cl.account_info().username}")

# ── Load whisper ──────────────────────────────────────────────────────────────
print(f"loading whisper [{WHISPER_MODEL}]...")
model = whisper.load_model(WHISPER_MODEL)
print("whisper ready\n")

# ── Load voice cache ──────────────────────────────────────────────────────────
voice_cache: dict = {}
if VOICE_CACHE.exists():
    with open(VOICE_CACHE) as f:
        voice_cache = json.load(f)
    print(f"cache: {len(voice_cache)} voice messages already transcribed")

def save_cache():
    with open(VOICE_CACHE, "w") as f:
        json.dump(voice_cache, f, indent=2)

def transcribe_url(url: str, msg_id: str) -> str:
    if msg_id in voice_cache:
        return voice_cache[msg_id]

    try:
        with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
            tmp_path = tmp.name

        req = urllib.request.Request(url, headers={"User-Agent": "Mozilla/5.0"})
        with urllib.request.urlopen(req, timeout=30) as resp:
            with open(tmp_path, "wb") as f:
                f.write(resp.read())

        result = model.transcribe(tmp_path, fp16=False)
        text   = result["text"].strip()
        os.unlink(tmp_path)

        voice_cache[msg_id] = text
        save_cache()
        return text
    except Exception as e:
        return f"[transcription failed: {e}]"

# ── Read CSV ──────────────────────────────────────────────────────────────────
rows = []
with open(DM_CSV, encoding="utf-8") as f:
    rows = list(csv.DictReader(f))

print(f"processing {len(rows)} threads...\n")

updated = 0
voices_found = 0

for i, row in enumerate(rows):
    thread_id = row["thread_id"]

    try:
        thread = cl.direct_thread(thread_id, amount=MSGS_PER_THREAD)
        messages = list(reversed(thread.messages))

        had_voice = False
        new_summary_lines = []

        for m in messages:
            sender_name = "keith" if str(m.user_id) == my_id else row["display_name"]
            ts = m.timestamp.strftime("%d.%m.%Y") if hasattr(m.timestamp, "strftime") else ""

            if m.item_type == "voice_media" and hasattr(m, "voice_media") and m.voice_media:
                audio_url = getattr(m.voice_media, "audio_address", None)
                if audio_url:
                    msg_id = str(m.id)
                    print(f"  [{i+1}/{len(rows)}] voice from {sender_name} in @{row['handle']} — transcribing...")
                    transcript = transcribe_url(audio_url, msg_id)
                    new_summary_lines.append(f"[{ts}] {sender_name}: [VOICE: {transcript}]")
                    voices_found += 1
                    had_voice = True
                else:
                    new_summary_lines.append(f"[{ts}] {sender_name}: [voice message]")
            else:
                txt = getattr(m, "text", None) or f"[{m.item_type}]"
                new_summary_lines.append(f"[{ts}] {sender_name}: {txt[:200]}")

        if had_voice:
            convo = " | ".join(new_summary_lines[-12:])
            row["conversation_summary"] = convo
            updated += 1
            print(f"  → updated thread {row['handle']}")

        time.sleep(DELAY)

    except DirectThreadNotFound:
        pass
    except ClientError as e:
        print(f"  rate limit at {i}, sleeping 30s... ({e})")
        time.sleep(30)
    except Exception as e:
        print(f"  error on {row['handle']}: {e}")

# ── Rewrite CSV ───────────────────────────────────────────────────────────────
fieldnames = list(rows[0].keys())
with open(DM_CSV, "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(rows)

print(f"""
═══════════════════════════════════════
VOICE TRANSCRIPTION COMPLETE
═══════════════════════════════════════
threads checked:  {len(rows)}
threads updated:  {updated}
voices found:     {voices_found}
cache saved:      {VOICE_CACHE}
═══════════════════════════════════════
""")
