// Send to device and wrap as variable
batch = self.wrap_batch(batch)
// Separate inputs from targets
inputs, target = batch[0:-1], batch[-1]
// Compute prediction
prediction = self.apply_model(*inputs)
// Compute loss
loss = self.criterion(prediction, target)
After Change
// Send to device and wrap as variable
batch = self.wrap_batch(batch)
// Separate inputs from targets
inputs, target = self.split_batch(batch, from_loader="train")
// Compute prediction
prediction = self.apply_model(*inputs)
// Compute loss
loss = self.criterion(prediction, target)