output, state = self.core(xt.unsqueeze(0), state)
logprobs = F.log_softmax(self.logit(self.dropout(output.squeeze(0))), dim=1)
return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seqLogprobs], 1)
After Change
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
for t in range(self.seq_length + 2):
if t == 0:
xt = self.img_embed(fc_feats)
else: