c23e2307b9b2e0147ef4699b2b5bbabdac08a0fc,parlai/core/torch_generator_agent.py,TreeSearch,advance,#TreeSearch#Any#,888
Before Change
// beam blocking
if self.block_ngram > 0:
for beam_id, hyp in enumerate(self.partial_hyps):
if len(hyp) < self.block_ngram - 1:
continue
ngrams = self._find_ngrams(hyp, self.block_ngram)
prefix = hyp[-(self.block_ngram - 1) :]
for ngram in ngrams:
if prefix == list(ngram[:-1]) or self.block_ngram == 1:
logprobs[beam_id][ngram[-1]] = neginf(logprobs.dtype)
hyp_ids, tok_ids, self.scores = self.select_paths(logprobs, self.scores)
// use clone() here to ensure that self.all_scores will not be changed
// later due to any penalties to self.scores
self.all_scores.append(self.scores.clone())
After Change
if self.block_ngram > 0:
logprobs = self._block_ngrams(self.block_ngram, logprobs, None)
if self.context_block_ngram > 0:
if self.context is None:
raise ValueError(
"Must use TreeSearch.set_context to use context blocking."
)
logprobs = self._block_ngrams(
self.context_block_ngram, logprobs, self.context
)
hyp_ids, tok_ids, self.scores = self.select_paths(logprobs, self.scores)
// use clone() here to ensure that self.all_scores will not be changed
// later due to any penalties to self.scores
self.all_scores.append(self.scores.clone())
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 12
Instances
Project Name: facebookresearch/ParlAI
Commit Name: c23e2307b9b2e0147ef4699b2b5bbabdac08a0fc
Time: 2019-12-03
Author: roller@fb.com
File Name: parlai/core/torch_generator_agent.py
Class Name: TreeSearch
Method Name: advance
Project Name: maciejkula/spotlight
Commit Name: eef158f03c4ec9bf872a3e358d62a1fd21a73c35
Time: 2017-07-13
Author: maciej.kula@gmail.com
File Name: examples/movielens_cnn.py
Class Name:
Method Name:
Project Name: open-mmlab/mmcv
Commit Name: 4ec73abbcc30314604d7afa6f59c1d6be5e8426c
Time: 2020-08-17
Author: linjintao@sensetime.com
File Name: mmcv/runner/hooks/logger/text.py
Class Name: TextLoggerHook
Method Name: log