// This makes the index global (e.g. best beams for the second
// batch example is in [K, 2*K)).
offsets = np.arange(bsz) * self.K
offset_beams = tf.cast(best_beams, tf.int64) + tf.expand_dims(offsets, -1)
flat_beams = tf.reshape(offset_beams, [bsz * self.K])
// Select the paths to extend based on the best beams
flat_paths = tf.reshape(paths, [bsz * self.K, -1])
After Change
done_mask = tf.expand_dims(done_mask, -1)
// Can creating this mask be moved out of the loop? It never changes but we don"t have V
// This mask selects the EOS token
eos_mask = tf.cast(tf.zeros((1, 1, V)) + tf.reshape(tf.cast(tf.range(V) == Offsets.EOS, tf.float32), (1, 1, V)), done_mask.dtype)
// eos_mask[:, :, Offsets.EOS] = 1
// This mask selects the EOS token of only the beams that are done.
mask = done_mask & eos_mask