// running the RNN we zero out the corresponding rows in the result.
// First count how many sequences are empty.
batch_size = mask.size()[0]
num_valid = torch.sum(mask[:, 0]).int().data[0]
// Force every sequence to be length at least one. Need to `.clone()` the mask
After Change
zeros = state.data.new(num_layers_times_directions,
batch_size - num_valid,
encoding_dim).fill_(0)
state = torch.cat([state, zeros], 1)
// Restore the original indices and return the final state of the
// top layer. Pytorch"s recurrent layers return state in the form
// (num_layers * num_directions, batch_size, hidden_size) regardless