output, h, new_target = model(X, y, h, s)
output = output.reshape((-3, -1))
new_target = new_target.reshape((-1,))
l = loss(output, new_target) * m.reshape((-1,))
Ls.append(l/args.batch_size)
hiddens[j] = h
After Change
for _ in range(len(data)):
hidden, ls = parallel.get()
// hidden states are ordered by context id
index = context.index(hidden[0].context)
hiddens[index] = hidden
Ls.append(ls)