Skip to content

Commit 6e426ab

Browse files
michaelchiapre-commit-ci[bot]krassowskidlqqq
authored
Framework for adding context to LLM prompt (#993)
* context provider * split base and base command context providers + replacing prompt * comment * only replace prompt if context variable in template * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Run mypy on CI, fix or ignore typing issues (#987) * Run mypy on CI * Rename, add mypy to test deps * Fix typing jupyter-ai codebase (mostly) * Three more cases * update deepmerge version specifier --------- Co-authored-by: David L. Qiu <[email protected]> * context provider * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mypy * black * modify backtick logic * allow for spaces in filepath * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor * fixes * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test * refactor autocomplete to remove hardcoded '/' and '@' prefix * modify context prompt template Co-authored-by: david qiu <[email protected]> * refactor * docstrings + refactor * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * mypy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add context providers to help * remove _examples.py and remove @learned from defaults * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make find_commands unoverridable --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Michał Krassowski <[email protected]> Co-authored-by: David L. Qiu <[email protected]>
1 parent fcb2d71 commit 6e426ab

File tree

14 files changed

+942
-87
lines changed

14 files changed

+942
-87
lines changed

packages/jupyter-ai-magics/jupyter_ai_magics/providers.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,25 @@
6767
The following is a friendly conversation between you and a human.
6868
""".strip()
6969

70-
CHAT_DEFAULT_TEMPLATE = """Current conversation:
71-
{history}
72-
Human: {input}
70+
CHAT_DEFAULT_TEMPLATE = """
71+
{% if context %}
72+
Context:
73+
{{context}}
74+
75+
{% endif %}
76+
Current conversation:
77+
{{history}}
78+
Human: {{input}}
7379
AI:"""
7480

81+
HUMAN_MESSAGE_TEMPLATE = """
82+
{% if context %}
83+
Context:
84+
{{context}}
85+
86+
{% endif %}
87+
{{input}}
88+
"""
7589

7690
COMPLETION_SYSTEM_PROMPT = """
7791
You are an application built to provide helpful code completion suggestions.
@@ -400,17 +414,21 @@ def get_chat_prompt_template(self) -> PromptTemplate:
400414
CHAT_SYSTEM_PROMPT
401415
).format(provider_name=name, local_model_id=self.model_id),
402416
MessagesPlaceholder(variable_name="history"),
403-
HumanMessagePromptTemplate.from_template("{input}"),
417+
HumanMessagePromptTemplate.from_template(
418+
HUMAN_MESSAGE_TEMPLATE,
419+
template_format="jinja2",
420+
),
404421
]
405422
)
406423
else:
407424
return PromptTemplate(
408-
input_variables=["history", "input"],
425+
input_variables=["history", "input", "context"],
409426
template=CHAT_SYSTEM_PROMPT.format(
410427
provider_name=name, local_model_id=self.model_id
411428
)
412429
+ "\n\n"
413430
+ CHAT_DEFAULT_TEMPLATE,
431+
template_format="jinja2",
414432
)
415433

416434
def get_completion_prompt_template(self) -> PromptTemplate:

packages/jupyter-ai/jupyter_ai/chat_handlers/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from langchain.pydantic_v1 import BaseModel
3434

3535
if TYPE_CHECKING:
36+
from jupyter_ai.context_providers import BaseCommandContextProvider
3637
from jupyter_ai.handlers import RootChatHandler
3738
from jupyter_ai.history import BoundedChatHistory
3839
from langchain_core.chat_history import BaseChatMessageHistory
@@ -121,6 +122,10 @@ class BaseChatHandler:
121122
chat handlers, which is necessary for some use-cases like printing the help
122123
message."""
123124

125+
context_providers: Dict[str, "BaseCommandContextProvider"]
126+
"""Dictionary of context providers. Allows chat handlers to reference
127+
context providers, which can be used to provide context to the LLM."""
128+
124129
def __init__(
125130
self,
126131
log: Logger,
@@ -134,6 +139,7 @@ def __init__(
134139
dask_client_future: Awaitable[DaskClient],
135140
help_message_template: str,
136141
chat_handlers: Dict[str, "BaseChatHandler"],
142+
context_providers: Dict[str, "BaseCommandContextProvider"],
137143
):
138144
self.log = log
139145
self.config_manager = config_manager
@@ -154,6 +160,7 @@ def __init__(
154160
self.dask_client_future = dask_client_future
155161
self.help_message_template = help_message_template
156162
self.chat_handlers = chat_handlers
163+
self.context_providers = context_providers
157164

158165
self.llm: Optional[BaseProvider] = None
159166
self.llm_params: Optional[dict] = None
@@ -430,8 +437,17 @@ def send_help_message(self, human_msg: Optional[HumanChatMessage] = None) -> Non
430437
]
431438
)
432439

440+
context_commands_list = "\n".join(
441+
[
442+
f"* `{cp.command_id}` — {cp.help}"
443+
for cp in self.context_providers.values()
444+
]
445+
)
446+
433447
help_message_body = self.help_message_template.format(
434-
persona_name=self.persona.name, slash_commands_list=slash_commands_list
448+
persona_name=self.persona.name,
449+
slash_commands_list=slash_commands_list,
450+
context_commands_list=context_commands_list,
435451
)
436452
help_message = AgentChatMessage(
437453
id=uuid4().hex,

packages/jupyter-ai/jupyter_ai/chat_handlers/default.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import time
23
from typing import Dict, Type
34
from uuid import uuid4
@@ -12,6 +13,7 @@
1213
from langchain_core.runnables import ConfigurableFieldSpec
1314
from langchain_core.runnables.history import RunnableWithMessageHistory
1415

16+
from ..context_providers import ContextProviderException, find_commands
1517
from ..models import HumanChatMessage
1618
from .base import BaseChatHandler, SlashCommandRoutingType
1719

@@ -27,6 +29,7 @@ class DefaultChatHandler(BaseChatHandler):
2729

2830
def __init__(self, *args, **kwargs):
2931
super().__init__(*args, **kwargs)
32+
self.prompt_template = None
3033

3134
def create_llm_chain(
3235
self, provider: Type[BaseProvider], provider_params: Dict[str, str]
@@ -40,6 +43,7 @@ def create_llm_chain(
4043

4144
prompt_template = llm.get_chat_prompt_template()
4245
self.llm = llm
46+
self.prompt_template = prompt_template
4347

4448
runnable = prompt_template | llm # type:ignore
4549
if not llm.manages_history:
@@ -101,14 +105,25 @@ async def process_message(self, message: HumanChatMessage):
101105
self.get_llm_chain()
102106
received_first_chunk = False
103107

108+
inputs = {"input": message.body}
109+
if "context" in self.prompt_template.input_variables:
110+
# include context from context providers.
111+
try:
112+
context_prompt = await self.make_context_prompt(message)
113+
except ContextProviderException as e:
114+
self.reply(str(e), message)
115+
return
116+
inputs["context"] = context_prompt
117+
inputs["input"] = self.replace_prompt(inputs["input"])
118+
104119
# start with a pending message
105120
with self.pending("Generating response", message) as pending_message:
106121
# stream response in chunks. this works even if a provider does not
107122
# implement streaming, as `astream()` defaults to yielding `_call()`
108123
# when `_stream()` is not implemented on the LLM class.
109124
assert self.llm_chain
110125
async for chunk in self.llm_chain.astream(
111-
{"input": message.body},
126+
inputs,
112127
config={"configurable": {"last_human_msg": message}},
113128
):
114129
if not received_first_chunk:
@@ -128,3 +143,21 @@ async def process_message(self, message: HumanChatMessage):
128143

129144
# complete stream after all chunks have been streamed
130145
self._send_stream_chunk(stream_id, "", complete=True)
146+
147+
async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
148+
return "\n\n".join(
149+
await asyncio.gather(
150+
*[
151+
provider.make_context_prompt(human_msg)
152+
for provider in self.context_providers.values()
153+
if find_commands(provider, human_msg.prompt)
154+
]
155+
)
156+
)
157+
158+
def replace_prompt(self, prompt: str) -> str:
159+
# modifies prompt by the context providers.
160+
# some providers may modify or remove their '@' commands from the prompt.
161+
for provider in self.context_providers.values():
162+
prompt = provider.replace_prompt(prompt)
163+
return prompt
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .base import (
2+
BaseCommandContextProvider,
3+
ContextCommand,
4+
ContextProviderException,
5+
find_commands,
6+
)
7+
from .file import FileContextProvider
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Currently unused as it is duplicating the functionality of the /ask command.
2+
# TODO: Rename "learned" to something better.
3+
from typing import List
4+
5+
from jupyter_ai.chat_handlers.learn import Retriever
6+
from jupyter_ai.models import HumanChatMessage
7+
8+
from .base import BaseCommandContextProvider, ContextCommand
9+
from .file import FileContextProvider
10+
11+
FILE_CHUNK_TEMPLATE = """
12+
Snippet from file: {filepath}
13+
```
14+
{content}
15+
```
16+
""".strip()
17+
18+
19+
class LearnedContextProvider(BaseCommandContextProvider):
20+
id = "learned"
21+
help = "Include content indexed from `/learn`"
22+
remove_from_prompt = True
23+
header = "Following are snippets from potentially relevant files:"
24+
25+
def __init__(self, **kwargs):
26+
super().__init__(**kwargs)
27+
self.retriever = Retriever(learn_chat_handler=self.chat_handlers["/learn"])
28+
29+
async def _make_context_prompt(
30+
self, message: HumanChatMessage, commands: List[ContextCommand]
31+
) -> str:
32+
if not self.retriever:
33+
return ""
34+
query = self._clean_prompt(message.body)
35+
docs = await self.retriever.ainvoke(query)
36+
excluded = self._get_repeated_files(message)
37+
context = "\n\n".join(
38+
[
39+
FILE_CHUNK_TEMPLATE.format(
40+
filepath=d.metadata["path"], content=d.page_content
41+
)
42+
for d in docs
43+
if d.metadata["path"] not in excluded and d.page_content
44+
]
45+
)
46+
return self.header + "\n" + context
47+
48+
def _get_repeated_files(self, message: HumanChatMessage) -> List[str]:
49+
# don't include files that are already provided by the file context provider
50+
file_context_provider = self.context_providers.get("file")
51+
if isinstance(file_context_provider, FileContextProvider):
52+
return file_context_provider.get_filepaths(message)
53+
return []

0 commit comments

Comments
 (0)