@@ -245,9 +245,9 @@ def _expand(name, v):
245245 for head_id , dataset_config in enumerate (args .list_datasets ):
246246 base_bs = args .list_batch_sizes [head_id ]
247247 if dataset_config .dali_type == "decord" :
248- adjusted_bs = base_bs * frame_scale_factor
248+ adjusted_bs = base_bs * 1
249249 logger .info (f"[head_id={ head_id } ] Video branch: base_bs={ base_bs } , "
250- f"adjusted_bs={ adjusted_bs } (scale={ frame_scale_factor } x)" )
250+ f"adjusted_bs={ adjusted_bs } (scale={ 1 } x)" )
251251 else :
252252 adjusted_bs = base_bs
253253 logger .info (f"[head_id={ head_id } ] Image branch: bs={ adjusted_bs } " )
@@ -408,8 +408,8 @@ def wrap_ddp(model):
408408 data_csv_path = dataset_config .prefixes [0 ],
409409 mode = "train" ,
410410 dali_num_threads = 2 ,
411- dali_py_num_workers = 4 // frame_scale_factor ,
412- decord_num_threads = frame_scale_factor ,
411+ dali_py_num_workers = 4 // 1 ,
412+ decord_num_threads = 1 ,
413413 batch_size = args .list_batch_sizes_adjusted [head_id ],
414414 input_size = args .image_size_video [0 ],
415415 sequence_length = args .num_frames ,
@@ -427,8 +427,8 @@ def wrap_ddp(model):
427427 data_csv_path = dataset_config .prefixes [0 ],
428428 mode = "train" ,
429429 dali_num_threads = 2 ,
430- dali_py_num_workers = 4 // frame_scale_factor ,
431- decord_num_threads = frame_scale_factor ,
430+ dali_py_num_workers = 4 // 1 ,
431+ decord_num_threads = 1 ,
432432 batch_size = args .list_batch_sizes_adjusted [head_id ],
433433 input_size = args .image_size_video [0 ],
434434 sequence_length = 64 ,
@@ -610,14 +610,7 @@ def wrap_ddp(model):
610610
611611 # Unpatchify: [n, C, target_num, p, p] -> [n, C, T', H, W]
612612 T_new = args .target_num // (Hp * Wp ) # 2048 // 256 = 8
613- if T_new == 0 :
614- T_new = 1
615613 num_patches = T_new * Hp * Wp # 8 * 256 = 2048
616- if selected .size (2 ) > num_patches :
617- selected = selected [:, :, :num_patches ] # [14, 3, 2048, 14, 14]
618- elif selected .size (2 ) < num_patches :
619- selected = torch .cat ([selected , selected [:, :, - 1 :].expand (- 1 , - 1 , num_patches - selected .size (2 ), - 1 , - 1 )], dim = 2 ) # [14, 3, 2048, 14, 14]
620- # [14, 3, 2048, 14, 14] -> [14, 3, 8, 16, 16, 14, 14] -> [14, 3, 8, 16, 14, 16, 14] -> [14, 3, 8, 224, 224]
621614 combined_head_input = selected .view (n , C , T_new , Hp , Wp , patch_size , patch_size ).permute (0 , 1 , 2 , 3 , 5 , 4 , 6 ).reshape (n , C , T_new , H , W ) # [14, 3, 8, 224, 224]
622615
623616 with torch .amp .autocast (dtype = torch .bfloat16 , device_type = "cuda" ):
0 commit comments