if not getattr(self.model_, "batch_first", True):
X = np.rollaxis(X, 0, 2)
X = np.array(X, dtype=np.float32)
X = Variable(torch.from_numpy(X))if self.gpu_:
X = X.cuda()
batch["X"] = X
// forward pass
fX = self.model_(batch["X"])
After Change
X = batch["X"]
if not getattr(self.model_, "batch_first", True):
X = np.rollaxis(X, 0, 2)
batch["X"] = torch.tensor(X, dtype=torch.float32, device=self.device_)
// forward pass
fX = self.model_(batch["X"])