def initial_global_state(self):
See base class.
// NormalizedQuery has no global state beyond the numerator state.
return self._numerator.initial_global_state()
def derive_sample_params(self, global_state):
See base class.
return self._numerator.derive_sample_params(global_state)
After Change
denominator = tf.cast(self._denominator, tf.float32)
else:
denominator = None
return self._GlobalState(
self._numerator.initial_global_state(), denominator)
def derive_sample_params(self, global_state):
See base class.
return self._numerator.derive_sample_params(global_state.numerator_state)