Skip to content

Commit 1eff3be

Browse files
authored
fix slice bug (#4470)
1 parent d214a7e commit 1eff3be

File tree

1 file changed

+6
-0
lines changed
  • paddlex/inference/models/multilingual_speech_recognition

1 file changed

+6
-0
lines changed

paddlex/inference/models/multilingual_speech_recognition/processors.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,6 +1342,7 @@ def __init__(self, tokenizer: Tokenizer, sample_begin: int):
13421342

13431343
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
13441344
if tokens.shape[1] == self.sample_begin:
1345+
logits.contiguous()
13451346
logits[:, self.tokenizer.encode(" ").input_ids + [self.tokenizer.eot]] = (
13461347
-np.inf
13471348
)
@@ -1352,6 +1353,7 @@ def __init__(self, suppress_tokens: Sequence[int]):
13521353
self.suppress_tokens = list(suppress_tokens)
13531354

13541355
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
1356+
logits.contiguous()
13551357
logits[:, self.suppress_tokens] = -np.inf
13561358

13571359

@@ -1369,6 +1371,7 @@ def __init__(
13691371
def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
13701372
# suppress <|notimestamps|> which is handled by without_timestamps
13711373
if self.tokenizer.no_timestamps is not None:
1374+
logits.contiguous()
13721375
logits[:, self.tokenizer.no_timestamps] = -np.inf
13731376

13741377
# timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
@@ -1382,6 +1385,7 @@ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
13821385
)
13831386

13841387
if last_was_timestamp:
1388+
logits.contiguous()
13851389
if penultimate_was_timestamp: # has to be non-timestamp
13861390
logits[k, self.tokenizer.timestamp_begin :] = -np.inf
13871391
else: # cannot be normal text tokens
@@ -1395,6 +1399,7 @@ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
13951399
last_allowed = (
13961400
self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
13971401
)
1402+
logits.contiguous()
13981403
logits[:, last_allowed + 1 :] = -np.inf
13991404

14001405
# if sum of probability over timestamps is above any other token, sample timestamp
@@ -1413,6 +1418,7 @@ def apply(self, logits: paddle.Tensor, tokens: paddle.Tensor):
14131418
logprobs[k, : self.tokenizer.timestamp_begin]
14141419
)
14151420
if timestamp_logprob > max_text_token_logprob:
1421+
logits.contiguous()
14161422
logits[k, : self.tokenizer.timestamp_begin] = -np.inf
14171423

14181424

0 commit comments

Comments
 (0)