for idx in range(1, len(self._masks)):
embedding += self.embeddings(masked_indices[:, :, idx])
embedding /= len(self._masks)
return embedding
After Change
batch_size, seq_size = indices.size(0), 1
if not indices.is_contiguous():
indices = indices.contiguous()
indices = indices.data.view(batch_size * seq_size, 1)
torch.mul(
indices.expand_as(masks),