assert tensor.ndim in [3, 4]
if tensor.ndim == 3 and self.add_channel_axis_if_necessary:
// Add channel axis
return torch.from_numpy(tensor[None, ...])
else:
// Channel axis is in already
return torch.from_numpy(tensor)
elif self.dimensionality == 2:
// We"re dealing with an image. tensor can either be 2D or 3D
assert tensor.ndim in [2, 3]
if tensor.ndim == 2 and self.add_channel_axis_if_necessary:
// Add channel axis
return torch.from_numpy(tensor[None, ...])
else:
// Channel axis is in already
return torch.from_numpy(tensor)
elif self.dimensionality == 1:
// We"re dealing with a vector - it has to be 1D
assert tensor.ndim == 1
return torch.from_numpy(tensor)
else:
raise NotImplementedError