device = input.device
index = []
for i, l in enumerate(lengths):
index.extend(range(i * max_len, i * max_len + l))
index = th.tensor(index).to(device)
return gather_row(input.view(batch_size * max_len, -1), index)
After Change
index = th.ones(out_len, dtype=th.int64, device=device)
cum_lengths = th.cumsum(lengths, 0)
index[cum_lengths[:-1]] += (max_len - lengths[:-1])
index = th.cumsum(index, 0) - 1
return input[index]
def boolean_mask(input, mask):
if "bool" not in str(mask.dtype):