66import  mimetypes 
77import  multiprocessing 
88import  fnmatch 
9+ import  traceback 
910from  tqdm  import  tqdm 
1011from  datetime  import  datetime 
1112from  functools  import  reduce 
1213from  itertools  import  repeat 
13- from  typing  import  Dict , Optional , List , Callable , Type 
14+ from  typing  import  Dict , Optional , List , Callable , Type ,  Union 
1415from  pathlib  import  Path , PurePosixPath , PurePath 
1516from  lazyllm .thirdparty  import  fsspec 
1617from  lazyllm  import  ModuleBase , LOG , config 
1718from  lazyllm .components .formatter .formatterbase  import  _lazyllm_get_file_list 
19+ from  lazyllm .tools .rag .readers .readerBase  import  TxtReader , DefaultReader 
1820from  .doc_node  import  DocNode 
1921from  .readers  import  (ReaderBase , PDFReader , DocxReader , HWPReader , PPTXReader , ImageReader , IPYNBReader ,
2022                      EpubReader , MarkdownReader , MboxReader , PandasCSVReader , PandasExcelReader , VideoAudioReader ,
2123                      get_default_fs , is_default_fs )
24+ from  .transform  import  NodeTransform , FuncNodeTransform 
2225from  .global_metadata  import  (RAG_DOC_PATH , RAG_DOC_FILE_NAME , RAG_DOC_FILE_TYPE , RAG_DOC_FILE_SIZE ,
2326                              RAG_DOC_CREATION_DATE , RAG_DOC_LAST_MODIFIED_DATE , RAG_DOC_LAST_ACCESSED_DATE )
2427
@@ -79,6 +82,8 @@ class SimpleDirectoryReader(ModuleBase):
7982        '*.xlsx' : PandasExcelReader ,
8083        '*.mp3' : VideoAudioReader ,
8184        '*.mp4' : VideoAudioReader ,
85+         '*.txt' : TxtReader ,
86+         '*.xml' : TxtReader ,
8287    }
8388
8489    def  __init__ (self , input_dir : Optional [str ] =  None , input_files : Optional [List ] =  None ,
@@ -89,39 +94,31 @@ def __init__(self, input_dir: Optional[str] = None, input_files: Optional[List]
8994                 return_trace : bool  =  False , metadatas : Optional [Dict ] =  None ) ->  None :
9095        super ().__init__ (return_trace = return_trace )
9196
92-         if  (not  input_dir  and  not  input_files ) or  (input_dir  and  input_files ):
93-             raise  ValueError ('Must provide either `input_dir` or `input_files`.' )
94- 
9597        self ._fs  =  fs  or  get_default_fs ()
9698        self ._encoding  =  encoding 
97- 
9899        self ._exclude  =  exclude 
99100        self ._recursive  =  recursive 
100101        self ._exclude_hidden  =  exclude_hidden 
101102        self ._required_exts  =  required_exts 
102103        self ._num_files_limit  =  num_files_limit 
103104        self ._Path  =  Path  if  is_default_fs (self ._fs ) else  PurePosixPath 
104105        self ._metadatas  =  metadatas 
106+         self ._input_files  =  self ._get_input_files (input_dir , input_files )
107+         self ._file_extractor  =  {** self .default_file_readers , ** (file_extractor  or  {})}
108+         self ._metadata_genf  =  metadata_genf  or  _DefaultFileMetadataFunc (self ._fs )
109+         if  filename_as_id : LOG .warning ('Argument `filename_as_id` for DataReader is no longer used' )
105110
111+     def  _get_input_files (self , input_dir , input_files ):
106112        if  input_files :
107-             self ._input_files  =  []
108-             for  path  in  input_files :
109-                 if  not  self ._fs .isfile (path ):
110-                     path  =  os .path .join (config ['data_path' ], path )
111-                     if  not  self ._fs .isfile (path ):
112-                         raise  ValueError (f'File { path }   does not exist.' )
113-                 input_file  =  self ._Path (path )
114-                 self ._input_files .append (input_file )
113+             assert  not  input_dir , 'Cannot provide files and dir at the same time' 
114+             input_files  =  [os .path .join (config ['data_path' ], p ) if  not  self ._fs .isfile (p ) else  p  for  p  in  input_files ]
115+             input_files  =  [self ._Path (p ) if  p  else  (_  for  _  in  ()).throw (ValueError , f'File { p }   does not exist.' )
116+                            for  p  in  input_files ]
115117        elif  input_dir :
116118            if  not  self ._fs .isdir (input_dir ):
117119                raise  ValueError (f'Directory { input_dir }   does not exist.' )
118-             self ._input_dir  =  self ._Path (input_dir )
119-             self ._input_files  =  self ._add_files (self ._input_dir )
120- 
121-         self ._file_extractor  =  file_extractor  or  {}
122- 
123-         self ._metadata_genf  =  metadata_genf  or  _DefaultFileMetadataFunc (self ._fs )
124-         if  filename_as_id : LOG .warning ('Argument `filename_as_id` for DataReader is no longer used' )
120+             input_files  =  self ._add_files (self ._Path (input_dir ))
121+         return  input_files 
125122
126123    def  _add_files (self , input_dir : Path ) ->  List [Path ]:  # noqa: C901 
127124        all_files  =  set ()
@@ -195,15 +192,10 @@ def _exclude_metadata(self, documents: List[DocNode]) -> List[DocNode]:
195192        return  documents 
196193
197194    @staticmethod  
198-     def  load_file (input_file : Path , metadata_genf : Callable [[str ], Dict ], file_extractor : Dict [str , Callable ],
199-                   encoding : str  =  'utf-8' , pathm : PurePath  =  Path , fs : Optional ['fsspec.AbstractFileSystem' ] =  None ,
200-                   metadata : Optional [Dict ] =  None ) ->  List [DocNode ]:
201-         # metadata priority: user > reader > metadata_genf 
202-         user_metadata : Dict  =  metadata  or  {}
203-         metadata_generated : Dict  =  metadata_genf (str (input_file )) if  metadata_genf  else  {}
204-         documents : List [DocNode ] =  []
205- 
195+     def  find_extractor_by_file (input_file : Path , file_extractor : Dict [str , Callable ], pathm : PurePath  =  Path ):
206196        filename_lower  =  str (input_file ).lower ()
197+         file_suffix  =  filename_lower .split ('.' )[- 1 ]
198+         if  extractor  :=  file_extractor .get (f'*.{ file_suffix }  ' ): return  extractor 
207199
208200        for  pattern , extractor  in  file_extractor .items ():
209201            pt_lower  =  str (pathm (pattern )).lower ()
@@ -213,72 +205,91 @@ def load_file(input_file: Path, metadata_genf: Callable[[str], Dict], file_extra
213205            else :
214206                base  =  str (pathm .cwd ()).lower ()
215207                match_pattern  =  os .path .join (base , pt_lower )
216- 
217208            if  fnmatch .fnmatch (filename_lower , match_pattern ):
218-                 reader  =  extractor () if  isinstance (extractor , type ) else  extractor 
219-                 kwargs  =  {'fs' : fs } if  fs  and  not  is_default_fs (fs ) else  {}
220-                 docs  =  reader (input_file , ** kwargs )
221-                 if  isinstance (docs , DocNode ): docs  =  [docs ]
222-                 for  doc  in  docs :
223-                     metadata  =  metadata_generated .copy ()
224-                     metadata .update (doc ._global_metadata  or  {})
225-                     metadata .update (user_metadata )
226-                     doc ._global_metadata  =  metadata 
227- 
228-                 if  config ['rag_filename_as_id' ]:
229-                     for  i , doc  in  enumerate (docs ):
230-                         doc ._uid  =  f'{ input_file !s}  _index_{ i }  ' 
231-                 documents .extend (docs )
232-                 break 
233-         else :
234-             if  not  config ['use_fallback_reader' ]:
235-                 LOG .warning (f'no pattern found for { input_file }  ! ' 
236-                             'If you want fallback to default Reader, set `LAZYLLM_USE_FALLBACK_READER=True`.' )
237-                 return  documents 
238-             fs  =  fs  or  get_default_fs ()
239-             with  fs .open (input_file , encoding = encoding ) as  f :
240-                 try :
241-                     data  =  f .read ().decode (encoding )
242-                     doc  =  DocNode (text = data , global_metadata = user_metadata )
243-                     documents .append (doc )
244-                 except  Exception :
245-                     LOG .error (f'no pattern found for { input_file }   and it is not utf-8, skip it!' )
246-         return  documents 
209+                 return  extractor 
210+         return  DefaultReader 
247211
248-     def  _load_data (self , show_progress : bool  =  False , num_workers : Optional [int ] =  None ,
249-                    fs : Optional ['fsspec.AbstractFileSystem' ] =  None ) ->  List [DocNode ]:
250-         documents  =  []
212+     @staticmethod  
213+     def  load_file (input_file : Path , metadata_genf : Callable [[str ], Dict ], file_extractor : Dict [str , Callable ],
214+                   encoding : str  =  'utf-8' , pathm : PurePath  =  Path , fs : Optional ['fsspec.AbstractFileSystem' ] =  None ,
215+                   metadata : Optional [Dict ] =  None ) ->  List [DocNode ]:
216+         # metadata priority: user > reader > metadata_genf 
217+         user_metadata : Dict  =  metadata  or  {}
218+         metadata_generated : Dict  =  metadata_genf (str (input_file )) if  metadata_genf  else  {}
219+         rd  =  SimpleDirectoryReader .find_extractor_by_file (input_file , file_extractor , pathm )
220+         reader  =  rd (encoding = encoding ) if  isinstance (rd , TxtReader ) else  rd () if  isinstance (rd , type ) else  rd 
221+         kwargs  =  {'fs' : fs } if  fs  and  not  is_default_fs (fs ) else  {}
222+ 
223+         try :
224+             docs  =  reader (input_file , ** kwargs )
225+         except  Exception  as  e :
226+             LOG .error (f'Error loading file { input_file }  , skip it!' )
227+             LOG .error (f'message: { e } \n  Traceback: { traceback .format_tb (e .__traceback__ )}  ' )
228+             return  []
229+         docs  =  [docs ] if  isinstance (docs , DocNode ) else  [] if  docs  is  None  else  docs 
230+ 
231+         for  doc  in  docs :
232+             metadata  =  metadata_generated .copy ()
233+             metadata .update (doc ._global_metadata  or  {})
234+             metadata .update (user_metadata )
235+             doc ._global_metadata  =  metadata 
251236
252-         fs  =  fs  or  self ._fs 
253-         process_file  =  self ._input_files 
254-         file_readers  =  self ._file_extractor .copy ()
255-         for  key , func  in  self .default_file_readers .items ():
256-             if  key  not  in   file_readers : file_readers [key ] =  func 
237+         if  config ['rag_filename_as_id' ]:
238+             for  i , doc  in  enumerate (docs ):
239+                 doc ._uid  =  f'{ input_file !s}  _index_{ i }  ' 
240+         return  docs 
241+ 
242+     def  _load_data (self , show_progress : bool  =  False , num_workers : Optional [int ] =  None ,
243+                    fs : Optional ['fsspec.AbstractFileSystem' ] =  None , metadatas : Optional [Dict ] =  None ,
244+                    input_dir : Optional [str ] =  None , input_files : Optional [List ] =  None ) ->  List [DocNode ]:
245+         documents , fs , metadatas  =  [], fs  or  self ._fs , metadatas  or  self ._metadatas 
246+         process_file  =  self ._get_input_files (input_dir , input_files ) if  input_dir  or  input_files  else  self ._input_files 
257247
258248        if  num_workers  and  num_workers  >=  1 :
259249            if  num_workers  >  multiprocessing .cpu_count ():
260250                LOG .warning ('Specified num_workers exceed number of CPUs in the system. ' 
261251                            'Setting `num_workers` down to the maximum CPU count.' )
262252            with  multiprocessing .get_context ('spawn' ).Pool (num_workers ) as  p :
263253                results  =  p .starmap (SimpleDirectoryReader .load_file ,
264-                                     zip (process_file , repeat (self ._metadata_genf ), repeat (file_readers ),
254+                                     zip (process_file , repeat (self ._metadata_genf ), repeat (self . _file_extractor ),
265255                                        repeat (self ._encoding ), repeat (self ._Path ),
266-                                         repeat (self ._fs ), self . _metadatas  or  repeat (None )))
256+                                         repeat (self ._fs ), metadatas  or  repeat (None )))
267257                documents  =  reduce (lambda  x , y : x  +  y , results )
268258        else :
269259            if  show_progress :
270260                process_file  =  tqdm (self ._input_files , desc = 'Loading files' , unit = 'file' )
271-             for  input_file , metadata  in  zip (process_file , self . _metadatas  or  repeat (None )):
261+             for  input_file , metadata  in  zip (process_file , metadatas  or  repeat (None )):
272262                documents .extend (
273263                    SimpleDirectoryReader .load_file (
274-                         input_file = input_file , metadata_genf = self ._metadata_genf , file_extractor = file_readers ,
264+                         input_file = input_file , metadata_genf = self ._metadata_genf , file_extractor = self . _file_extractor ,
275265                        encoding = self ._encoding , pathm = self ._Path , fs = self ._fs , metadata = metadata ))
276266
277267        return  self ._exclude_metadata (documents )
278268
279269    def  forward (self , * args , ** kwargs ) ->  List [DocNode ]:
280270        return  self ._load_data (* args , ** kwargs )
281271
272+     @staticmethod  
273+     def  get_default_reader (file_ext : str ) ->  Callable [[Path , Dict ], List [DocNode ]]:
274+         if  not  file_ext .startswith ('*.' ): file_ext  =  '*.'  +  file_ext 
275+         return  SimpleDirectoryReader .default_file_readers .get (file_ext )
276+ 
277+     @staticmethod  
278+     def  add_post_action_for_default_reader (file_ext : str , f : Callable [[DocNode ], Union [DocNode , List [DocNode ]]]):
279+         if  not  file_ext .startswith ('*.' ): file_ext  =  '*.'  +  file_ext 
280+         if  file_ext  not  in   SimpleDirectoryReader .default_file_readers :
281+             raise  KeyError (f'{ file_ext }   has no default reader, use Document.add_reader instead' )
282+ 
283+         reader  =  SimpleDirectoryReader .default_file_readers [file_ext ]
284+         assert  isinstance (reader , type ) and  issubclass (reader , ReaderBase )
285+ 
286+         if  isinstance (f , type ): f  =  f ()
287+         if  not  isinstance (f , NodeTransform ):
288+             try : f ('test' )
289+             except  Exception : pass 
290+             else : f  =  FuncNodeTransform (f , trans_node = False )
291+         reader .post_action  =  staticmethod (f )
292+ 
282293
283294config .add ('rag_filename_as_id' , bool , False , 'RAG_FILENAME_AS_ID' )
284295config .add ('use_fallback_reader' , bool , True , 'USE_FALLBACK_READER' )
0 commit comments