batch_size, drop_last)
def __iter__(self):
batches = list(super().__iter__())
if self.last_batch_first:
last_batch = batches.pop()
if self.shuffle:
random.shuffle(batches)
After Change
if not self.biggest_batches_first:
return get_batches()
else:
batches = list(get_batches())
indices = heapq.nlargest(
5,
range(len(batches)),
key=lambda i: len(pickle.dumps([self.data[j] for j in batches[i]])))
front = [batches[i] for i in indices]
for i in sorted(indices, reverse=True):
batches.pop(i)
batches[0:0] = front
return iter(batches)