target_data = target.data.clone()
for i in range(target_data.size(0)):
if target_data[i] == 0 and align.data[i] != 0:
target_data[i] = align.data[i] + len(self.tgt_vocab)
// Coverage loss term.
loss_data = loss.data.clone()
After Change
// tgt[i] = align[i] + len(tgt_vocab)
// for i such that tgt[i] == 0 and align[i] != 0
target_data = target.data.clone()
correct_mask = target_data.eq(0) * align.data.ne(0)
correct_copy = (align.data + len(self.tgt_vocab)) * correct_mask.long()
target_data = target_data + correct_copy
// Coverage loss term.
loss_data = loss.data.clone()