def lists_collect_spans(outer_spans, inner_spans):
mapping = []
flattened_spans = []
for idx, spans in enumerate(outer_spans):
for s in spans:
flattened_spans.append(s)
mapping.append(idx)
flattened_spans, mapping = zip(*sorted(zip(flattened_spans, mapping)))
flat_bins = list(spans_collect_spans(flattened_spans, inner_spans))
bins = [[] for _ in range(len(outer_spans))]
for flatidx, flatbin in enumerate(flat_bins):
binidx = mapping[flatidx]
bins[binidx].extend(flatbin)
for bin in bins:
yield unique(bin)