for d in itr:
if "y" in d.keys():
// d["x"] is in np.int32, but pytorch require np.int64
yield np.array(d["x"], dtype=int), np.array(d["y"], dtype=int)
else:
yield np.array(d["x"], dtype=int), np.array(d["x"], dtype=int)
After Change
// not sure about the optimal choice of shuffle_queue_size here:
itr = self.tfrecord.iterator_utils.shuffle_iterator(itr, queue_size=128)
for d in itr:
yield np.array(d[self.x], dtype=int), np.array(d[self.y], dtype=int)
class MultiFileDatasetReader:
Provide a base-class to do operations that are independent of token representation