unmasked_features = features.clone()
if padding_mask is not None:
extra = padding_mask.size(1) % features.size(1)
if extra > 0:
padding_mask = padding_mask[:, :-extra]
padding_mask = padding_mask.view(padding_mask.size(0), features.size(1), -1)
padding_mask = padding_mask.all(-1)
After Change
// these two operations makes sure that all values
// before the output lengths indices are attended to
padding_mask[(torch.arange(padding_mask.shape[0], device=padding_mask.device), output_lengths - 1)] = 1
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
if self.post_extract_proj is not None:
features = self.post_extract_proj(features)