if covar_root.shape[-1] < base_samples.shape[-2]:
base_samples = base_samples[..., : covar_root.shape[-1], :]
elif covar_root.shape[-1] > base_samples.shape[-2]:
raise RuntimeError("Incompatible dimension of `base_samples`")
res = covar_root.matmul(base_samples) + self.loc.unsqueeze(-1)
// Permute and reshape new samples to be original size
res = res.permute(-1, *range(self.loc.dim())).contiguous()
After Change
base_samples = base_samples[..., : covar_root.shape[-1], :]
elif covar_root.shape[-1] > base_samples.shape[-2]:
// raise RuntimeError("Incompatible dimension of `base_samples`")
covar_root = covar_root.transpose(-2, -1)
res = covar_root.matmul(base_samples) + self.loc.unsqueeze(-1)
// Permute and reshape new samples to be original size
res = res.permute(-1, *range(self.loc.dim())).contiguous()