Skip to content

Commit 0fa45e5

Browse files
committed
black reformatting
add support for critique api for search arena.
1 parent 0e6d3e4 commit 0fa45e5

File tree

2 files changed

+178
-0
lines changed

2 files changed

+178
-0
lines changed

fastchat/model/model_registry.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,13 @@ def get_model_info(name: str) -> ModelInfo:
270270
"Frontier Multimodal Language Model by Reka",
271271
)
272272

273+
register_model_info(
274+
["critique-agentic-search", "critique-agentic-search-api", "critique-labs-ai"],
275+
"Critique Labs AI",
276+
"https://www.critique-labs.ai/",
277+
"Agentic Search Engine By Critique Labs AI",
278+
)
279+
273280
register_model_info(
274281
["gemini-pro", "gemini-pro-dev-api"],
275282
"Gemini",

fastchat/serve/api_provider.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
259259
api_key=model_api_dict["api_key"],
260260
extra_body=extra_body,
261261
)
262+
elif model_api_dict["api_type"] == "critique-labs-ai":
263+
prompt = conv.to_openai_api_messages()
264+
stream_iter = critique_api_stream_iter(
265+
model_api_dict["model_name"],
266+
prompt,
267+
temperature,
268+
top_p,
269+
max_new_tokens,
270+
api_key=model_api_dict.get("api_key"),
271+
api_base=model_api_dict.get("api_base"),
272+
)
262273
else:
263274
raise NotImplementedError()
264275

@@ -1345,3 +1356,163 @@ def metagen_api_stream_iter(
13451356
"text": f"**API REQUEST ERROR** Reason: Unknown.",
13461357
"error_code": 1,
13471358
}
1359+
1360+
1361+
def critique_api_stream_iter(
1362+
model_name,
1363+
messages,
1364+
temperature,
1365+
top_p,
1366+
max_new_tokens,
1367+
api_key=None,
1368+
api_base=None,
1369+
):
1370+
import websockets
1371+
import threading
1372+
import queue
1373+
import json
1374+
import time
1375+
1376+
api_key = api_key or os.environ.get("CRITIQUE_API_KEY")
1377+
if not api_key:
1378+
yield {
1379+
"text": "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables.",
1380+
"error_code": 1,
1381+
}
1382+
return
1383+
1384+
# Combine all messages into a single prompt
1385+
prompt = ""
1386+
for message in messages:
1387+
if isinstance(message["content"], str):
1388+
role_prefix = (
1389+
f"{message['role'].capitalize()}: "
1390+
if message["role"] != "system"
1391+
else ""
1392+
)
1393+
prompt += f"{role_prefix}{message['content']}\n"
1394+
else: # Handle content that might be a list (for multimodal)
1395+
for content_item in message["content"]:
1396+
if content_item.get("type") == "text":
1397+
role_prefix = (
1398+
f"{message['role'].capitalize()}: "
1399+
if message["role"] != "system"
1400+
else ""
1401+
)
1402+
prompt += f"{role_prefix}{content_item['text']}\n"
1403+
prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"
1404+
1405+
# Log request parameters
1406+
gen_params = {
1407+
"model": model_name,
1408+
"prompt": prompt,
1409+
"temperature": temperature,
1410+
"top_p": top_p,
1411+
"max_new_tokens": max_new_tokens,
1412+
}
1413+
logger.info(f"==== request ====\n{gen_params}")
1414+
1415+
# Create a queue for communication between threads
1416+
response_queue = queue.Queue()
1417+
stop_event = threading.Event()
1418+
connection_closed = threading.Event()
1419+
1420+
# Thread function to handle WebSocket communication
1421+
def websocket_thread():
1422+
import asyncio
1423+
1424+
async def connect_and_stream():
1425+
uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"
1426+
1427+
try:
1428+
# Create connection with headers in the correct format
1429+
async with websockets.connect(
1430+
uri, additional_headers={"X-API-Key": api_key}
1431+
) as websocket:
1432+
# Send the search request
1433+
await websocket.send(
1434+
json.dumps(
1435+
{
1436+
"prompt": prompt,
1437+
}
1438+
)
1439+
)
1440+
1441+
# Receive and process streaming responses
1442+
while not stop_event.is_set():
1443+
try:
1444+
response = await websocket.recv()
1445+
data = json.loads(response)
1446+
response_queue.put(data)
1447+
1448+
# If we get an error, we're done
1449+
if data["type"] == "error":
1450+
break
1451+
except websockets.exceptions.ConnectionClosed:
1452+
# This is the expected end signal - not an error
1453+
logger.info(
1454+
"WebSocket connection closed by server - this is the expected end signal"
1455+
)
1456+
connection_closed.set() # Signal that the connection was closed normally
1457+
break
1458+
except Exception as e:
1459+
# Only log as error for unexpected exceptions
1460+
logger.error(f"WebSocket error: {str(e)}")
1461+
response_queue.put(
1462+
{"type": "error", "content": f"WebSocket error: {str(e)}"}
1463+
)
1464+
finally:
1465+
# Always set connection_closed when we exit
1466+
connection_closed.set()
1467+
1468+
asyncio.run(connect_and_stream())
1469+
1470+
# Start the WebSocket thread
1471+
thread = threading.Thread(target=websocket_thread)
1472+
thread.daemon = True
1473+
thread.start()
1474+
1475+
try:
1476+
text = ""
1477+
context_info = []
1478+
1479+
# Process responses from the queue until connection is closed
1480+
while not connection_closed.is_set() or not response_queue.empty():
1481+
try:
1482+
# Wait for a response with timeout
1483+
data = response_queue.get(
1484+
timeout=0.5
1485+
) # Short timeout to check connection_closed frequently
1486+
1487+
if data["type"] == "response":
1488+
text += data["content"]
1489+
yield {
1490+
"text": text,
1491+
"error_code": 0,
1492+
}
1493+
elif data["type"] == "context":
1494+
# Collect context information
1495+
context_info.append(data["content"])
1496+
elif data["type"] == "error":
1497+
logger.error(f"Critique API error: {data['content']}")
1498+
yield {
1499+
"text": f"**API REQUEST ERROR** Reason: {data['content']}",
1500+
"error_code": 1,
1501+
}
1502+
break
1503+
1504+
response_queue.task_done()
1505+
except queue.Empty:
1506+
# Just a timeout to check if connection is closed
1507+
continue
1508+
1509+
except Exception as e:
1510+
logger.error(f"Error in critique_api_stream_iter: {str(e)}")
1511+
yield {
1512+
"text": f"**API REQUEST ERROR** Reason: {str(e)}",
1513+
"error_code": 1,
1514+
}
1515+
finally:
1516+
# Signal the thread to stop and wait for it to finish
1517+
stop_event.set()
1518+
thread.join(timeout=5)

0 commit comments

Comments
 (0)