tf.concat([input_tensors[i] for i in inputs_for_unit], axis=1))
if len(lattice_inputs) > 1:
// Stack into (-1, units, lattice_rank) for multi-unit lattice layer
lattice_inputs = tf.stack(lattice_inputs, axis=1)
else:
lattice_inputs = lattice_inputs[0]
output_monotonicity = max(monotonicities)
// Call each lattice layer and store based on output monotonicy.
After Change
for monotonicities, inputs_for_units in self._rtl_structure:
if len(inputs_for_units) == 1:
inputs_for_units = inputs_for_units[0]
lattice_inputs = tf.gather(flattened_input, inputs_for_units, axis=1)
output_monotonicity = max(monotonicities)
// Call each lattice layer and store based on output monotonicy.
outputs_for_monotonicity[output_monotonicity].append(
self._lattice_layers[str(monotonicities)](lattice_inputs))