def _inference(self, Xs, mode=None):
estimator = self.get_estimator()
input_func = self.input_pipeline.get_predict_input_fn(Xs)
length = len(Xs) if not callable(Xs) else None
pred_gen = list(
map(
lambda y: y[mode] if mode else y,
After Change
self._data = self.input_pipeline._format_for_inference(Xs)
self._iter_idx = 0
n = len(self._data)
if not getattr(self, "estimator", None):
self.estimator = self.get_estimator()
self._input_fn = lambda: self.input_pipeline._dataset_without_targets(
self._data_generator, train=None
).batch(self.config.batch_size)
self._predictions = self.estimator.predict(input_fn=self._input_fn)
predictions = [None] * n
for i in range(n):
y = next(self._predictions)
y = y[mode] if mode else y
predictions[i] = y