// Since tree roots are designated to 0. In the batched graph we can
// simply find the corresponding node ID by looking at node_offset
root_ids = mol_tree_batch.node_offset[:-1]
n_nodes = len(mol_tree_batch.nodes)
edge_list = mol_tree_batch.edge_list
n_edges = len(edge_list)
// Assign structure embeddings to tree nodes
After Change
def run(self, mol_tree_batch, mol_tree_batch_lg):
// Since tree roots are designated to 0. In the batched graph we can
// simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
// Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({
"x": self.embedding(mol_tree_batch.ndata["wid"]),
"h": cuda(torch.zeros(n_nodes, self.hidden_size)),
})
// Initialize the intermediate variables according to Eq (4)-(8).
// Also initialize the src_x and dst_x fields.
// TODO: context?
mol_tree_batch.edata.update({
"s": cuda(torch.zeros(n_edges, self.hidden_size)),
"m": cuda(torch.zeros(n_edges, self.hidden_size)),
"r": cuda(torch.zeros(n_edges, self.hidden_size)),
"z": cuda(torch.zeros(n_edges, self.hidden_size)),
"src_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"dst_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"rm": cuda(torch.zeros(n_edges, self.hidden_size)),
"accum_rm": cuda(torch.zeros(n_edges, self.hidden_size)),
})
// Send the source/destination node features to edges
mol_tree_batch.apply_edges(
func=lambda edges: {"src_x": edges.src["x"], "dst_x": edges.dst["x"]},
)
// Message passing
// I exploited the fact that the reduce function is a sum of incoming
// messages, and the uncomputed messages are zero vectors. Essentially,
// we can always compute s_ij as the sum of incoming m_ij, no matter
// if m_ij is actually computed or not.
for eid in level_order(mol_tree_batch, root_ids):
//eid = mol_tree_batch.edge_ids(u, v)
mol_tree_batch_lg.pull(
eid,
enc_tree_msg,
enc_tree_reduce,
self.enc_tree_update,
)
// Readout
mol_tree_batch.update_all(
enc_tree_gather_msg,
enc_tree_gather_reduce,
self.enc_tree_gather_update,
)
root_vecs = mol_tree_batch.nodes[root_ids].data["h"]
return mol_tree_batch, root_vecs