""" Image generation abstraction layer. Supports DALL-E (OpenAI) and Replicate (Stability AI SDXL) for image generation. """ import time from abc import ABC, abstractmethod import httpx TIMEOUT = 120.0 POLL_INTERVAL = 2.0 MAX_POLL_ATTEMPTS = 60 class ImageProvider(ABC): """Abstract base class for image generation providers.""" def __init__(self, api_key: str, model: str | None = None): self.api_key = api_key self.model = model @abstractmethod def generate(self, prompt: str, size: str = "1024x1024") -> str: """Generate an image from a text prompt. Args: prompt: Text description of the image to generate. size: Image dimensions as 'WIDTHxHEIGHT' string. Returns: URL of the generated image. """ ... class DallEProvider(ImageProvider): """OpenAI DALL-E 3 image generation provider.""" API_URL = "https://api.openai.com/v1/images/generations" def __init__(self, api_key: str, model: str | None = None): super().__init__(api_key, model or "dall-e-3") def generate(self, prompt: str, size: str = "1024x1024") -> str: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } payload = { "model": self.model, "prompt": prompt, "n": 1, "size": size, "response_format": "url", } try: with httpx.Client(timeout=TIMEOUT) as client: response = client.post(self.API_URL, headers=headers, json=payload) response.raise_for_status() data = response.json() return data["data"][0]["url"] except httpx.HTTPStatusError as e: raise RuntimeError( f"DALL-E API error {e.response.status_code}: {e.response.text}" ) from e except httpx.RequestError as e: raise RuntimeError(f"DALL-E API request failed: {e}") from e class ReplicateProvider(ImageProvider): """Replicate image generation provider using Stability AI SDXL.""" API_URL = "https://api.replicate.com/v1/predictions" def __init__(self, api_key: str, model: str | None = None): super().__init__(api_key, model or "stability-ai/sdxl:latest") def generate(self, prompt: str, size: str = "1024x1024") -> str: headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } # Parse size into width and height try: width, height = (int(d) for d in size.split("x")) except ValueError: width, height = 1024, 1024 # Determine the version string from the model # Replicate expects "owner/model:version" or uses the version hash directly version = self.model payload = { "version": version, "input": { "prompt": prompt, "width": width, "height": height, }, } try: with httpx.Client(timeout=TIMEOUT) as client: # Create prediction response = client.post(self.API_URL, headers=headers, json=payload) response.raise_for_status() prediction = response.json() prediction_url = prediction.get("urls", {}).get("get") if not prediction_url: prediction_id = prediction.get("id") prediction_url = f"{self.API_URL}/{prediction_id}" # Poll for completion for _ in range(MAX_POLL_ATTEMPTS): poll_response = client.get(prediction_url, headers=headers) poll_response.raise_for_status() result = poll_response.json() status = result.get("status") if status == "succeeded": output = result.get("output") if isinstance(output, list) and output: return output[0] if isinstance(output, str): return output raise RuntimeError( f"Replicate returned unexpected output format: {output}" ) if status == "failed": error = result.get("error", "Unknown error") raise RuntimeError(f"Replicate prediction failed: {error}") if status == "canceled": raise RuntimeError("Replicate prediction was canceled") time.sleep(POLL_INTERVAL) raise RuntimeError( "Replicate prediction timed out after polling" ) except httpx.HTTPStatusError as e: raise RuntimeError( f"Replicate API error {e.response.status_code}: {e.response.text}" ) from e except httpx.RequestError as e: raise RuntimeError(f"Replicate API request failed: {e}") from e def get_image_provider( provider_name: str, api_key: str, model: str | None = None ) -> ImageProvider: """Factory function to get an image generation provider instance. Args: provider_name: One of 'dalle', 'replicate'. api_key: API key for the provider. model: Optional model override. Uses default if not specified. Returns: An ImageProvider instance. Raises: ValueError: If provider_name is not supported. """ providers = { "dalle": DallEProvider, "replicate": ReplicateProvider, } provider_cls = providers.get(provider_name.lower()) if provider_cls is None: supported = ", ".join(providers.keys()) raise ValueError( f"Unknown image provider '{provider_name}'. Supported: {supported}" ) return provider_cls(api_key=api_key, model=model)