end_idx = start_idx + self.generator.batch_size
if start_idx >= self.data_size:
raise IndexError("Mapper: batch_num larger than length of data")
// print("Fetching {} batch {} [{}]".format(self.name, batch_num, start_idx))
// Get head nodes
head_ids = self.ids[start_idx:end_idx]
After Change
// Get head nodes and labels
head_ids, batch_targets = next(self._gen)
self.ids = list(head_ids)
// Get head node types from all src, dst nodes extracted from all links,
// and make sure there"s only one pair of node types:
self.head_node_types = self._infer_head_node_types(self.generator.schema)