q_group = sorted(q_group.items(), key=lambda x: x[0])
word_positions, guesses = list(map(list, zip(*q_group)))
// each entry is a list of (guess, logit, prob) sorted by logit
labels = np.array([int(g[0][0] == answer) for g in guesses], dtype=np.int32)
vectors = vector_converter(guesses)
return qid, vectors, labels, word_positions
After Change
// labels = np.array([int(g[0][0] == answer) for g in guesses], dtype=np.int32)
// vectors = vector_converter(guesses)
q_rows = q_rows.groupby("char_index")
q_rows = q_rows.apply(lambda x: x.sort_values("score"))
return qid, vectors, labels, word_positions
def read_data(