m = learn.model[0]
m.reset()
pred,_ = m(t)
res = pred[-1][-1,:,:].squeeze().detach().numpy()
return(res)
After Change
pred,_ = m(t)
//return concatenation of last, mean and max
last = pred[-1][-1,:,:].squeeze()
avg_pool = pred[-1].mean(0)[0].squeeze()
max_pool = pred[-1].max(0)[0].squeeze()
res = torch.cat((last,avg_pool,max_pool)).detach().cpu().numpy()
return(res)