@@ -393,6 +393,15 @@ def get_input_positions_tensor(
393
393
context_len = context_len ,
394
394
seq_len = seq_len ,
395
395
)
396
+ elif hf_config .model_type in ["ernie4_5_moe_vl" , "ernie4_5_vl" ]:
397
+ return cls ._ernie_get_input_positions_tensor (
398
+ input_tokens = input_tokens ,
399
+ hf_config = hf_config ,
400
+ image_grid_thw = image_grid_thw ,
401
+ video_grid_thw = video_grid_thw ,
402
+ context_len = context_len ,
403
+ seq_len = seq_len ,
404
+ )
396
405
else :
397
406
return cls ._vl_get_input_positions_tensor (
398
407
input_tokens = input_tokens ,
@@ -513,6 +522,120 @@ def _glm4v_get_input_positions_tensor(
513
522
len (input_tokens )).item ()
514
523
return llm_positions , mrope_position_delta
515
524
525
+ @classmethod
526
+ def _ernie_get_input_positions_tensor (
527
+ cls ,
528
+ input_tokens : list [int ],
529
+ hf_config : PretrainedConfig ,
530
+ image_grid_thw : Union [list [list [int ]], torch .Tensor ],
531
+ video_grid_thw : Union [list [list [int ]], torch .Tensor ],
532
+ context_len : int = 0 ,
533
+ seq_len : Optional [int ] = None ,
534
+ ) -> tuple [torch .Tensor , int ]:
535
+ """Get mrope input positions and delta value for Ernie VL."""
536
+
537
+ image_token_id = hf_config .im_patch_id
538
+ video_start_token_id = hf_config .video_start_token_id
539
+ video_end_token_id = hf_config .video_end_token_id
540
+ spatial_conv_size = hf_config .spatial_conv_size
541
+ temporal_conv_size = hf_config .temporal_conv_size
542
+ llm_pos_ids_list : list = []
543
+
544
+ if not (image_grid_thw is None and video_grid_thw is None ):
545
+ if isinstance (image_grid_thw , torch .Tensor ):
546
+ image_grid_thw = image_grid_thw .tolist ()
547
+
548
+ input_token_type : list [str ] = []
549
+ video_check_flg = False
550
+ for token in input_tokens :
551
+ if token == video_start_token_id :
552
+ video_check_flg = True
553
+ elif token == video_end_token_id :
554
+ video_check_flg = False
555
+
556
+ if (token == image_token_id ) and (video_check_flg is False ):
557
+ input_token_type .append ("image" )
558
+ elif (token == image_token_id ) and (video_check_flg is True ):
559
+ input_token_type .append ("video" )
560
+ else :
561
+ input_token_type .append ("text" )
562
+
563
+ input_type_group : list [tuple [str , int , int ]] = []
564
+ for key , group_iter in itertools .groupby (
565
+ enumerate (input_token_type ), lambda x : x [1 ]):
566
+ group_list = list (group_iter )
567
+ start_index = group_list [0 ][0 ]
568
+ end_index = group_list [- 1 ][0 ] + 1
569
+ input_type_group .append ((key , start_index , end_index ))
570
+
571
+ video_frame_num = 1
572
+ mm_data_idx = 0
573
+ for modality_type , start_idx , end_idx in input_type_group :
574
+ st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (
575
+ llm_pos_ids_list ) > 0 else 0
576
+ if modality_type == "image" :
577
+ t , h , w = (
578
+ image_grid_thw [mm_data_idx ][0 ],
579
+ image_grid_thw [mm_data_idx ][1 ],
580
+ image_grid_thw [mm_data_idx ][2 ],
581
+ )
582
+ llm_grid_t , llm_grid_h , llm_grid_w = \
583
+ t , h // spatial_conv_size , w // spatial_conv_size
584
+
585
+ t_index = torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (
586
+ - 1 , llm_grid_h * llm_grid_w ).flatten ()
587
+ h_index = torch .arange (llm_grid_h ).view (1 , - 1 , 1 ).expand (
588
+ llm_grid_t , - 1 , llm_grid_w ).flatten ()
589
+ w_index = torch .arange (llm_grid_w ).view (1 , 1 , - 1 ).expand (
590
+ llm_grid_t , llm_grid_h , - 1 ).flatten ()
591
+ llm_pos_ids_list .append (
592
+ torch .stack ([t_index , h_index , w_index ]) + st_idx )
593
+ mm_data_idx += 1
594
+
595
+ elif modality_type == "video" :
596
+ t , h , w = (
597
+ video_grid_thw [mm_data_idx ][0 ],
598
+ video_grid_thw [mm_data_idx ][1 ],
599
+ video_grid_thw [mm_data_idx ][2 ],
600
+ )
601
+ llm_grid_t , llm_grid_h , llm_grid_w = (t //
602
+ temporal_conv_size ,
603
+ h //
604
+ spatial_conv_size ,
605
+ w //
606
+ spatial_conv_size )
607
+
608
+ for t_idx in range (llm_grid_t ):
609
+ t_index = torch .tensor (t_idx ).view (- 1 , 1 ).expand (
610
+ - 1 , llm_grid_h * llm_grid_w ).flatten ()
611
+ h_index = torch .arange (llm_grid_h ).view (
612
+ 1 , - 1 , 1 ).expand (1 , - 1 , llm_grid_w ).flatten ()
613
+ w_index = torch .arange (llm_grid_w ).view (
614
+ 1 , 1 , - 1 ).expand (1 , llm_grid_h , - 1 ).flatten ()
615
+ llm_pos_ids_list .append (
616
+ torch .stack ([t_index , h_index , w_index ]) + st_idx )
617
+
618
+ mm_data_idx += 1
619
+ video_frame_num += 1
620
+
621
+ else :
622
+ text_len = end_idx - start_idx
623
+ llm_pos_ids_list .append (
624
+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) +
625
+ st_idx )
626
+ video_frame_num = 1
627
+
628
+ else :
629
+ text_len = len (input_tokens )
630
+ llm_pos_ids_list .append (
631
+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ))
632
+
633
+ llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
634
+ llm_positions = llm_positions [:, context_len :seq_len ]
635
+ mrope_position_delta = (llm_positions .max () + 1 -
636
+ len (input_tokens )).item ()
637
+ return llm_positions , mrope_position_delta
638
+
516
639
@classmethod
517
640
def _vl_get_input_positions_tensor (
518
641
cls ,
0 commit comments