@@ -259,6 +259,17 @@ def get_api_provider_stream_iter(
259259 api_key = model_api_dict ["api_key" ],
260260 extra_body = extra_body ,
261261 )
262+ elif model_api_dict ["api_type" ] == "critique-labs-ai" :
263+ prompt = conv .to_openai_api_messages ()
264+ stream_iter = critique_api_stream_iter (
265+ model_api_dict ["model_name" ],
266+ prompt ,
267+ temperature ,
268+ top_p ,
269+ max_new_tokens ,
270+ api_key = model_api_dict .get ("api_key" ),
271+ api_base = model_api_dict .get ("api_base" ),
272+ )
262273 else :
263274 raise NotImplementedError ()
264275
@@ -1345,3 +1356,146 @@ def metagen_api_stream_iter(
13451356 "text" : f"**API REQUEST ERROR** Reason: Unknown." ,
13461357 "error_code" : 1 ,
13471358 }
1359+
1360+
1361+ def critique_api_stream_iter (
1362+ model_name ,
1363+ messages ,
1364+ temperature ,
1365+ top_p ,
1366+ max_new_tokens ,
1367+ api_key = None ,
1368+ api_base = None ,
1369+ ):
1370+ import websockets
1371+ import threading
1372+ import queue
1373+ import json
1374+ import time
1375+
1376+ api_key = api_key or os .environ .get ("CRITIQUE_API_KEY" )
1377+ if not api_key :
1378+ yield {
1379+ "text" : "**API REQUEST ERROR** Reason: CRITIQUE_API_KEY not found in environment variables." ,
1380+ "error_code" : 1 ,
1381+ }
1382+ return
1383+
1384+ # Combine all messages into a single prompt
1385+ prompt = ""
1386+ for message in messages :
1387+ if isinstance (message ["content" ], str ):
1388+ role_prefix = f"{ message ['role' ].capitalize ()} : " if message ['role' ] != 'system' else ""
1389+ prompt += f"{ role_prefix } { message ['content' ]} \n "
1390+ else : # Handle content that might be a list (for multimodal)
1391+ for content_item in message ["content" ]:
1392+ if content_item .get ("type" ) == "text" :
1393+ role_prefix = f"{ message ['role' ].capitalize ()} : " if message ['role' ] != 'system' else ""
1394+ prompt += f"{ role_prefix } { content_item ['text' ]} \n "
1395+ prompt += "\n DO NOT RESPONSE IN MARKDOWN or provide any citations"
1396+
1397+ # Log request parameters
1398+ gen_params = {
1399+ "model" : model_name ,
1400+ "prompt" : prompt ,
1401+ "temperature" : temperature ,
1402+ "top_p" : top_p ,
1403+ "max_new_tokens" : max_new_tokens ,
1404+ }
1405+ logger .info (f"==== request ====\n { gen_params } " )
1406+
1407+ # Create a queue for communication between threads
1408+ response_queue = queue .Queue ()
1409+ stop_event = threading .Event ()
1410+ connection_closed = threading .Event ()
1411+
1412+ # Thread function to handle WebSocket communication
1413+ def websocket_thread ():
1414+ import asyncio
1415+
1416+ async def connect_and_stream ():
1417+ uri = api_base or "wss://api.critique-labs.ai/v1/ws/search"
1418+
1419+ try :
1420+ # Create connection with headers in the correct format
1421+ async with websockets .connect (
1422+ uri ,
1423+ additional_headers = {'X-API-Key' : api_key }
1424+ ) as websocket :
1425+ # Send the search request
1426+ await websocket .send (json .dumps ({
1427+ 'prompt' : prompt ,
1428+ }))
1429+
1430+ # Receive and process streaming responses
1431+ while not stop_event .is_set ():
1432+ try :
1433+ response = await websocket .recv ()
1434+ data = json .loads (response )
1435+ response_queue .put (data )
1436+
1437+ # If we get an error, we're done
1438+ if data ['type' ] == 'error' :
1439+ break
1440+ except websockets .exceptions .ConnectionClosed :
1441+ # This is the expected end signal - not an error
1442+ logger .info ("WebSocket connection closed by server - this is the expected end signal" )
1443+ connection_closed .set () # Signal that the connection was closed normally
1444+ break
1445+ except Exception as e :
1446+ # Only log as error for unexpected exceptions
1447+ logger .error (f"WebSocket error: { str (e )} " )
1448+ response_queue .put ({"type" : "error" , "content" : f"WebSocket error: { str (e )} " })
1449+ finally :
1450+ # Always set connection_closed when we exit
1451+ connection_closed .set ()
1452+
1453+ asyncio .run (connect_and_stream ())
1454+
1455+ # Start the WebSocket thread
1456+ thread = threading .Thread (target = websocket_thread )
1457+ thread .daemon = True
1458+ thread .start ()
1459+
1460+ try :
1461+ text = ""
1462+ context_info = []
1463+
1464+ # Process responses from the queue until connection is closed
1465+ while not connection_closed .is_set () or not response_queue .empty ():
1466+ try :
1467+ # Wait for a response with timeout
1468+ data = response_queue .get (timeout = 0.5 ) # Short timeout to check connection_closed frequently
1469+
1470+ if data ['type' ] == 'response' :
1471+ text += data ['content' ]
1472+ yield {
1473+ "text" : text ,
1474+ "error_code" : 0 ,
1475+ }
1476+ elif data ['type' ] == 'context' :
1477+ # Collect context information
1478+ context_info .append (data ['content' ])
1479+ elif data ['type' ] == 'error' :
1480+ logger .error (f"Critique API error: { data ['content' ]} " )
1481+ yield {
1482+ "text" : f"**API REQUEST ERROR** Reason: { data ['content' ]} " ,
1483+ "error_code" : 1 ,
1484+ }
1485+ break
1486+
1487+ response_queue .task_done ()
1488+ except queue .Empty :
1489+ # Just a timeout to check if connection is closed
1490+ continue
1491+
1492+ except Exception as e :
1493+ logger .error (f"Error in critique_api_stream_iter: { str (e )} " )
1494+ yield {
1495+ "text" : f"**API REQUEST ERROR** Reason: { str (e )} " ,
1496+ "error_code" : 1 ,
1497+ }
1498+ finally :
1499+ # Signal the thread to stop and wait for it to finish
1500+ stop_event .set ()
1501+ thread .join (timeout = 5 )
0 commit comments