@@ -30,18 +30,23 @@ def get_masked_input_and_mask(
30
30
added_vocab_end_index : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
31
31
# torch.compile will fuse all of the pointwise ops below
32
32
# into a single kernel, making it very fast
33
- org_vocab_mask = (input_ >= org_vocab_start_index ) & (input_ <
34
- org_vocab_end_index )
35
- added_vocab_mask = (input_ >= added_vocab_start_index ) & (
36
- input_ < added_vocab_end_index )
37
- added_offset = added_vocab_start_index - (
38
- org_vocab_end_index - org_vocab_start_index ) - num_org_vocab_padding
39
- valid_offset = (org_vocab_start_index *
40
- org_vocab_mask ) + (added_offset * added_vocab_mask )
41
- vocab_mask = org_vocab_mask | added_vocab_mask
33
+ org_vocab_mask = (input_ >= org_vocab_start_index ) & (
34
+ input_ < org_vocab_end_index )
35
+ # Adapt: avoid create added_vocab_mask when added_vocab_start_index == added_vocab_end_index.
36
+ if added_vocab_start_index == added_vocab_end_index :
37
+ valid_offset = (org_vocab_start_index * org_vocab_mask )
38
+ vocab_mask = org_vocab_mask
39
+ else :
40
+ added_vocab_mask = (input_ >= added_vocab_start_index ) & (
41
+ input_ < added_vocab_end_index )
42
+ added_offset = added_vocab_start_index - (
43
+ org_vocab_end_index -
44
+ org_vocab_start_index ) - num_org_vocab_padding
45
+ valid_offset = (org_vocab_start_index *
46
+ org_vocab_mask ) + (added_offset * added_vocab_mask )
47
+ vocab_mask = org_vocab_mask | added_vocab_mask
48
+ # Adapt end.
42
49
input_ = vocab_mask * (input_ - valid_offset )
43
- #FIXME(xyx) refactor this
44
- torch ._dynamo .mark_static (vocab_mask )
45
50
return input_ , ~ vocab_mask
46
51
47
52
0 commit comments