g.register_apply_node_func(self.cell.apply_func, batchable=True)
// feed embedding
embeds = self.embedding(batch.wordid)
x = zero_initializer((n, self.x_size))
x = x.index_copy(0, batch.nid_with_word, embeds)
if h is None:
h = zero_initializer((n, self.h_size))
After Change
// feed embedding
wordid = g.pop_n_repr("x")
mask = (wordid != dgl.data.SST.PAD_WORD)
wordid = wordid * mask.long()
embeds = self.embedding(wordid)
x = embeds * th.unsqueeze(mask, 1).float()
if h is None:
h = zero_initializer((n, self.h_size))
h_tild = zero_initializer((n, self.h_size))
if c is None: