from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import requests
import json
import time
import copy
import random
from typing import Dict, Any, Optional
import os
app = FastAPI()
# 🔒 Configuration CORS — autorise le frontend sur localhost:8084
app.add_middleware(
CORSMiddleware,
allow_origins=["http://localhost:8084"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# đź”§ Configuration
COMFYUI_URL = "http://127.0.0.1:8188"
WORKFLOW_PATH = os.path.join(os.path.dirname(__file__), "Jaugernaut_wrkf.json")
MAX_WAIT_TIME = 60 # seconds maximum d'attente pour la génération
POLL_INTERVAL = 0.5 # secondes entre chaque vérification
# 🔽 Charger le workflow de base une seule fois au démarrage
try:
with open(WORKFLOW_PATH, "r") as f:
BASE_WORKFLOW: Dict[str, Any] = json.load(f)
except FileNotFoundError:
raise RuntimeError(f"Fichier de workflow introuvable : {WORKFLOW_PATH}")
except json.JSONDecodeError as e:
raise RuntimeError(f"Erreur de parsing JSON dans {WORKFLOW_PATH} : {e}")
# 🔍 Fonction pour trouver un nœud par son titre dans _meta
def find_node_by_title(workflow: Dict[str, Any], title: str) -> Optional[str]:
for node_id, node in workflow.items():
if isinstance(node, dict) and node.get("_meta", {}).get("title") == title:
return node_id
return None
# 🔍 Localiser une fois au démarrage les IDs des nœuds de prompt par leurs titres
POS_NODE_ID = find_node_by_title(BASE_WORKFLOW, "Positive (Prompt)")
NEG_NODE_ID = find_node_by_title(BASE_WORKFLOW, "Negative (Prompt)")
if not POS_NODE_ID:
raise RuntimeError("Nœud de prompt positif introuvable : titre 'Positive (Prompt)' non trouvé dans le workflow.")
if not NEG_NODE_ID:
raise RuntimeError("Nœud de prompt négatif introuvable : titre 'Negative (Prompt)' non trouvé dans le workflow.")
# Vérifier l’unicité (au cas où)
def count_nodes_by_title(workflow: Dict[str, Any], title: str) -> int:
return sum(1 for node in workflow.values() if isinstance(node, dict) and node.get("_meta", {}).get("title") == title)
if count_nodes_by_title(BASE_WORKFLOW, "Positive (Prompt)") != 1:
raise RuntimeError("Titre 'Positive (Prompt)' trouvé plusieurs fois ou pas du tout — doit être unique.")
if count_nodes_by_title(BASE_WORKFLOW, "Negative (Prompt)") != 1:
raise RuntimeError("Titre 'Negative (Prompt)' trouvé plusieurs fois ou pas du tout — doit être unique.")
# 📥 Modèle de requête
class ImageRequest(BaseModel):
positive_prompt: str
negative_prompt: str = "text, watermark"
@app.post("/generate-image")
async def generate_image(data: ImageRequest):
print(f"[MCP] Received image generation request:")
print(f"[MCP] Positive: {data.positive_prompt[:50]}...")
print(f"[MCP] Negative: {data.negative_prompt[:50]}...")
# Faire une copie profonde du workflow de base pour éviter les effets de bord
workflow = copy.deepcopy(BASE_WORKFLOW)
# Randomizer les seeds pour éviter le cache et obtenir des images différentes
for node in workflow.values():
if isinstance(node, dict) and "inputs" in node and isinstance(node["inputs"], dict):
if "seed" in node["inputs"]:
old_seed = node["inputs"]["seed"]
new_seed = random.randint(0, 2**32 - 1)
node["inputs"]["seed"] = new_seed
node_title = node.get("_meta", {}).get("title", "unknown")
print(f"[MCP] Randomizing seed for '{node_title}': {old_seed} -> {new_seed}")
# Injecter les prompts dans les nœuds identifiés
try:
workflow[POS_NODE_ID]["inputs"]["text"] = data.positive_prompt
workflow[NEG_NODE_ID]["inputs"]["text"] = data.negative_prompt
except KeyError as e:
print(f"[MCP] ERROR: Node {e} not found in workflow")
raise HTTPException(status_code=500, detail=f"Échec d'injection du prompt : nœud {e} non trouvé")
except Exception as e:
print(f"[MCP] ERROR: Failed to inject prompts: {e}")
raise HTTPException(status_code=500, detail=f"Erreur lors de l'injection des prompts : {e}")
# Envoyer le workflow Ă ComfyUI
try:
print(f"[MCP] Sending workflow to ComfyUI at {COMFYUI_URL}/prompt")
response = requests.post(
f"{COMFYUI_URL}/prompt",
json={"prompt": workflow},
timeout=10
)
response.raise_for_status()
print(f"[MCP] Workflow accepted by ComfyUI (status {response.status_code})")
except requests.RequestException as e:
print(f"[MCP] ERROR: Cannot reach ComfyUI: {e}")
raise HTTPException(status_code=502, detail=f"Impossible de contacter ComfyUI : {e}")
# Extraire le prompt_id
try:
prompt_id = response.json()["prompt_id"]
print(f"[MCP] Got prompt_id: {prompt_id}")
except (KeyError, json.JSONDecodeError) as e:
print(f"[MCP] ERROR: Invalid response from ComfyUI: {e}")
raise HTTPException(status_code=502, detail="Réponse invalide de ComfyUI : prompt_id manquant")
# Attendre le résultat avec timeout
start_time = time.time()
poll_count = 0
while time.time() - start_time < MAX_WAIT_TIME:
poll_count += 1
try:
print(f"[MCP] Polling history for {prompt_id} (attempt {poll_count})")
history_response = requests.get(f"{COMFYUI_URL}/history/{prompt_id}", timeout=10)
history_response.raise_for_status()
history = history_response.json()
if prompt_id in history:
print(f"[MCP] Prompt {prompt_id} found in history")
print(f"[MCP] FULL HISTORY ENTRY: {json.dumps(history[prompt_id], indent=2)}")
if history[prompt_id].get("outputs"):
outputs = history[prompt_id]["outputs"]
print(f"[MCP] Outputs nodes: {list(outputs.keys())}")
# Parcourir tous les nœuds de sortie pour trouver celui qui contient des images
for node_id in outputs:
if "images" in outputs[node_id] and len(outputs[node_id]["images"]) > 0:
print(f"[MCP] Found images in node {node_id}")
img_info = outputs[node_id]["images"][0]
filename = img_info["filename"]
subfolder = img_info["subfolder"]
img_type = img_info["type"]
print(f"[MCP] Image details: {filename}, {subfolder}, {img_type}")
# Laisser le temps Ă ComfyUI de finaliser le job
print(f"[MCP] Sleeping 1s before returning...")
time.sleep(1.0)
# Construire l'URL d'accès à l'image
image_url = f"{COMFYUI_URL}/view?filename={filename}&subfolder={subfolder}&type={img_type}"
print(f"[MCP] Returning response: {image_url}")
return {"image_url": image_url, "prompt_id": prompt_id}
else:
print(f"[MCP] No outputs yet for prompt {prompt_id}")
else:
print(f"[MCP] Prompt {prompt_id} not yet in history")
except requests.RequestException as e:
print(f"[MCP] Request error during poll: {e}")
pass # On ignore les erreurs temporaires et on retry
time.sleep(POLL_INTERVAL)
# Si on sort de la boucle, c'est un timeout
print(f"[MCP] TIMEOUT after {MAX_WAIT_TIME}s for prompt {prompt_id}")
raise HTTPException(
status_code=504,
detail=f"Timeout : aucune image générée après {MAX_WAIT_TIME} secondes. Vérifiez ComfyUI et la charge du GPU."
)
@app.get("/")
def read_root():
return {"status": "MCP server for ComfyUI image generation"}