def get_state(self):
""" Return all hidden states/cells, for decoding purpose"""
state_list = [s.cpu() for s in self.state_list]
if self.enable_cell:
cell_list = [c.cpu() for c in self.cell_list]
return state_list, cell_list
return state_list
After Change
def get_state(self):
""" Return all hidden states/cells, for decoding purpose"""
if self.enable_cell:
return (self.hidden_state[0].cpu(),self.hidden_state[1].cpu())
else:
return self.hidden_state.cpu()