try:
out = torch.zeros_like(x)
except AttributeError:
out = x.new(x.size()).zero_()
// unroll inputs and outputs for ease of iteration through elements
condition_ = condition.view(N)
x_ = x.view(N)
After Change
def where(condition, x, y):
assert condition.shape == x.shape == y.shape, "Dimension mismatch"
result = zeros_like(condition, **context(x))
result[condition] = x
result[~condition] = y
return result
def all(tensor):
return torch.sum(tensor != 0)