Skip to content

Commit e801927

Browse files
committed
mask pad value to 0
1 parent 2d873ee commit e801927

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

wenet/tts/vits/models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from wenet.tts.vits.commons import init_weights, get_padding
1616
from wenet.tts.vits.losses import generator_loss, discriminator_loss, feature_loss, kl_loss
1717
from wenet.tts.vits.mel_processing import mel_spectrogram_torch
18+
from wenet.utils.mask import make_pad_mask
1819

1920

2021
class StochasticDurationPredictor(nn.Module):
@@ -791,6 +792,8 @@ def __init__(self, n_vocab, spec_channels, **kwargs):
791792
def forward(self, batch: dict, device: torch.device):
792793
x = batch['target'].to(device)
793794
x_lengths = batch['target_lengths'].to(device)
795+
x_mask = make_pad_mask(x_lengths)
796+
x = x.masked_fill(x_mask, 0) # change pad value(IGNORE_ID = -1) to 0
794797
spec = batch['feats'].to(device)
795798
spec_lengths = batch['feats_lengths'].to(device)
796799
spec = spec.transpose(1, 2)

0 commit comments

Comments
 (0)