Skip to content

Commit 5e13884

Browse files
committed
add support for critique api for search arena.
1 parent 0e6d3e4 commit 5e13884

File tree

2 files changed

+161
-0
lines changed

2 files changed

+161
-0
lines changed

fastchat/model/model_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ def get_model_info(name: str) -> ModelInfo:
270270
"Frontier Multimodal Language Model by Reka",
271271
)
272272

273+
register_model_info(
274+
["critique-agentic-search","critique-agentic-search-api","critique-labs-ai"],
275+
"Critique Labs AI",
276+
"https://www.critique-labs.ai/",
277+
"Agentic Search Engine By Critique Labs AI",
278+
)
279+
273280
register_model_info(
274281
["gemini-pro", "gemini-pro-dev-api"],
275282
"Gemini",

fastchat/serve/api_provider.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
259259
api_key=model_api_dict["api_key"],
260260
extra_body=extra_body,
261261
)
262+
elif model_api_dict["api_type"] == "critique-labs-ai":
263+
prompt = conv.to_openai_api_messages()
264+
stream_iter = critique_api_stream_iter(
265+
model_api_dict["model_name"],
266+
prompt,
267+
temperature,
268+
top_p,
269+
max_new_tokens,
270+
api_key=model_api_dict.get("api_key"),
271+
api_base=model_api_dict.get("api_base"),
272+
)
262273
else:
263274
raise NotImplementedError()
264275

@@ -1345,3 +1356,146 @@ def metagen_api_stream_iter(
13451356
"text": f"**API REQUEST ERROR** Reason: Unknown.",
13461357
"error_code": 1,
13471358
}
1359+
1360+
1361+
def critique_api_stream_iter(
1362+
model_name,
1363+
messages,
1364+
temperature,
1365+
top_p,
1366+
max_new_tokens,
1367+
api_key=None,
1368+
api_base=None,
1369+
):
1370+
import websockets
1371+
import threading
1372+
import queue
1373+
import json
1374+
import time
1375+
1376+
api_key = api_key or os.environ.get("CRITIQUE_API_KEY")
1377+
if not api_key:
1378+
yield {
1379+
"text": "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables.",
1380+
"error_code": 1,
1381+
}
1382+
return
1383+
1384+
# Combine all messages into a single prompt
1385+
prompt = ""
1386+
for message in messages:
1387+
if isinstance(message["content"], str):
1388+
role_prefix = f"{message['role'].capitalize()}: " if message['role'] != 'system' else ""
1389+
prompt += f"{role_prefix}{message['content']}\n"
1390+
else: # Handle content that might be a list (for multimodal)
1391+
for content_item in message["content"]:
1392+
if content_item.get("type") == "text":
1393+
role_prefix = f"{message['role'].capitalize()}: " if message['role'] != 'system' else ""
1394+
prompt += f"{role_prefix}{content_item['text']}\n"
1395+
prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"
1396+
1397+
# Log request parameters
1398+
gen_params = {
1399+
"model": model_name,
1400+
"prompt": prompt,
1401+
"temperature": temperature,
1402+
"top_p": top_p,
1403+
"max_new_tokens": max_new_tokens,
1404+
}
1405+
logger.info(f"==== request ====\n{gen_params}")
1406+
1407+
# Create a queue for communication between threads
1408+
response_queue = queue.Queue()
1409+
stop_event = threading.Event()
1410+
connection_closed = threading.Event()
1411+
1412+
# Thread function to handle WebSocket communication
1413+
def websocket_thread():
1414+
import asyncio
1415+
1416+
async def connect_and_stream():
1417+
uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"
1418+
1419+
try:
1420+
# Create connection with headers in the correct format
1421+
async with websockets.connect(
1422+
uri,
1423+
additional_headers={'X-API-Key': api_key}
1424+
) as websocket:
1425+
# Send the search request
1426+
await websocket.send(json.dumps({
1427+
'prompt': prompt,
1428+
}))
1429+
1430+
# Receive and process streaming responses
1431+
while not stop_event.is_set():
1432+
try:
1433+
response = await websocket.recv()
1434+
data = json.loads(response)
1435+
response_queue.put(data)
1436+
1437+
# If we get an error, we're done
1438+
if data['type'] == 'error':
1439+
break
1440+
except websockets.exceptions.ConnectionClosed:
1441+
# This is the expected end signal - not an error
1442+
logger.info("WebSocket connection closed by server - this is the expected end signal")
1443+
connection_closed.set() # Signal that the connection was closed normally
1444+
break
1445+
except Exception as e:
1446+
# Only log as error for unexpected exceptions
1447+
logger.error(f"WebSocket error: {str(e)}")
1448+
response_queue.put({"type": "error", "content": f"WebSocket error: {str(e)}"})
1449+
finally:
1450+
# Always set connection_closed when we exit
1451+
connection_closed.set()
1452+
1453+
asyncio.run(connect_and_stream())
1454+
1455+
# Start the WebSocket thread
1456+
thread = threading.Thread(target=websocket_thread)
1457+
thread.daemon = True
1458+
thread.start()
1459+
1460+
try:
1461+
text = ""
1462+
context_info = []
1463+
1464+
# Process responses from the queue until connection is closed
1465+
while not connection_closed.is_set() or not response_queue.empty():
1466+
try:
1467+
# Wait for a response with timeout
1468+
data = response_queue.get(timeout=0.5) # Short timeout to check connection_closed frequently
1469+
1470+
if data['type'] == 'response':
1471+
text += data['content']
1472+
yield {
1473+
"text": text,
1474+
"error_code": 0,
1475+
}
1476+
elif data['type'] == 'context':
1477+
# Collect context information
1478+
context_info.append(data['content'])
1479+
elif data['type'] == 'error':
1480+
logger.error(f"Critique API error: {data['content']}")
1481+
yield {
1482+
"text": f"**API REQUEST ERROR** Reason: {data['content']}",
1483+
"error_code": 1,
1484+
}
1485+
break
1486+
1487+
response_queue.task_done()
1488+
except queue.Empty:
1489+
# Just a timeout to check if connection is closed
1490+
continue
1491+
1492+
except Exception as e:
1493+
logger.error(f"Error in critique_api_stream_iter: {str(e)}")
1494+
yield {
1495+
"text": f"**API REQUEST ERROR** Reason: {str(e)}",
1496+
"error_code": 1,
1497+
}
1498+
finally:
1499+
# Signal the thread to stop and wait for it to finish
1500+
stop_event.set()
1501+
thread.join(timeout=5)

0 commit comments

Comments
 (0)