// When training with multiple TPU workers, datasets needs to be cloned
// across workers. Since Dataset instance cannot be cloned in eager mode,
// we instead pass callable that returns a dataset.
input_data = input_fn(self._params.as_dict())
if callable(input_data):
iterator = iter(
strategy.experimental_distribute_datasets_from_function(input_data))
else:
iterator = iter(strategy.experimental_distribute_dataset(input_data))
return iterator
def _create_test_step(self):
Creates a distributed test step.
After Change
// across workers. Since Dataset instance cannot be cloned in eager mode,
// we instead pass callable that returns a dataset.
input_data = input_fn(self._params)
return iter(strategy.experimental_distribute_dataset(input_data))
// TODO(yeqing): Extract the train_step out as a class for re-usability.
def _create_train_step(self):
Creates a distributed training step.