for i, core in enumerate(self._cores):
if self._skip_connections and i > 0:
flat_input = (nest.flatten(inputs), nest.flatten(current_input))
flat_input = [tf.concat(input_, 1) for input_ in zip(*flat_input)]
current_input = nest.pack_sequence_as(structure=inputs,
flat_sequence=flat_input)
After Change
outputs.append(current_input)
if self._skip_connections and self._concat_final_output_if_skip:
output = nest.map_structure(concatenate, *outputs)
else:
output = current_input
self._last_output_size = _get_shape_without_batch_dimension(output)