def __call__(self, graph): // type: (Graph) -> Graph
nodes_to_be_removed = []
for node in graph.nodes:
if node.op_type == "Gather" and len(node.parents) == 0 and \
node.inputs[0] in node.input_tensors and node.inputs[1] in node.input_tensors:
nodes_to_be_removed.append(node)
data = node.input_tensors[node.inputs[0]]
idx = node.input_tensors[node.inputs[1]]
axis = node.attrs.get("axis", 0)
x = np.take(data, idx, axis=axis)
graph.shape_dict[node.outputs[0]] = x.shape
for child_node in node.children:
child_node.parents.remove(node)
child_node.input_tensors[node.outputs[0]] = x
After Change
break
for parent in node.parents:
parent.children.remove(node)
transformed_nodes = []
for node in graph.nodes:
if node not in nodes_to_be_removed: