x_batch = Variable(batch)
if cuda_device is not None:
x_batch = x_batch.cuda(cuda_device)
batch_pred = self(x_batch)
preds.append(batch_pred.data)
self.train()
return Variable(torch.cat(preds))
After Change
batch_data = [batch_data]
input_batch = batch_data[0]
if not isinstance(input_batch, (list,tuple)):
input_batch = [input_batch]
input_batch = [Variable(ins) for ins in input_batch]
if cuda_device > -1:
input_batch = [ins.cuda(cuda_device) for ins in input_batch]