Skip to content

Commit 0884211

Browse files
authored
Add metadata field to agent messages (#1013)
* add `metadata` field to agent messages Currently only set when the default chat handler is used. * include message metadata in TestLLMWithStreaming * tweak TestLLMWithStreaming parameters * pre-commit * default to empty dict if no generation_info
1 parent 6e426ab commit 0884211

File tree

9 files changed

+82
-12
lines changed

9 files changed

+82
-12
lines changed

packages/jupyter-ai-test/jupyter_ai_test/test_llms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def _stream(
4848
run_manager: Optional[CallbackManagerForLLMRun] = None,
4949
**kwargs: Any,
5050
) -> Iterator[GenerationChunk]:
51-
time.sleep(5)
51+
time.sleep(1)
5252
yield GenerationChunk(
53-
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n"
53+
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n",
54+
generation_info={"test_metadata_field": "foobar"},
5455
)
55-
for i in range(1, 101):
56-
time.sleep(0.5)
56+
for i in range(1, 6):
57+
time.sleep(0.2)
5758
yield GenerationChunk(text=f"{i}, ")
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
"""
2+
Provides classes which extend `langchain_core.callbacks:BaseCallbackHandler`.
3+
Not to be confused with Jupyter AI chat handlers.
4+
"""
5+
6+
from .metadata import MetadataCallbackHandler
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from langchain_core.callbacks import BaseCallbackHandler
2+
from langchain_core.outputs import LLMResult
3+
4+
5+
class MetadataCallbackHandler(BaseCallbackHandler):
6+
"""
7+
When passed as a callback handler, this stores the LLMResult's
8+
`generation_info` dictionary in the `self.jai_metadata` instance attribute
9+
after the provider fully processes an input.
10+
11+
If used in a streaming chat handler: the `metadata` field of the final
12+
`AgentStreamChunkMessage` should be set to `self.jai_metadata`.
13+
14+
If used in a non-streaming chat handler: the `metadata` field of the
15+
returned `AgentChatMessage` should be set to `self.jai_metadata`.
16+
"""
17+
18+
def __init__(self, *args, **kwargs):
19+
super().__init__(*args, **kwargs)
20+
self.jai_metadata = {}
21+
22+
def on_llm_end(self, response: LLMResult, **kwargs) -> None:
23+
if not (len(response.generations) and len(response.generations[0])):
24+
return
25+
26+
self.jai_metadata = response.generations[0][0].generation_info or {}

packages/jupyter-ai/jupyter_ai/chat_handlers/default.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import asyncio
22
import time
3-
from typing import Dict, Type
3+
from typing import Any, Dict, Type
44
from uuid import uuid4
55

6+
from jupyter_ai.callback_handlers import MetadataCallbackHandler
67
from jupyter_ai.models import (
78
AgentStreamChunkMessage,
89
AgentStreamMessage,
@@ -85,13 +86,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:
8586

8687
return stream_id
8788

88-
def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False):
89+
def _send_stream_chunk(
90+
self,
91+
stream_id: str,
92+
content: str,
93+
complete: bool = False,
94+
metadata: Dict[str, Any] = {},
95+
):
8996
"""
9097
Sends an `agent-stream-chunk` message containing content that should be
9198
appended to an existing `agent-stream` message with ID `stream_id`.
9299
"""
93100
stream_chunk_msg = AgentStreamChunkMessage(
94-
id=stream_id, content=content, stream_complete=complete
101+
id=stream_id, content=content, stream_complete=complete, metadata=metadata
95102
)
96103

97104
for handler in self._root_chat_handlers.values():
@@ -104,6 +111,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals
104111
async def process_message(self, message: HumanChatMessage):
105112
self.get_llm_chain()
106113
received_first_chunk = False
114+
assert self.llm_chain
107115

108116
inputs = {"input": message.body}
109117
if "context" in self.prompt_template.input_variables:
@@ -121,10 +129,13 @@ async def process_message(self, message: HumanChatMessage):
121129
# stream response in chunks. this works even if a provider does not
122130
# implement streaming, as `astream()` defaults to yielding `_call()`
123131
# when `_stream()` is not implemented on the LLM class.
124-
assert self.llm_chain
132+
metadata_handler = MetadataCallbackHandler()
125133
async for chunk in self.llm_chain.astream(
126134
inputs,
127-
config={"configurable": {"last_human_msg": message}},
135+
config={
136+
"configurable": {"last_human_msg": message},
137+
"callbacks": [metadata_handler],
138+
},
128139
):
129140
if not received_first_chunk:
130141
# when receiving the first chunk, close the pending message and
@@ -142,7 +153,9 @@ async def process_message(self, message: HumanChatMessage):
142153
break
143154

144155
# complete stream after all chunks have been streamed
145-
self._send_stream_chunk(stream_id, "", complete=True)
156+
self._send_stream_chunk(
157+
stream_id, "", complete=True, metadata=metadata_handler.jai_metadata
158+
)
146159

147160
async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
148161
return "\n\n".join(

packages/jupyter-ai/jupyter_ai/models.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ class BaseAgentMessage(BaseModel):
8787
this defaults to a description of `JupyternautPersona`.
8888
"""
8989

90+
metadata: Dict[str, Any] = {}
91+
"""
92+
Message metadata set by a provider after fully processing an input. The
93+
contents of this dictionary are provider-dependent, and can be any
94+
dictionary with string keys. This field is not to be displayed directly to
95+
the user, and is intended solely for developer purposes.
96+
"""
97+
9098

9199
class AgentChatMessage(BaseAgentMessage):
92100
type: Literal["agent"] = "agent"
@@ -101,9 +109,17 @@ class AgentStreamMessage(BaseAgentMessage):
101109
class AgentStreamChunkMessage(BaseModel):
102110
type: Literal["agent-stream-chunk"] = "agent-stream-chunk"
103111
id: str
112+
"""ID of the parent `AgentStreamMessage`."""
104113
content: str
114+
"""The string to append to the `AgentStreamMessage` referenced by `id`."""
105115
stream_complete: bool
106-
"""Indicates whether this chunk message completes the referenced stream."""
116+
"""Indicates whether this chunk completes the stream referenced by `id`."""
117+
metadata: Dict[str, Any] = {}
118+
"""
119+
The metadata of the stream referenced by `id`. Metadata from the latest
120+
chunk should override any metadata from previous chunks. See the docstring
121+
on `BaseAgentMessage.metadata` for information.
122+
"""
107123

108124

109125
class HumanChatMessage(BaseModel):

packages/jupyter-ai/src/chat_handler.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ export class ChatHandler implements IDisposable {
170170
}
171171

172172
streamMessage.body += newMessage.content;
173+
streamMessage.metadata = newMessage.metadata;
173174
if (newMessage.stream_complete) {
174175
streamMessage.complete = true;
175176
}

packages/jupyter-ai/src/components/chat-messages.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@ function sortMessages(
7474
export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
7575
const collaborators = useCollaboratorsContext();
7676

77+
if (props.message.type === 'agent-stream' && props.message.complete) {
78+
console.log(props.message.metadata);
79+
}
80+
7781
const sharedStyles: SxProps<Theme> = {
7882
height: '24px',
7983
width: '24px'

packages/jupyter-ai/src/components/pending-messages.tsx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ export function PendingMessages(
6060
time: lastMessage.time,
6161
body: '',
6262
reply_to: '',
63-
persona: lastMessage.persona
63+
persona: lastMessage.persona,
64+
metadata: {}
6465
});
6566

6667
// timestamp format copied from ChatMessage

packages/jupyter-ai/src/handler.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ export namespace AiService {
114114
body: string;
115115
reply_to: string;
116116
persona: Persona;
117+
metadata: Record<string, any>;
117118
};
118119

119120
export type HumanChatMessage = {
@@ -172,6 +173,7 @@ export namespace AiService {
172173
id: string;
173174
content: string;
174175
stream_complete: boolean;
176+
metadata: Record<string, any>;
175177
};
176178

177179
export type Request = ChatRequest | ClearRequest;

0 commit comments

Comments
 (0)