Return the output of the wrapped TensorFlow model for the given input,
along with a callback to handle the backward pass.
convert_inputs = model.get_attr("convert_inputs")convert_outputs = model.get_attr("convert_outputs")
tensorflow_model = model.shims[0]
X_tensorflow, get_dX = convert_inputs(model, X, is_train)
if is_train:
Y_tensorflow, tensorflow_backprop = tensorflow_model(X_tensorflow, is_train)
After Change
Return the output of the wrapped TensorFlow model for the given input,
along with a callback to handle the backward pass.
convert_inputs = model.attrs["convert_inputs"]convert_outputs = model.attrs["convert_outputs"]
tensorflow_model = model.shims[0]
X_tensorflow, get_dX = convert_inputs(model, X, is_train)
if is_train:
Y_tensorflow, tensorflow_backprop = tensorflow_model(X_tensorflow, is_train)