conv1 = func(name="NDHWC", data_format="NDHWC")
x = tf.constant(np.random.random(self.INPUT_SHAPE).astype(np.float32))
o1 = conv1(x)
// We will force both modules to share the same weights by creating
// a custom getter that returns the weights from the first conv module when
// tf.get_variable is called.
custom_getter = {"w": create_custom_field_getter(conv1, "w"),
"b": create_custom_field_getter(conv1, "b")}
conv2 = func(name="NCDHW", data_format="NCDHW", custom_getter=custom_getter)
x_transpose = tf.transpose(x, perm=(0, 4, 1, 2, 3))o2 = tf.transpose(conv2(x_transpose), perm=(0, 2, 3, 4, 1))self.checkEquality(o1, o2)
@parameterized.named_parameters(("WithBias", True), ("WithoutBias", False))
def testConv3DDataFormatsBatchNorm(self, use_bias):
Similar to `testConv3DDataFormats`, but this checks BatchNorm support.