arr_encoded = list(itertools.chain.from_iterable(self.input_pipeline._text_to_ids([x]) for x in X))
labels, batch_probas, associations = [], [], []
for pred in self._inference(X, mode=None):
pred_labels = self.input_pipeline.label_encoder.inverse_transform(pred["sequence"])labels.append(pred_labels)
batch_probas.append(pred["sequence_probs"])
pred["association_probs"] = self.prune_probs(pred["association_probs"], pred_labels)
most_likely_associations, most_likely_class_id = zip(*[np.unravel_index(np.argmax(a, axis=None), a.shape) for a in pred["association_probs"]])
associations.append((