@@ -371,6 +371,25 @@ def copy_input_file(self, input_uri: str, copy_to_path: str):
371371 with fsspec .open (copy_to_path , "wb" ) as output_file :
372372 output_file .write (input_file .read ())
373373
374+ def copy_input_folder (self , input_uri : str , nb_copy_to_path : str ):
375+ """Copies the input file along with the input directory to the staging directory"""
376+ input_dir_path = os .path .dirname (os .path .join (self .root_dir , input_uri ))
377+ staging_dir = os .path .dirname (nb_copy_to_path )
378+
379+ # Copy the input file
380+ self .copy_input_file (input_uri , nb_copy_to_path )
381+
382+ # Copy the rest of the input folder excluding the input file
383+ for item in os .listdir (input_dir_path ):
384+ source = os .path .join (input_dir_path , item )
385+ destination = os .path .join (staging_dir , item )
386+ if os .path .isdir (source ):
387+ shutil .copytree (source , destination )
388+ elif os .path .isfile (source ) and item != os .path .basename (input_uri ):
389+ with fsspec .open (source ) as src_file :
390+ with fsspec .open (destination , "wb" ) as output_file :
391+ output_file .write (src_file .read ())
392+
374393 def create_job (self , model : CreateJob ) -> str :
375394 if not model .job_definition_id and not self .file_exists (model .input_uri ):
376395 raise InputUriError (model .input_uri )
@@ -401,7 +420,10 @@ def create_job(self, model: CreateJob) -> str:
401420 session .commit ()
402421
403422 staging_paths = self .get_staging_paths (DescribeJob .from_orm (job ))
404- self .copy_input_file (model .input_uri , staging_paths ["input" ])
423+ if model .package_input_folder :
424+ self .copy_input_folder (model .input_uri , staging_paths ["input" ])
425+ else :
426+ self .copy_input_file (model .input_uri , staging_paths ["input" ])
405427
406428 # The MP context forces new processes to not be forked on Linux.
407429 # This is necessary because `asyncio.get_event_loop()` is bugged in
@@ -541,7 +563,10 @@ def create_job_definition(self, model: CreateJobDefinition) -> str:
541563 job_definition_id = job_definition .job_definition_id
542564
543565 staging_paths = self .get_staging_paths (DescribeJobDefinition .from_orm (job_definition ))
544- self .copy_input_file (model .input_uri , staging_paths ["input" ])
566+ if model .package_input_folder :
567+ self .copy_input_folder (model .input_uri , staging_paths ["input" ])
568+ else :
569+ self .copy_input_file (model .input_uri , staging_paths ["input" ])
545570
546571 if self .task_runner and job_definition .schedule :
547572 self .task_runner .add_job_definition (job_definition_id )
@@ -690,6 +715,10 @@ def get_staging_paths(self, model: Union[DescribeJob, DescribeJobDefinition]) ->
690715
691716 staging_paths ["input" ] = os .path .join (self .staging_path , id , model .input_filename )
692717
718+ if model .package_input_folder :
719+ notebook_dir = os .path .dirname (staging_paths ["input" ])
720+ staging_paths ["input_dir" ] = notebook_dir
721+
693722 return staging_paths
694723
695724
0 commit comments