batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
seq = []
seqLogprobs = []
for t in range(self.seq_length + 2):
if t == 0:
xt = self.img_embed(fc_feats)
After Change
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
seq = fc_feats.new_zeros(batch_size, self.seq_length, dtype=torch.long)
seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
for t in range(self.seq_length + 2):
if t == 0:
xt = self.img_embed(fc_feats)
else: