// embed fc and att feats
fc_feats = self.fc_embed(fc_feats)
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks)
// Project the attention feats first to reduce memory and computation comsumptions.
p_att_feats = self.ctx2att(att_feats)
for i in range(seq.size(1) - 1):
if self.training and i >= 1 and self.ss_prob > 0.0: // otherwiste no need to sample
sample_prob = fc_feats.data.new(batch_size).uniform_(0, 1)
sample_mask = sample_prob < self.ss_prob
if sample_mask.sum() == 0:
it = seq[:, i].clone()
else:
sample_ind = sample_mask.nonzero().view(-1)
it = seq[:, i].data.clone()
//prob_prev = torch.exp(outputs[-1].data.index_select(0, sample_ind)) // fetch prev distribution: shape Nx(M+1)
//it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1))
// prob_prev = torch.exp(outputs[-1].data) // fetch prev distribution: shape Nx(M+1)
prob_prev = torch.exp(outputs[:, i-1].data) // fetch prev distribution: shape Nx(M+1)
it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
it = Variable(it, requires_grad=False)
else:
it = seq[:, i].clone()
// break if all the sequences end
if i >= 1 and seq[:, i].data.sum() == 0:
break
xt = self.embed(it)
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
output = F.log_softmax(self.logit(output))
outputs[:, i] = output
// outputs.append(output)
return outputs
After Change
// outputs = []
outputs = Variable(fc_feats.data.new(batch_size, seq.size(1) - 1, self.vocab_size+1).zero_())
fc_feats, att_feats, p_att_feats = self._prepare_feature(fc_feats, att_feats, att_masks)
for i in range(seq.size(1) - 1):
if self.training and i >= 1 and self.ss_prob > 0.0: // otherwiste no need to sample