@@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
259
259
api_key = model_api_dict ["api_key" ],
260
260
extra_body = extra_body ,
261
261
)
262
+ elif model_api_dict ["api_type" ] == "critique-labs-ai" :
263
+ prompt = conv .to_openai_api_messages ()
264
+ stream_iter = critique_api_stream_iter (
265
+ model_api_dict ["model_name" ],
266
+ prompt ,
267
+ temperature ,
268
+ top_p ,
269
+ max_new_tokens ,
270
+ api_key = model_api_dict .get ("api_key" ),
271
+ api_base = model_api_dict .get ("api_base" ),
272
+ )
262
273
else :
263
274
raise NotImplementedError ()
264
275
@@ -1345,3 +1356,163 @@ def metagen_api_stream_iter(
1345
1356
"text" : f"**API REQUEST ERROR** Reason: Unknown." ,
1346
1357
"error_code" : 1 ,
1347
1358
}
1359
+
1360
+
1361
+ def critique_api_stream_iter (
1362
+ model_name ,
1363
+ messages ,
1364
+ temperature ,
1365
+ top_p ,
1366
+ max_new_tokens ,
1367
+ api_key = None ,
1368
+ api_base = None ,
1369
+ ):
1370
+ import websockets
1371
+ import threading
1372
+ import queue
1373
+ import json
1374
+ import time
1375
+
1376
+ api_key = api_key or os .environ .get ("CRITIQUE_API_KEY" )
1377
+ if not api_key :
1378
+ yield {
1379
+ "text" : "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables." ,
1380
+ "error_code" : 1 ,
1381
+ }
1382
+ return
1383
+
1384
+ # Combine all messages into a single prompt
1385
+ prompt = ""
1386
+ for message in messages :
1387
+ if isinstance (message ["content" ], str ):
1388
+ role_prefix = (
1389
+ f"{ message ['role' ].capitalize ()} : "
1390
+ if message ["role" ] != "system"
1391
+ else ""
1392
+ )
1393
+ prompt += f"{ role_prefix } { message ['content' ]} \n "
1394
+ else : # Handle content that might be a list (for multimodal)
1395
+ for content_item in message ["content" ]:
1396
+ if content_item .get ("type" ) == "text" :
1397
+ role_prefix = (
1398
+ f"{ message ['role' ].capitalize ()} : "
1399
+ if message ["role" ] != "system"
1400
+ else ""
1401
+ )
1402
+ prompt += f"{ role_prefix } { content_item ['text' ]} \n "
1403
+ prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"
1404
+
1405
+ # Log request parameters
1406
+ gen_params = {
1407
+ "model" : model_name ,
1408
+ "prompt" : prompt ,
1409
+ "temperature" : temperature ,
1410
+ "top_p" : top_p ,
1411
+ "max_new_tokens" : max_new_tokens ,
1412
+ }
1413
+ logger .info (f"==== request ====\n { gen_params } " )
1414
+
1415
+ # Create a queue for communication between threads
1416
+ response_queue = queue .Queue ()
1417
+ stop_event = threading .Event ()
1418
+ connection_closed = threading .Event ()
1419
+
1420
+ # Thread function to handle WebSocket communication
1421
+ def websocket_thread ():
1422
+ import asyncio
1423
+
1424
+ async def connect_and_stream ():
1425
+ uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"
1426
+
1427
+ try :
1428
+ # Create connection with headers in the correct format
1429
+ async with websockets .connect (
1430
+ uri , additional_headers = {"X-API-Key" : api_key }
1431
+ ) as websocket :
1432
+ # Send the search request
1433
+ await websocket .send (
1434
+ json .dumps (
1435
+ {
1436
+ "prompt" : prompt ,
1437
+ }
1438
+ )
1439
+ )
1440
+
1441
+ # Receive and process streaming responses
1442
+ while not stop_event .is_set ():
1443
+ try :
1444
+ response = await websocket .recv ()
1445
+ data = json .loads (response )
1446
+ response_queue .put (data )
1447
+
1448
+ # If we get an error, we're done
1449
+ if data ["type" ] == "error" :
1450
+ break
1451
+ except websockets .exceptions .ConnectionClosed :
1452
+ # This is the expected end signal - not an error
1453
+ logger .info (
1454
+ "WebSocket connection closed by server - this is the expected end signal"
1455
+ )
1456
+ connection_closed .set () # Signal that the connection was closed normally
1457
+ break
1458
+ except Exception as e :
1459
+ # Only log as error for unexpected exceptions
1460
+ logger .error (f"WebSocket error: { str (e )} " )
1461
+ response_queue .put (
1462
+ {"type" : "error" , "content" : f"WebSocket error: { str (e )} " }
1463
+ )
1464
+ finally :
1465
+ # Always set connection_closed when we exit
1466
+ connection_closed .set ()
1467
+
1468
+ asyncio .run (connect_and_stream ())
1469
+
1470
+ # Start the WebSocket thread
1471
+ thread = threading .Thread (target = websocket_thread )
1472
+ thread .daemon = True
1473
+ thread .start ()
1474
+
1475
+ try :
1476
+ text = ""
1477
+ context_info = []
1478
+
1479
+ # Process responses from the queue until connection is closed
1480
+ while not connection_closed .is_set () or not response_queue .empty ():
1481
+ try :
1482
+ # Wait for a response with timeout
1483
+ data = response_queue .get (
1484
+ timeout = 0.5
1485
+ ) # Short timeout to check connection_closed frequently
1486
+
1487
+ if data ["type" ] == "response" :
1488
+ text += data ["content" ]
1489
+ yield {
1490
+ "text" : text ,
1491
+ "error_code" : 0 ,
1492
+ }
1493
+ elif data ["type" ] == "context" :
1494
+ # Collect context information
1495
+ context_info .append (data ["content" ])
1496
+ elif data ["type" ] == "error" :
1497
+ logger .error (f"Critique API error: { data ['content' ]} " )
1498
+ yield {
1499
+ "text" : f"**API REQUEST ERROR** Reason: { data ['content' ]} " ,
1500
+ "error_code" : 1 ,
1501
+ }
1502
+ break
1503
+
1504
+ response_queue .task_done ()
1505
+ except queue .Empty :
1506
+ # Just a timeout to check if connection is closed
1507
+ continue
1508
+
1509
+ except Exception as e :
1510
+ logger .error (f"Error in critique_api_stream_iter: { str (e )} " )
1511
+ yield {
1512
+ "text" : f"**API REQUEST ERROR** Reason: { str (e )} " ,
1513
+ "error_code" : 1 ,
1514
+ }
1515
+ finally :
1516
+ # Signal the thread to stop and wait for it to finish
1517
+ stop_event .set ()
1518
+ thread .join (timeout = 5 )
0 commit comments