11import asyncio
22import time
3- from typing import Dict , Type
3+ from typing import Any , Dict , Type
44from uuid import uuid4
55
6+ from jupyter_ai .callback_handlers import MetadataCallbackHandler
67from jupyter_ai .models import (
78 AgentStreamChunkMessage ,
89 AgentStreamMessage ,
@@ -85,13 +86,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:
8586
8687 return stream_id
8788
88- def _send_stream_chunk (self , stream_id : str , content : str , complete : bool = False ):
89+ def _send_stream_chunk (
90+ self ,
91+ stream_id : str ,
92+ content : str ,
93+ complete : bool = False ,
94+ metadata : Dict [str , Any ] = {},
95+ ):
8996 """
9097 Sends an `agent-stream-chunk` message containing content that should be
9198 appended to an existing `agent-stream` message with ID `stream_id`.
9299 """
93100 stream_chunk_msg = AgentStreamChunkMessage (
94- id = stream_id , content = content , stream_complete = complete
101+ id = stream_id , content = content , stream_complete = complete , metadata = metadata
95102 )
96103
97104 for handler in self ._root_chat_handlers .values ():
@@ -104,6 +111,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals
104111 async def process_message (self , message : HumanChatMessage ):
105112 self .get_llm_chain ()
106113 received_first_chunk = False
114+ assert self .llm_chain
107115
108116 inputs = {"input" : message .body }
109117 if "context" in self .prompt_template .input_variables :
@@ -121,10 +129,13 @@ async def process_message(self, message: HumanChatMessage):
121129 # stream response in chunks. this works even if a provider does not
122130 # implement streaming, as `astream()` defaults to yielding `_call()`
123131 # when `_stream()` is not implemented on the LLM class.
124- assert self . llm_chain
132+ metadata_handler = MetadataCallbackHandler ()
125133 async for chunk in self .llm_chain .astream (
126134 inputs ,
127- config = {"configurable" : {"last_human_msg" : message }},
135+ config = {
136+ "configurable" : {"last_human_msg" : message },
137+ "callbacks" : [metadata_handler ],
138+ },
128139 ):
129140 if not received_first_chunk :
130141 # when receiving the first chunk, close the pending message and
@@ -142,7 +153,9 @@ async def process_message(self, message: HumanChatMessage):
142153 break
143154
144155 # complete stream after all chunks have been streamed
145- self ._send_stream_chunk (stream_id , "" , complete = True )
156+ self ._send_stream_chunk (
157+ stream_id , "" , complete = True , metadata = metadata_handler .jai_metadata
158+ )
146159
147160 async def make_context_prompt (self , human_msg : HumanChatMessage ) -> str :
148161 return "\n \n " .join (
0 commit comments