elif all(isinstance(el, str) for el in point_colors):
point_colors = group_by_category(point_colors)
categories = list(set(np.sort(point_colors)))
x_stacked = np.vstack(x)
x_reshaped = [[] for i in categories]
for idx,point in enumerate(point_colors):
x_reshaped[categories.index(point)].append(x_stacked[idx])
x = [np.vstack(i) for i in x_reshaped]
After Change
////HYPERTOOLS-SPECIFIC ARG PARSING////
if "n_clusters" in kwargs:
n_clusters=kwargs["n_clusters"]
if "ndims" in kwargs:
ndims = kwargs["ndims"]
else:
ndims = 3
cluster_labels = get_clusters(x, ndims, n_clusters)