from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
import time
import copy
import random
from typing import Dict, Any, Optional
import os

app = FastAPI()

# 🔒 Configuration CORS — autorise le frontend sur localhost:8084
app.add_middleware(
    CORSMiddleware,
    allow_origins=["http://localhost:8084"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# 🔧 Configuration
COMFYUI_URL = "http://127.0.0.1:8188"
WORKFLOW_PATH = os.path.join(os.path.dirname(__file__), "Jaugernaut_wrkf.json")
MAX_WAIT_TIME = 60  # seconds maximum d'attente pour la génération
POLL_INTERVAL = 0.5  # secondes entre chaque vérification

# 🔽 Charger le workflow de base une seule fois au démarrage
try:
    with open(WORKFLOW_PATH, "r") as f:
        BASE_WORKFLOW: Dict[str, Any] = json.load(f)
except FileNotFoundError:
    raise RuntimeError(f"Fichier de workflow introuvable : {WORKFLOW_PATH}")
except json.JSONDecodeError as e:
    raise RuntimeError(f"Erreur de parsing JSON dans {WORKFLOW_PATH} : {e}")

# 🔍 Fonction pour trouver un nœud par son titre dans _meta
def find_node_by_title(workflow: Dict[str, Any], title: str) -> Optional[str]:
    for node_id, node in workflow.items():
        if isinstance(node, dict) and node.get("_meta", {}).get("title") == title:
            return node_id
    return None

# 🔍 Localiser une fois au démarrage les IDs des nœuds de prompt par leurs titres
POS_NODE_ID = find_node_by_title(BASE_WORKFLOW, "Positive (Prompt)")
NEG_NODE_ID = find_node_by_title(BASE_WORKFLOW, "Negative (Prompt)")

if not POS_NODE_ID:
    raise RuntimeError("Nœud de prompt positif introuvable : titre 'Positive (Prompt)' non trouvé dans le workflow.")
if not NEG_NODE_ID:
    raise RuntimeError("Nœud de prompt négatif introuvable : titre 'Negative (Prompt)' non trouvé dans le workflow.")

# Vérifier l’unicité (au cas où)
def count_nodes_by_title(workflow: Dict[str, Any], title: str) -> int:
    return sum(1 for node in workflow.values() if isinstance(node, dict) and node.get("_meta", {}).get("title") == title)

if count_nodes_by_title(BASE_WORKFLOW, "Positive (Prompt)") != 1:
    raise RuntimeError("Titre 'Positive (Prompt)' trouvé plusieurs fois ou pas du tout — doit être unique.")
if count_nodes_by_title(BASE_WORKFLOW, "Negative (Prompt)") != 1:
    raise RuntimeError("Titre 'Negative (Prompt)' trouvé plusieurs fois ou pas du tout — doit être unique.")

# 📥 Modèle de requête
class ImageRequest(BaseModel):
    positive_prompt: str
    negative_prompt: str = "text, watermark"

@app.post("/generate-image")
async def generate_image(data: ImageRequest):
    print(f"[MCP] Received image generation request:")
    print(f"[MCP]   Positive: {data.positive_prompt[:50]}...")
    print(f"[MCP]   Negative: {data.negative_prompt[:50]}...")

    # Faire une copie profonde du workflow de base pour éviter les effets de bord
    workflow = copy.deepcopy(BASE_WORKFLOW)

    # Randomizer les seeds pour éviter le cache et obtenir des images différentes
    for node in workflow.values():
        if isinstance(node, dict) and "inputs" in node and isinstance(node["inputs"], dict):
            if "seed" in node["inputs"]:
                old_seed = node["inputs"]["seed"]
                new_seed = random.randint(0, 2**32 - 1)
                node["inputs"]["seed"] = new_seed
                node_title = node.get("_meta", {}).get("title", "unknown")
                print(f"[MCP] Randomizing seed for '{node_title}': {old_seed} -> {new_seed}")

    # Injecter les prompts dans les nœuds identifiés
    try:
        workflow[POS_NODE_ID]["inputs"]["text"] = data.positive_prompt
        workflow[NEG_NODE_ID]["inputs"]["text"] = data.negative_prompt
    except KeyError as e:
        print(f"[MCP] ERROR: Node {e} not found in workflow")
        raise HTTPException(status_code=500, detail=f"Échec d'injection du prompt : nœud {e} non trouvé")
    except Exception as e:
        print(f"[MCP] ERROR: Failed to inject prompts: {e}")
        raise HTTPException(status_code=500, detail=f"Erreur lors de l'injection des prompts : {e}")

    # Envoyer le workflow à ComfyUI
    try:
        print(f"[MCP] Sending workflow to ComfyUI at {COMFYUI_URL}/prompt")
        response = requests.post(
            f"{COMFYUI_URL}/prompt",
            json={"prompt": workflow},
            timeout=10
        )
        response.raise_for_status()
        print(f"[MCP] Workflow accepted by ComfyUI (status {response.status_code})")
    except requests.RequestException as e:
        print(f"[MCP] ERROR: Cannot reach ComfyUI: {e}")
        raise HTTPException(status_code=502, detail=f"Impossible de contacter ComfyUI : {e}")

    # Extraire le prompt_id
    try:
        prompt_id = response.json()["prompt_id"]
        print(f"[MCP] Got prompt_id: {prompt_id}")
    except (KeyError, json.JSONDecodeError) as e:
        print(f"[MCP] ERROR: Invalid response from ComfyUI: {e}")
        raise HTTPException(status_code=502, detail="Réponse invalide de ComfyUI : prompt_id manquant")

    # Attendre le résultat avec timeout
    start_time = time.time()
    poll_count = 0
    while time.time() - start_time < MAX_WAIT_TIME:
        poll_count += 1
        try:
            print(f"[MCP] Polling history for {prompt_id} (attempt {poll_count})")
            history_response = requests.get(f"{COMFYUI_URL}/history/{prompt_id}", timeout=10)
            history_response.raise_for_status()
            history = history_response.json()

            if prompt_id in history:
                print(f"[MCP] Prompt {prompt_id} found in history")
                print(f"[MCP] FULL HISTORY ENTRY: {json.dumps(history[prompt_id], indent=2)}")
                if history[prompt_id].get("outputs"):
                    outputs = history[prompt_id]["outputs"]
                    print(f"[MCP] Outputs nodes: {list(outputs.keys())}")
                    # Parcourir tous les nœuds de sortie pour trouver celui qui contient des images
                    for node_id in outputs:
                        if "images" in outputs[node_id] and len(outputs[node_id]["images"]) > 0:
                            print(f"[MCP] Found images in node {node_id}")
                            img_info = outputs[node_id]["images"][0]
                            filename = img_info["filename"]
                            subfolder = img_info["subfolder"]
                            img_type = img_info["type"]
                            print(f"[MCP] Image details: {filename}, {subfolder}, {img_type}")
                            # Laisser le temps à ComfyUI de finaliser le job
                            print(f"[MCP] Sleeping 1s before returning...")
                            time.sleep(1.0)
                            # Construire l'URL d'accès à l'image
                            image_url = f"{COMFYUI_URL}/view?filename={filename}&subfolder={subfolder}&type={img_type}"
                            print(f"[MCP] Returning response: {image_url}")
                            return {"image_url": image_url, "prompt_id": prompt_id}
                else:
                    print(f"[MCP] No outputs yet for prompt {prompt_id}")
            else:
                print(f"[MCP] Prompt {prompt_id} not yet in history")
        except requests.RequestException as e:
            print(f"[MCP] Request error during poll: {e}")
            pass  # On ignore les erreurs temporaires et on retry

        time.sleep(POLL_INTERVAL)

    # Si on sort de la boucle, c'est un timeout
    print(f"[MCP] TIMEOUT after {MAX_WAIT_TIME}s for prompt {prompt_id}")
    raise HTTPException(
        status_code=504,
        detail=f"Timeout : aucune image générée après {MAX_WAIT_TIME} secondes. Vérifiez ComfyUI et la charge du GPU."
    )

@app.get("/")
def read_root():
    return {"status": "MCP server for ComfyUI image generation"}