def run(self, mol_tree_batch, mol_tree_batch_lg):
// Since tree roots are designated to 0. In the batched graph we can
// simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum([0] + mol_tree_batch.batch_num_nodes)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
After Change
def run(self, mol_tree_batch, mol_tree_batch_lg):
// Since tree roots are designated to 0. In the batched graph we can
// simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()