output, state = self.core(xt.unsqueeze(0), state)
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
return torch.cat([_.unsqueeze(1)for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
After Change
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
// sample the next word
if t == self.seq_length + 1: // skip if we achieve maximum length
break
if sample_max:
sampleLogprobs, it = torch.max(logprobs.data, 1)
it = it.view(-1).long()
else:
if temperature == 1.0:
prob_prev = torch.exp(logprobs.data).cpu() // fetch prev distribution: shape Nx(M+1)
else:
// scale logprobs by temperature
prob_prev = torch.exp(torch.div(logprobs.data, temperature)).cpu()
it = torch.multinomial(prob_prev, 1).cuda()
sampleLogprobs = logprobs.gather(1, it) // gather the logprobs at sampled positions
it = it.view(-1).long() // and flatten indices for downstream processing
if t >= 1:
// stop when all finished
if t == 1:
unfinished = it > 0
else:
unfinished = unfinished * (it > 0)
it = it * unfinished.type_as(it)
seq[:,t-1] = it //seq[t] the input of t+2 time step
seqLogprobs[:,t-1] = sampleLogprobs.view(-1)
if unfinished.sum() == 0:
break