@@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
259
259
api_key = model_api_dict ["api_key" ],
260
260
extra_body = extra_body ,
261
261
)
262
+ elif model_api_dict ["api_type" ] == "critique-labs-ai" :
263
+ prompt = conv .to_openai_api_messages ()
264
+ stream_iter = critique_api_stream_iter (
265
+ model_api_dict ["model_name" ],
266
+ prompt ,
267
+ temperature ,
268
+ top_p ,
269
+ max_new_tokens ,
270
+ api_key = model_api_dict .get ("api_key" ),
271
+ api_base = model_api_dict .get ("api_base" ),
272
+ )
262
273
else :
263
274
raise NotImplementedError ()
264
275
@@ -1345,3 +1356,146 @@ def metagen_api_stream_iter(
1345
1356
"text" : f"**API REQUEST ERROR** Reason: Unknown." ,
1346
1357
"error_code" : 1 ,
1347
1358
}
1359
+
1360
+
1361
+ def critique_api_stream_iter (
1362
+ model_name ,
1363
+ messages ,
1364
+ temperature ,
1365
+ top_p ,
1366
+ max_new_tokens ,
1367
+ api_key = None ,
1368
+ api_base = None ,
1369
+ ):
1370
+ import websockets
1371
+ import threading
1372
+ import queue
1373
+ import json
1374
+ import time
1375
+
1376
+ api_key = api_key or os .environ .get ("CRITIQUE_API_KEY" )
1377
+ if not api_key :
1378
+ yield {
1379
+ "text" : "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables." ,
1380
+ "error_code" : 1 ,
1381
+ }
1382
+ return
1383
+
1384
+ # Combine all messages into a single prompt
1385
+ prompt = ""
1386
+ for message in messages :
1387
+ if isinstance (message ["content" ], str ):
1388
+ role_prefix = f"{ message ['role' ].capitalize ()} : " if message ['role' ] != 'system' else ""
1389
+ prompt += f"{ role_prefix } { message ['content' ]} \n "
1390
+ else : # Handle content that might be a list (for multimodal)
1391
+ for content_item in message ["content" ]:
1392
+ if content_item .get ("type" ) == "text" :
1393
+ role_prefix = f"{ message ['role' ].capitalize ()} : " if message ['role' ] != 'system' else ""
1394
+ prompt += f"{ role_prefix } { content_item ['text' ]} \n "
1395
+ prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"
1396
+
1397
+ # Log request parameters
1398
+ gen_params = {
1399
+ "model" : model_name ,
1400
+ "prompt" : prompt ,
1401
+ "temperature" : temperature ,
1402
+ "top_p" : top_p ,
1403
+ "max_new_tokens" : max_new_tokens ,
1404
+ }
1405
+ logger .info (f"==== request ====\n { gen_params } " )
1406
+
1407
+ # Create a queue for communication between threads
1408
+ response_queue = queue .Queue ()
1409
+ stop_event = threading .Event ()
1410
+ connection_closed = threading .Event ()
1411
+
1412
+ # Thread function to handle WebSocket communication
1413
+ def websocket_thread ():
1414
+ import asyncio
1415
+
1416
+ async def connect_and_stream ():
1417
+ uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"
1418
+
1419
+ try :
1420
+ # Create connection with headers in the correct format
1421
+ async with websockets .connect (
1422
+ uri ,
1423
+ additional_headers = {'X-API-Key' : api_key }
1424
+ ) as websocket :
1425
+ # Send the search request
1426
+ await websocket .send (json .dumps ({
1427
+ 'prompt' : prompt ,
1428
+ }))
1429
+
1430
+ # Receive and process streaming responses
1431
+ while not stop_event .is_set ():
1432
+ try :
1433
+ response = await websocket .recv ()
1434
+ data = json .loads (response )
1435
+ response_queue .put (data )
1436
+
1437
+ # If we get an error, we're done
1438
+ if data ['type' ] == 'error' :
1439
+ break
1440
+ except websockets .exceptions .ConnectionClosed :
1441
+ # This is the expected end signal - not an error
1442
+ logger .info ("WebSocket connection closed by server - this is the expected end signal" )
1443
+ connection_closed .set () # Signal that the connection was closed normally
1444
+ break
1445
+ except Exception as e :
1446
+ # Only log as error for unexpected exceptions
1447
+ logger .error (f"WebSocket error: { str (e )} " )
1448
+ response_queue .put ({"type" : "error" , "content" : f"WebSocket error: { str (e )} " })
1449
+ finally :
1450
+ # Always set connection_closed when we exit
1451
+ connection_closed .set ()
1452
+
1453
+ asyncio .run (connect_and_stream ())
1454
+
1455
+ # Start the WebSocket thread
1456
+ thread = threading .Thread (target = websocket_thread )
1457
+ thread .daemon = True
1458
+ thread .start ()
1459
+
1460
+ try :
1461
+ text = ""
1462
+ context_info = []
1463
+
1464
+ # Process responses from the queue until connection is closed
1465
+ while not connection_closed .is_set () or not response_queue .empty ():
1466
+ try :
1467
+ # Wait for a response with timeout
1468
+ data = response_queue .get (timeout = 0.5 ) # Short timeout to check connection_closed frequently
1469
+
1470
+ if data ['type' ] == 'response' :
1471
+ text += data ['content' ]
1472
+ yield {
1473
+ "text" : text ,
1474
+ "error_code" : 0 ,
1475
+ }
1476
+ elif data ['type' ] == 'context' :
1477
+ # Collect context information
1478
+ context_info .append (data ['content' ])
1479
+ elif data ['type' ] == 'error' :
1480
+ logger .error (f"Critique API error: { data ['content' ]} " )
1481
+ yield {
1482
+ "text" : f"**API REQUEST ERROR** Reason: { data ['content' ]} " ,
1483
+ "error_code" : 1 ,
1484
+ }
1485
+ break
1486
+
1487
+ response_queue .task_done ()
1488
+ except queue .Empty :
1489
+ # Just a timeout to check if connection is closed
1490
+ continue
1491
+
1492
+ except Exception as e :
1493
+ logger .error (f"Error in critique_api_stream_iter: { str (e )} " )
1494
+ yield {
1495
+ "text" : f"**API REQUEST ERROR** Reason: { str (e )} " ,
1496
+ "error_code" : 1 ,
1497
+ }
1498
+ finally :
1499
+ # Signal the thread to stop and wait for it to finish
1500
+ stop_event .set ()
1501
+ thread .join (timeout = 5 )
0 commit comments