"PolicyInfo",
("log_probability", "predicted_rewards"))
// Set default empty tuple for all fields.
PolicyInfo.__new__.__defaults__ = ((),) * len(PolicyInfo._fields)
action_spec = [
tensor_spec.BoundedTensorSpec((2, 3), dtype, -10, 10),
After Change
// Test with batch, we should see the additional outer batch dim for both
// `action` and `info`.
batch_size = 2batched_time_step = self.create_batch(time_step, batch_size)batched_action_step = policy.action(batched_time_step)
tf.nest.assert_same_structure(action_spec, batched_action_step.action)
self.assertEqual((batch_size, 2, 3,), batched_action_step.action[0].shape)
self.assertEqual((batch_size, 1, 2,), batched_action_step.action[1].shape)
tf.nest.assert_same_structure(info_spec, batched_action_step.info)