diff --git a/wenet/models/firered/subsampling.py b/wenet/models/firered/subsampling.py index d4e98ea91..a426b3431 100644 --- a/wenet/models/firered/subsampling.py +++ b/wenet/models/firered/subsampling.py @@ -64,8 +64,13 @@ def forward( x_lens = torch.sum(x_mask.squeeze(1), dim=1) x_lens = x_lens + self.right_context x_mask = make_non_pad_mask(x_lens).unsqueeze(1) + mask_seq = x_mask.size(2) x = torch.nn.functional.pad(x, (0, 0, 0, self.right_context), 'constant', 0.0) + x_seq = x.size(1) + if x_seq > mask_seq: + x_mask = torch.nn.functional.pad(x_mask, (0, x_seq - mask_seq, 0,0,0,0),'constant', 0.0) + x = x.unsqueeze(1) # (b, c=1, t, f) x = self.conv(x) b, c, t, f = x.size()