wider_dtype = dtype_list[0]
wider_dtype_index = DTYPES.index(wider_dtype)
for dtype in dtype_list[1:]:
index = DTYPES.index(dtype)
if index > wider_dtype_index:
wider_dtype = dtype
wider_dtype_index = index
tensor_list = [cast(x, dtype=wider_dtype) for x in tensor_list]
After Change
dtype_list = [DTYPES[x.dtype] for x in tensor_list]
wider_dtype_index = max(dtype_list)
wider_dtype = list(DTYPES.keys())[wider_dtype_index]
tensor_list = [cast(x, dtype=wider_dtype) for x in tensor_list]
return tensor_list