Skip to content

Commit 7d6adbe

Browse files
committed
updated
1 parent f396606 commit 7d6adbe

1 file changed

Lines changed: 6 additions & 13 deletions

File tree

training/train.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,9 @@ def _expand(name, v):
245245
for head_id, dataset_config in enumerate(args.list_datasets):
246246
base_bs = args.list_batch_sizes[head_id]
247247
if dataset_config.dali_type == "decord":
248-
adjusted_bs = base_bs * frame_scale_factor
248+
adjusted_bs = base_bs * 1
249249
logger.info(f"[head_id={head_id}] Video branch: base_bs={base_bs}, "
250-
f"adjusted_bs={adjusted_bs} (scale={frame_scale_factor}x)")
250+
f"adjusted_bs={adjusted_bs} (scale={1}x)")
251251
else:
252252
adjusted_bs = base_bs
253253
logger.info(f"[head_id={head_id}] Image branch: bs={adjusted_bs}")
@@ -408,8 +408,8 @@ def wrap_ddp(model):
408408
data_csv_path=dataset_config.prefixes[0],
409409
mode="train",
410410
dali_num_threads=2,
411-
dali_py_num_workers=4 // frame_scale_factor,
412-
decord_num_threads=frame_scale_factor,
411+
dali_py_num_workers=4 // 1,
412+
decord_num_threads=1,
413413
batch_size=args.list_batch_sizes_adjusted[head_id],
414414
input_size=args.image_size_video[0],
415415
sequence_length=args.num_frames,
@@ -427,8 +427,8 @@ def wrap_ddp(model):
427427
data_csv_path=dataset_config.prefixes[0],
428428
mode="train",
429429
dali_num_threads=2,
430-
dali_py_num_workers=4 // frame_scale_factor,
431-
decord_num_threads=frame_scale_factor,
430+
dali_py_num_workers=4 // 1,
431+
decord_num_threads=1,
432432
batch_size=args.list_batch_sizes_adjusted[head_id],
433433
input_size=args.image_size_video[0],
434434
sequence_length=64,
@@ -610,14 +610,7 @@ def wrap_ddp(model):
610610

611611
# Unpatchify: [n, C, target_num, p, p] -> [n, C, T', H, W]
612612
T_new = args.target_num // (Hp * Wp) # 2048 // 256 = 8
613-
if T_new == 0:
614-
T_new = 1
615613
num_patches = T_new * Hp * Wp # 8 * 256 = 2048
616-
if selected.size(2) > num_patches:
617-
selected = selected[:, :, :num_patches] # [14, 3, 2048, 14, 14]
618-
elif selected.size(2) < num_patches:
619-
selected = torch.cat([selected, selected[:, :, -1:].expand(-1, -1, num_patches - selected.size(2), -1, -1)], dim=2) # [14, 3, 2048, 14, 14]
620-
# [14, 3, 2048, 14, 14] -> [14, 3, 8, 16, 16, 14, 14] -> [14, 3, 8, 16, 14, 16, 14] -> [14, 3, 8, 224, 224]
621614
combined_head_input = selected.view(n, C, T_new, Hp, Wp, patch_size, patch_size).permute(0, 1, 2, 3, 5, 4, 6).reshape(n, C, T_new, H, W) # [14, 3, 8, 224, 224]
622615

623616
with torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda"):

0 commit comments

Comments
 (0)