66from dask .distributed import Client as DaskClient
77from jupyter_ai .document_loaders .directory import get_embeddings , split
88from jupyter_ai .document_loaders .splitter import ExtensionSplitter , NotebookSplitter
9- from jupyter_ai .models import HumanChatMessage , IndexedDir , IndexMetadata
9+ from jupyter_ai .models import (
10+ DEFAULT_CHUNK_OVERLAP ,
11+ DEFAULT_CHUNK_SIZE ,
12+ HumanChatMessage ,
13+ IndexedDir ,
14+ IndexMetadata ,
15+ )
1016from jupyter_core .paths import jupyter_data_dir
1117from langchain import FAISS
1218from langchain .schema import BaseRetriever , Document
@@ -30,12 +36,20 @@ def __init__(
3036 super ().__init__ (* args , ** kwargs )
3137 self .root_dir = root_dir
3238 self .dask_client_future = dask_client_future
33- self .chunk_size = 2000
34- self .chunk_overlap = 100
3539 self .parser .prog = "/learn"
3640 self .parser .add_argument ("-v" , "--verbose" , action = "store_true" )
3741 self .parser .add_argument ("-d" , "--delete" , action = "store_true" )
3842 self .parser .add_argument ("-l" , "--list" , action = "store_true" )
43+ self .parser .add_argument (
44+ "-c" , "--chunk-size" , action = "store" , default = DEFAULT_CHUNK_SIZE , type = int
45+ )
46+ self .parser .add_argument (
47+ "-o" ,
48+ "--chunk-overlap" ,
49+ action = "store" ,
50+ default = DEFAULT_CHUNK_OVERLAP ,
51+ type = int ,
52+ )
3953 self .parser .add_argument ("path" , nargs = argparse .REMAINDER )
4054 self .index_name = "default"
4155 self .index = None
@@ -102,7 +116,7 @@ async def _process_message(self, message: HumanChatMessage):
102116 if args .verbose :
103117 self .reply (f"Loading and splitting files for { load_path } " , message )
104118
105- await self .learn_dir (load_path )
119+ await self .learn_dir (load_path , args . chunk_size , args . chunk_overlap )
106120 self .save ()
107121
108122 response = f"""🎉 I have learned documents at **{ load_path } ** and I am ready to answer questions about them.
@@ -119,27 +133,18 @@ def _build_list_response(self):
119133 { dir_list } """
120134 return message
121135
122- async def learn_dir (self , path : str ):
136+ async def learn_dir (self , path : str , chunk_size : int , chunk_overlap : int ):
123137 dask_client = await self .dask_client_future
138+ splitter_kwargs = {chunk_size : chunk_size , chunk_overlap : chunk_overlap }
124139 splitters = {
125- ".py" : PythonCodeTextSplitter (
126- chunk_size = self .chunk_size , chunk_overlap = self .chunk_overlap
127- ),
128- ".md" : MarkdownTextSplitter (
129- chunk_size = self .chunk_size , chunk_overlap = self .chunk_overlap
130- ),
131- ".tex" : LatexTextSplitter (
132- chunk_size = self .chunk_size , chunk_overlap = self .chunk_overlap
133- ),
134- ".ipynb" : NotebookSplitter (
135- chunk_size = self .chunk_size , chunk_overlap = self .chunk_overlap
136- ),
140+ ".py" : PythonCodeTextSplitter (** splitter_kwargs ),
141+ ".md" : MarkdownTextSplitter (** splitter_kwargs ),
142+ ".tex" : LatexTextSplitter (** splitter_kwargs ),
143+ ".ipynb" : NotebookSplitter (** splitter_kwargs ),
137144 }
138145 splitter = ExtensionSplitter (
139146 splitters = splitters ,
140- default_splitter = RecursiveCharacterTextSplitter (
141- chunk_size = self .chunk_size , chunk_overlap = self .chunk_overlap
142- ),
147+ default_splitter = RecursiveCharacterTextSplitter (** splitter_kwargs ),
143148 )
144149
145150 delayed = split (path , splitter = splitter )
@@ -149,14 +154,18 @@ async def learn_dir(self, path: str):
149154 delayed = get_embeddings (doc_chunks , em_provider_cls , em_provider_args )
150155 embedding_records = await dask_client .compute (delayed )
151156 self .index .add_embeddings (* embedding_records )
152- self ._add_dir_to_metadata (path )
157+ self ._add_dir_to_metadata (path , chunk_size , chunk_overlap )
153158 self .prev_em_id = em_provider_cls .id + ":" + em_provider_args ["model_id" ]
154159
155- def _add_dir_to_metadata (self , path : str ):
160+ def _add_dir_to_metadata (self , path : str , chunk_size : int , chunk_overlap : int ):
156161 dirs = self .metadata .dirs
157162 index = next ((i for i , dir in enumerate (dirs ) if dir .path == path ), None )
158163 if not index :
159- dirs .append (IndexedDir (path = path ))
164+ dirs .append (
165+ IndexedDir (
166+ path = path , chunk_size = chunk_size , chunk_overlap = chunk_overlap
167+ )
168+ )
160169 self .metadata .dirs = dirs
161170
162171 async def delete_and_relearn (self ):
@@ -213,7 +222,7 @@ async def relearn(self, metadata: IndexMetadata):
213222 for dir in metadata .dirs :
214223 # TODO: do not relearn directories in serial, but instead
215224 # concurrently or in parallel
216- await self .learn_dir (dir .path )
225+ await self .learn_dir (dir .path , dir . chunk_size , dir . chunk_overlap )
217226
218227 self .save ()
219228
0 commit comments