with torch.no_grad():
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():
prefix_tokens = sample["net_input"]["src_tokens"]
if prefix_tokens[:, 0].eq(self.source_dictionary.eos()).all():
prefix_tokens = prefix_tokens[:, 1:]
if getattr(self.args, "add_bos_token", False):
bos_token = self.source_dictionary.bos()
After Change
if getattr(self.args, "add_bos_token", False):
bos_token = self.source_dictionary.bos()
else:
bos_token = self.source_dictionary.eos()
// SequenceGenerator doesn"t use src_tokens directly, we need to
// pass the `prefix_tokens` argument instead
if prefix_tokens is None and sample["net_input"]["src_tokens"].nelement():