(batch_size, seq_len, dim) = sequence_embeddings.size()
// Move embedding dimensions to channels and add a fourth dim.
sequence_embeddings = (sequence_embeddings
.permute(0, 2, 1)
.contiguous()
.view(batch_size, dim, seq_len, 1))
x = sequence_embeddings
for cnn_layer in self.cnn_layers:
x = cnn_layer(x)
user_representations = x.view(batch_size, dim, -1)
pooled_representations = (user_representations
.max(-1)[0]
.view(batch_size, dim))
After Change
sequence_embeddings = (self.item_embeddings(item_sequences)
.permute(0, 2, 1))
// Add a trailing dimension of 1
sequence_embeddings = (sequence_embeddings
.unsqueeze(3))
x = sequence_embeddings
for i, cnn_layer in enumerate(self.cnn_layers):
// Pad so that the CNN doesn"t have the future
// of the sequence in its receptive field.
x = F.pad(x, (0, 0, self.kernel_width - min(i, 1), 0))
x = F.relu(cnn_layer(x))
x = x.squeeze(3)
return x[:, :, :-1], x[:, :, -1]
def forward(self, user_representations, targets):