for batch_sz in [1, 3]:
beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2,
torch.device("cpu"), GlobalScorerStub(), 0, 30,
False, ngram_repeat, set(),
torch.randint(0, 30, (batch_sz,)), False, 0.)
for i in range(ngram_repeat + 4):
// predict repeat_idx over and over again
word_probs = torch.full(
After Change
n_words = 100
repeat_idx = 47
ngram_repeat = 3
device_init = torch.zeros(1, 1)
for batch_sz in [1, 3]:
beam = BeamSearch(
beam_sz, batch_sz, 0, 1, 2, 2,
GlobalScorerStub(), 0, 30,
False, ngram_repeat, set(),
False, 0.)
beam.initialize(device_init, torch.randint(0, 30, (batch_sz,)))
for i in range(ngram_repeat + 4):
// predict repeat_idx over and over again
word_probs = torch.full(
(batch_sz * beam_sz, n_words), -float("inf"))