final_shape = []
first_tensor_idx = None
tensor_idx_shape = None
continuous_tensor_index = True
slice_after_tensor_idx = False
for i, (size, idx) in enumerate(zip(obj.shape, indices)):
// Handle slice: that dimension gets downsized
if isinstance(idx, slice):
if idx == _noop_index:
final_shape.append(size)
else:
final_shape.append(len(range(*idx.indices(size))))
// If we don"t have a continuous set of tensor indices, then the tensor indexed part
// goes to the front
if first_tensor_idx is not None:
slice_after_tensor_idx = True
// Handle int: we "lose" that dimension
elif isinstance(idx, int):
if settings.debug.on():
try:
range(size)[idx]
except IndexError:
raise IndexError(
"index element {} ({}) is invalid: out of range for obj of size "
"{}.".format(i, idx, obj.shape)
)
// Handle tensor index - this one is complicated
elif torch.is_tensor(idx):
if tensor_idx_shape is None:
tensor_idx_shape = idx.numel()
first_tensor_idx = len(final_shape)
final_shape.append(tensor_idx_shape)
// If we don"t have a continuous set of tensor indices, then the tensor indexed part
// goes to the front
elif slice_after_tensor_idx:
continuous_tensor_index = False
else:
if settings.debug.on():
if idx.numel() != tensor_idx_shape:
raise IndexError(
"index element {} is an invalid size: expected tensor indices of size {}, got "
"{}.".format(i, tensor_idx_shape, idx.numel())
)
// If we don"t have a continuous set of tensor indices, then the tensor indexed part
// goes to the front
if not continuous_tensor_index:
del final_shape[first_tensor_idx]
final_shape.insert(0, tensor_idx_shape)
return torch.Size(final_shape)