w = tf.reshape(w, shape=(batch, self.groups, self.radix, -1))
w = tf.transpose(w, perm=(0, 2, 1, 3, 4))
if is_channels_first(self.data_format):
w = tf.transpose(w, perm=(0, 2, 1, 3, 4))
else:
w = tf.transpose(w, perm=(0, 1, 2, 4, 3))
w = self.softmax(w)
if is_channels_first(self.data_format):
After Change
height = x_shape[1]
width = x_shape[2]
channels = x_shape[3]
x = tf.reshape(x, shape=(batch, height, width, self.radix, channels // self.radix))
w = tf.math.reduce_sum(x, axis=-2)
w = self.pool(w)
if self.use_conv:
axis = -1 if is_channels_first(self.data_format) else 1
w = tf.expand_dims(tf.expand_dims(w, axis=axis), axis=axis)
w = self.conv1(w) if self.use_conv else self.fc1(w)
w = self.bn(w, training=training)
w = self.activ(w)
w = self.conv2(w) if self.use_conv else self.fc2(w)
w = tf.reshape(w, shape=(batch, self.groups, self.radix, -1))
w = tf.transpose(w, perm=(0, 2, 1, 3))
w = self.softmax(w)
if is_channels_first(self.data_format):
w = tf.reshape(w, shape=(batch, self.radix, -1, 1, 1))
else:
w = tf.reshape(w, shape=(batch, 1, 1, self.radix, -1))
x = x * w
if is_channels_first(self.data_format):
x = tf.math.reduce_sum(x, axis=1)
else:
x = tf.math.reduce_sum(x, axis=-2)
return x