Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Binary file not shown.
82 changes: 82 additions & 0 deletions nodes/api_nodes/runway_text2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
import time
import requests
from PIL import Image
from io import BytesIO

class RunwayTextToImage:
"""
A ComfyUI-compatible node that generates an image from a text prompt
using Runway's /v1/text_to_image API.

Attributes:
prompt (str): The text description to generate an image from.
timeout (int): Maximum time (in seconds) to wait while polling the image URL. Default is 10 seconds.

Raises:
RuntimeError:
- If the RUNWAY_API_KEY environment variable is not set.
- If the API response is invalid or an error occurs during polling.

Environment Variables:
RUNWAY_API_KEY (str): Required API key for authenticating with RunwayML.

Example:
>>> node = RunwayTextToImage(prompt="a futuristic cityscape", timeout=15)
>>> image = node.run()
>>> image.show()

Notes:
- If the API key is missing, a clear error will be raised without crashing the program.
- To increase or decrease wait time for image generation, adjust the `timeout` parameter.
"""

def __init__(self, prompt: str, timeout: int = 10):
self.prompt = prompt
self.timeout = timeout

def run(self) -> Image.Image:
api_key = os.getenv("RUNWAY_API_KEY")
if not api_key:
raise RuntimeError("RUNWAY_API_KEY is not set in environment variables.")

headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}

payload = {
"prompt": self.prompt
}

# Step 1: Send POST request to initiate generation
response = requests.post(
"https://api.dev.runwayml.com/v1/text_to_image",
json=payload,
headers=headers
)

if response.status_code != 200:
raise RuntimeError(f"API request failed with status {response.status_code}: {response.text}")

data = response.json()
image_url = data.get("image_url")
if not image_url:
raise RuntimeError("API response does not contain 'image_url'.")

# Step 2: Poll the image URL until image is ready or timeout is reached
start_time = time.time()
image_data = None

while time.time() - start_time < self.timeout:
img_response = requests.get(image_url)
if img_response.status_code == 200 and img_response.content:
image_data = img_response.content
break
time.sleep(1) # Polling interval

if not image_data:
raise RuntimeError("Failed to fetch image data within timeout period.")

# Step 3: Load image from bytes into PIL.Image
return Image.open(BytesIO(image_data))
Comment on lines +53 to +82
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding request timeout and retry logic.

For production robustness, consider adding timeouts to the requests and basic retry logic for transient failures.

 response = requests.post(
     "https://api.dev.runwayml.com/v1/text_to_image",
     json=payload,
-    headers=headers
+    headers=headers,
+    timeout=30  # Add timeout for the request
 )

 while time.time() - start_time < self.timeout:
-    img_response = requests.get(image_url)
+    img_response = requests.get(image_url, timeout=30)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
response = requests.post(
"https://api.dev.runwayml.com/v1/text_to_image",
json=payload,
headers=headers
)
if response.status_code != 200:
raise RuntimeError(f"API request failed with status {response.status_code}: {response.text}")
data = response.json()
image_url = data.get("image_url")
if not image_url:
raise RuntimeError("API response does not contain 'image_url'.")
# Step 2: Poll the image URL until image is ready or timeout is reached
start_time = time.time()
image_data = None
while time.time() - start_time < self.timeout:
img_response = requests.get(image_url)
if img_response.status_code == 200 and img_response.content:
image_data = img_response.content
break
time.sleep(1) # Polling interval
if not image_data:
raise RuntimeError("Failed to fetch image data within timeout period.")
# Step 3: Load image from bytes into PIL.Image
return Image.open(BytesIO(image_data))
response = requests.post(
"https://api.dev.runwayml.com/v1/text_to_image",
json=payload,
headers=headers,
timeout=30 # Add timeout for the request
)
if response.status_code != 200:
raise RuntimeError(f"API request failed with status {response.status_code}: {response.text}")
data = response.json()
image_url = data.get("image_url")
if not image_url:
raise RuntimeError("API response does not contain 'image_url'.")
# Step 2: Poll the image URL until image is ready or timeout is reached
start_time = time.time()
image_data = None
while time.time() - start_time < self.timeout:
img_response = requests.get(image_url, timeout=30)
if img_response.status_code == 200 and img_response.content:
image_data = img_response.content
break
time.sleep(1) # Polling interval
if not image_data:
raise RuntimeError("Failed to fetch image data within timeout period.")
# Step 3: Load image from bytes into PIL.Image
return Image.open(BytesIO(image_data))
🤖 Prompt for AI Agents
In nodes/api_nodes/runway_text2img.py between lines 53 and 82, the HTTP requests
to the API and image URL lack timeout settings and retry logic, which can cause
hangs or failures on transient network issues. Add a timeout parameter to both
requests.post and requests.get calls to limit waiting time. Implement basic
retry logic with a limited number of retries and delays for transient failures
on both requests to improve robustness.

Binary file not shown.
34 changes: 34 additions & 0 deletions tests/api_nodes/test_runway_text2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import pytest
from unittest import mock
from unittest.mock import patch

from nodes.api_nodes.runway_text2img import RunwayTextToImage

@patch('nodes.api_nodes.runway_text2img.requests.post')
@patch('nodes.api_nodes.runway_text2img.requests.get')
def test_runway_text2img_node(mock_get, mock_post):
# Set environment variable
with mock.patch.dict(os.environ, {"RUNWAY_API_KEY": "fake_api_key"}):
# Mock POST response from API
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"image_url": "http://fakeurl.com/fakeimage.png"
}

# Mock GET response with minimal valid PNG bytes
mock_get.return_value.status_code = 200
mock_get.return_value.content = (
b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01'
b'\x00\x00\x00\x01\x08\x06\x00\x00\x00\x1f\x15\xc4\x89'
b'\x00\x00\x00\nIDATx\x9cc`\x00\x00\x00\x02\x00\x01'
b'\xe2!\xbc\x33\x00\x00\x00\x00IEND\xaeB`\x82'
)

node = RunwayTextToImage(prompt="A test prompt", timeout=5)

try:
result = node.run()
assert result is not None
except Exception as e:
pytest.fail(f"Node run method raised an exception: {e}")
Loading