state = ((W, b, hidden, cell, X), (Y, gates))
state = jax.lax.fori_loop(0, X.shape[0], _lstm_stepper, state)
(W, b, hidden, cell, X), (Y, gates) = state
return Y, cell, gates
@jax_jit()
After Change
C = xp.zeros((nL+1, nB, nO), dtype="f")
// Set initial hidden and cell states. The Y and C will be shifted 1,
// so that we can have fewer arrays.
Y = index_update(Y, index[0], h0)
C = index_update(C, index[0], c0)
state = ((W, b, X), (Y, C, G))
state = jax.lax.fori_loop(0, X.shape[0], _lstm_stepper, state)
(W, b, X), (Y, C, G) = state