it = it.view(-1).long()
else:
if temperature == 1.0:
prob_prev = torch.exp(logprobs.data) // fetch prev distribution: shape Nx(M+1)
else:
// scale logprobs by temperature
prob_prev = torch.exp(torch.div(logprobs.data, temperature))
it = torch.multinomial(prob_prev, 1)
sampleLogprobs = logprobs.gather(1, it) // gather the logprobs at sampled positions
it = it.view(-1).long() // and flatten indices for downstream processing
// stop when all finished
if t == 0:
After Change
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
else:
logprobs = logprobs / temperature
it = torch.distributions.Categorical(logits=logprobs.detach()).sample()
sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) // gather the logprobs at sampled positions
// stop when all finished
if t == 0: