@@ -1342,6 +1342,7 @@ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
1342
1342
1343
1343
def apply (self , logits : paddle .Tensor , tokens : paddle .Tensor ):
1344
1344
if tokens .shape [1 ] == self .sample_begin :
1345
+ logits .contiguous ()
1345
1346
logits [:, self .tokenizer .encode (" " ).input_ids + [self .tokenizer .eot ]] = (
1346
1347
- np .inf
1347
1348
)
@@ -1352,6 +1353,7 @@ def __init__(self, suppress_tokens: Sequence[int]):
1352
1353
self .suppress_tokens = list (suppress_tokens )
1353
1354
1354
1355
def apply (self , logits : paddle .Tensor , tokens : paddle .Tensor ):
1356
+ logits .contiguous ()
1355
1357
logits [:, self .suppress_tokens ] = - np .inf
1356
1358
1357
1359
@@ -1369,6 +1371,7 @@ def __init__(
1369
1371
def apply (self , logits : paddle .Tensor , tokens : paddle .Tensor ):
1370
1372
# suppress <|notimestamps|> which is handled by without_timestamps
1371
1373
if self .tokenizer .no_timestamps is not None :
1374
+ logits .contiguous ()
1372
1375
logits [:, self .tokenizer .no_timestamps ] = - np .inf
1373
1376
1374
1377
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
@@ -1382,6 +1385,7 @@ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1382
1385
)
1383
1386
1384
1387
if last_was_timestamp :
1388
+ logits .contiguous ()
1385
1389
if penultimate_was_timestamp : # has to be non-timestamp
1386
1390
logits [k , self .tokenizer .timestamp_begin :] = - np .inf
1387
1391
else : # cannot be normal text tokens
@@ -1395,6 +1399,7 @@ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1395
1399
last_allowed = (
1396
1400
self .tokenizer .timestamp_begin + self .max_initial_timestamp_index
1397
1401
)
1402
+ logits .contiguous ()
1398
1403
logits [:, last_allowed + 1 :] = - np .inf
1399
1404
1400
1405
# if sum of probability over timestamps is above any other token, sample timestamp
@@ -1413,6 +1418,7 @@ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1413
1418
logprobs [k , : self .tokenizer .timestamp_begin ]
1414
1419
)
1415
1420
if timestamp_logprob > max_text_token_logprob :
1421
+ logits .contiguous ()
1416
1422
logits [k , : self .tokenizer .timestamp_begin ] = - np .inf
1417
1423
1418
1424
0 commit comments