x = self.state_proc_model(x)
// Restack to batch_size x seq_len x rnn_input_dim
x = x.view(-1, self.seq_len, self.rnn_input_dim)
hid_0 = self.init_hidden(batch_size)_, final_hid = self.rnn_model(x, hid_0)
final_hid.squeeze_(dim=0)
// return tensor if single tail, else list of tail tensors
if len(self.model_tails) == 1:
return self.model_tails[0](final_hid)
After Change
// Restack to batch_size x seq_len x rnn_input_dim
x = x.view(-1, self.seq_len, self.rnn_input_dim)
if self.cell_type == "LSTM":
_output, (h_n, c_n) = self.rnn_model(x)
else:
_output, h_n = self.rnn_model(x)
hid_x = h_n[-1:].squeeze_(dim=0) // get final time-layer