Skip to content

Commit 3275f30

Browse files
3coinspre-commit-ci[bot]JasonWeill
authored
Allows specifying chunk size and overlap with /learn (#267)
* Allows specifying chunk size and overlap with /learn * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Refactored as per PR review comments * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Documents -c and -o options * Update docs/source/users/index.md --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jason Weill <[email protected]>
1 parent 39fe023 commit 3275f30

File tree

3 files changed

+52
-24
lines changed

3 files changed

+52
-24
lines changed

docs/source/users/index.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,20 @@ To clear the local vector database, you can run `/learn -d` and Jupyter AI will
328328
alt='Screen shot of a "/learn -d" command and a response.'
329329
class="screenshot" />
330330

331+
With the `/learn` command, some models work better with custom chunk size and chunk overlap values. To override the defaults,
332+
use the `-c` or `--chunk-size` option and the `-o` or `--chunk-overlap` option.
333+
334+
```
335+
# default chunk size and chunk overlap
336+
/learn <directory>
337+
338+
# chunk size of 500, and chunk overlap of 50
339+
/learn -c 500 -o 50 <directory>
340+
341+
# chunk size of 1000, and chunk overlap of 200
342+
/learn --chunk-size 1000 --chunk-overlap 200 <directory>
343+
```
344+
331345
### Additional chat commands
332346

333347
To clear the chat panel, use the `/clear` command. This does not reset the AI model; the model may still remember previous messages that you sent it, and it may use them to inform its responses.

packages/jupyter-ai/jupyter_ai/chat_handlers/learn.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from dask.distributed import Client as DaskClient
77
from jupyter_ai.document_loaders.directory import get_embeddings, split
88
from jupyter_ai.document_loaders.splitter import ExtensionSplitter, NotebookSplitter
9-
from jupyter_ai.models import HumanChatMessage, IndexedDir, IndexMetadata
9+
from jupyter_ai.models import (
10+
DEFAULT_CHUNK_OVERLAP,
11+
DEFAULT_CHUNK_SIZE,
12+
HumanChatMessage,
13+
IndexedDir,
14+
IndexMetadata,
15+
)
1016
from jupyter_core.paths import jupyter_data_dir
1117
from langchain import FAISS
1218
from langchain.schema import BaseRetriever, Document
@@ -30,12 +36,20 @@ def __init__(
3036
super().__init__(*args, **kwargs)
3137
self.root_dir = root_dir
3238
self.dask_client_future = dask_client_future
33-
self.chunk_size = 2000
34-
self.chunk_overlap = 100
3539
self.parser.prog = "/learn"
3640
self.parser.add_argument("-v", "--verbose", action="store_true")
3741
self.parser.add_argument("-d", "--delete", action="store_true")
3842
self.parser.add_argument("-l", "--list", action="store_true")
43+
self.parser.add_argument(
44+
"-c", "--chunk-size", action="store", default=DEFAULT_CHUNK_SIZE, type=int
45+
)
46+
self.parser.add_argument(
47+
"-o",
48+
"--chunk-overlap",
49+
action="store",
50+
default=DEFAULT_CHUNK_OVERLAP,
51+
type=int,
52+
)
3953
self.parser.add_argument("path", nargs=argparse.REMAINDER)
4054
self.index_name = "default"
4155
self.index = None
@@ -102,7 +116,7 @@ async def _process_message(self, message: HumanChatMessage):
102116
if args.verbose:
103117
self.reply(f"Loading and splitting files for {load_path}", message)
104118

105-
await self.learn_dir(load_path)
119+
await self.learn_dir(load_path, args.chunk_size, args.chunk_overlap)
106120
self.save()
107121

108122
response = f"""🎉 I have learned documents at **{load_path}** and I am ready to answer questions about them.
@@ -119,27 +133,18 @@ def _build_list_response(self):
119133
{dir_list}"""
120134
return message
121135

122-
async def learn_dir(self, path: str):
136+
async def learn_dir(self, path: str, chunk_size: int, chunk_overlap: int):
123137
dask_client = await self.dask_client_future
138+
splitter_kwargs = {chunk_size: chunk_size, chunk_overlap: chunk_overlap}
124139
splitters = {
125-
".py": PythonCodeTextSplitter(
126-
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
127-
),
128-
".md": MarkdownTextSplitter(
129-
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
130-
),
131-
".tex": LatexTextSplitter(
132-
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
133-
),
134-
".ipynb": NotebookSplitter(
135-
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
136-
),
140+
".py": PythonCodeTextSplitter(**splitter_kwargs),
141+
".md": MarkdownTextSplitter(**splitter_kwargs),
142+
".tex": LatexTextSplitter(**splitter_kwargs),
143+
".ipynb": NotebookSplitter(**splitter_kwargs),
137144
}
138145
splitter = ExtensionSplitter(
139146
splitters=splitters,
140-
default_splitter=RecursiveCharacterTextSplitter(
141-
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
142-
),
147+
default_splitter=RecursiveCharacterTextSplitter(**splitter_kwargs),
143148
)
144149

145150
delayed = split(path, splitter=splitter)
@@ -149,14 +154,18 @@ async def learn_dir(self, path: str):
149154
delayed = get_embeddings(doc_chunks, em_provider_cls, em_provider_args)
150155
embedding_records = await dask_client.compute(delayed)
151156
self.index.add_embeddings(*embedding_records)
152-
self._add_dir_to_metadata(path)
157+
self._add_dir_to_metadata(path, chunk_size, chunk_overlap)
153158
self.prev_em_id = em_provider_cls.id + ":" + em_provider_args["model_id"]
154159

155-
def _add_dir_to_metadata(self, path: str):
160+
def _add_dir_to_metadata(self, path: str, chunk_size: int, chunk_overlap: int):
156161
dirs = self.metadata.dirs
157162
index = next((i for i, dir in enumerate(dirs) if dir.path == path), None)
158163
if not index:
159-
dirs.append(IndexedDir(path=path))
164+
dirs.append(
165+
IndexedDir(
166+
path=path, chunk_size=chunk_size, chunk_overlap=chunk_overlap
167+
)
168+
)
160169
self.metadata.dirs = dirs
161170

162171
async def delete_and_relearn(self):
@@ -213,7 +222,7 @@ async def relearn(self, metadata: IndexMetadata):
213222
for dir in metadata.dirs:
214223
# TODO: do not relearn directories in serial, but instead
215224
# concurrently or in parallel
216-
await self.learn_dir(dir.path)
225+
await self.learn_dir(dir.path, dir.chunk_size, dir.chunk_overlap)
217226

218227
self.save()
219228

packages/jupyter-ai/jupyter_ai/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from jupyter_ai_magics.providers import AuthStrategy, Field
44
from pydantic import BaseModel
55

6+
DEFAULT_CHUNK_SIZE = 2000
7+
DEFAULT_CHUNK_OVERLAP = 100
8+
69

710
# the type of message used to chat with the agent
811
class ChatRequest(BaseModel):
@@ -86,6 +89,8 @@ class ListProvidersResponse(BaseModel):
8689

8790
class IndexedDir(BaseModel):
8891
path: str
92+
chunk_size: int = DEFAULT_CHUNK_SIZE
93+
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP
8994

9095

9196
class IndexMetadata(BaseModel):

0 commit comments

Comments
 (0)