assert tensor.ndim in [3, 4]
if tensor.ndim == 3 and self.add_channel_axis_if_necessary:
// Add channel axisreturn torch.from_numpy(tensor[None, ...])
else:
// Channel axis is in alreadyreturn torch.from_numpy(tensor)
elif self.dimensionality == 2:
// We"re dealing with an image. tensor can either be 2D or 3Dassert tensor.ndim in [2, 3]
if tensor.ndim == 2 and self.add_channel_axis_if_necessary:
// Add channel axisreturn torch.from_numpy(tensor[None, ...])
else:
// Channel axis is in alreadyreturn torch.from_numpy(tensor)
elif self.dimensionality == 1:
// We"re dealing with a vector - it has to be 1Dassert tensor.ndim == 1return torch.from_numpy(tensor)else:
raise NotImplementedError