if incremental_state is None or len(incremental_state) == 0:
return
prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state)
cached_state = (prev_hiddens, prev_cells, [input_feed])
new_state = [self.reorder_state(state, new_order) for state in cached_state]
prev_hiddens_tensor = torch.stack(new_state[0])
prev_cells_tensor = torch.stack(new_state[1])
cached_state_new = torch.jit.annotate(
Dict[str, Optional[Tensor]],
{"prev_hiddens": prev_hiddens_tensor, "prev_cells": prev_cells_tensor, "input_feed": new_state[2][0]})
self.set_incremental_state(incremental_state, "cached_state", cached_state_new),
return
After Change
if incremental_state is None or len(incremental_state) == 0:
return
prev_hiddens, prev_cells, input_feed = self.get_cached_state(incremental_state)
prev_hiddens = [p.index_select(0, new_order) for p in prev_hiddens]
prev_cells = [p.index_select(0, new_order) for p in prev_cells]
if input_feed is not None:
input_feed = input_feed.index_select(0, new_order)
cached_state_new = torch.jit.annotate(