mean = mean[..., :, None, None].to(data.device)
std = std[..., :, None, None].to(data.device)
out = data.sub(mean).div(std)
return out
// - denormalise
After Change
if mean.shape[0] != data.shape[-3] and mean.shape[:2] != data.shape[:2]:
raise ValueError("mean lenght and number of channels do not match")
if std.shape[0] != data.shape[-3] and std.shape[:2] != data.shape[:2]:
raise ValueError("std lenght and number of channels do not match")
if std.shape != mean.shape: