type_mapping = collections.OrderedDict([
(("%i" % i), input.dtype) for i in _pycompat.irange(input.ndim)
])
out_dtype = numpy.dtype(list(type_mapping.items()))
default_1d = numpy.full((1,), numpy.nan, dtype=out_dtype)
func = functools.partial(
_utils._center_of_mass, shape=input.shape, dtype=out_dtype
)
com_lbl = labeled_comprehension(
input, labels, index,
func, out_dtype, default_1d[0], pass_positions=True
)
com_lbl = dask.array.stack([com_lbl[k] for k in type_mapping], axis=-1)
return com_lbl
After Change
// This only matters if index is some array.
index = index.T
out_dtype = numpy.dtype([("com", input.dtype, (input.ndim,))])
default_1d = numpy.full((1,), numpy.nan, dtype=out_dtype)
func = functools.partial(
_utils._center_of_mass, shape=input.shape, dtype=out_dtype