if getattr(input_, "owner", None) and isinstance(input_.owner.op, DimShuffle):
// check if it only adds dimension to the left
pattern = input_.type.broadcastable
if not pattern[0]:
return False
j = 0
for i, bool_ in enumerate(pattern):
if not bool_:
j = i
break
if sum(pattern[j:]) == 0:
return input_.inputs
else :
return False
After Change
if input_.owner and isinstance(input_.owner.op, DimShuffle):
// check if it only adds dimension to the left
new_order = input_.owner.op.new_order
flag = False
for i, dim in enumerate(new_order_bool):
if i == 0 and dim == "x":
flag = True
elif dim == "x" and flag:
continue
elif i > 0 and flag:
flag = False
elif i > 0 and not dim == "x":
continue
else:
return False
return input_.inputs
return False