def call(self, inputs, training=None):
// inputs.shape=[None, input_num_capsule, input_dim_vector]
// Expand dims to [None, input_num_capsule, 1, 1, input_dim_vector]
inputs_expand = K.expand_dims(K.expand_dims(inputs, 2), 2)
// Replicate num_capsule dimension to prepare being multiplied by W
// Now it has shape = [None, input_num_capsule, num_capsule, 1, input_dim_vector]
inputs_tiled = K.tile(inputs_expand, [1, 1, self.num_capsule, 1, 1])
// Begin: inputs_hat computation V1 ---------------------------------------------------------------------//
// Compute `inputs * W` by expanding the first dim of W. More time-consuming and need batch_size.
// w_tiled.shape = [batch_size, input_num_capsule, num_capsule, input_dim_vector, dim_vector]
w_tiled = K.tile(K.expand_dims(self.W, 0), [self.batch_size, 1, 1, 1, 1])
// Transformed vectors, inputs_hat.shape = [None, input_num_capsule, num_capsule, 1, dim_vector]
inputs_hat = K.batch_dot(inputs_tiled, w_tiled, [4, 3])
// End: inputs_hat computation V1 ---------------------------------------------------------------------//
// Begin: inputs_hat computation V2 ---------------------------------------------------------------------//
// Compute `inputs * W` by scanning inputs_tiled on dimension 0. This is faster but requires Tensorflow.
// inputs_hat.shape = [None, input_num_capsule, num_capsule, 1, dim_vector]
inputs_hat = tf.scan(lambda ac, x: K.batch_dot(x, self.W, [3, 2]),
elems=inputs_tiled,
initializer=K.zeros([self.input_num_capsule, self.num_capsule, 1, self.dim_vector]))
// End: inputs_hat computation V2 ---------------------------------------------------------------------//
// Begin: routing algorithm V1, dynamic ------------------------------------------------------------//
def body(i, b, outputs):
c = tf.nn.softmax(b, dim=2) // dim=2 is the num_capsule dimension
outputs = squash(K.sum(c * inputs_hat, 1, keepdims=True))
if i != 1:
b = b + K.sum(inputs_hat * outputs, -1, keepdims=True)
return [i-1, b, outputs]
cond = lambda i, b, inputs_hat: i > 0
loop_vars = [K.constant(self.num_routing), self.bias, K.sum(inputs_hat, 1, keepdims=True)]
shape_invariants = [tf.TensorShape([]),
tf.TensorShape([None, self.input_num_capsule, self.num_capsule, 1, 1]),
tf.TensorShape([None, 1, self.num_capsule, 1, self.dim_vector])]
_, _, outputs = tf.while_loop(cond, body, loop_vars, shape_invariants)
// End: routing algorithm V1, dynamic ------------------------------------------------------------//
// Begin: routing algorithm V2, static -----------------------------------------------------------//
// Routing algorithm V2. Use iteration. V2 and V1 both work without much difference on performance
assert self.num_routing > 0, "The num_routing should be > 0."
for i in range(self.num_routing):
c = tf.nn.softmax(self.bias, dim=2) // dim=2 is the num_capsule dimension
// outputs.shape=[None, 1, num_capsule, 1, dim_vector]
outputs = squash(K.sum(c * inputs_hat, 1, keepdims=True))
// last iteration needs not compute bias which will not be passed to the graph any more anyway.
if i != self.num_routing - 1:
// self.bias = K.update_add(self.bias, K.sum(inputs_hat * outputs, [0, -1], keepdims=True))
self.bias += K.sum(inputs_hat * outputs, -1, keepdims=True)
// tf.summary.histogram("BigBee", self.bias) // for debugging
// End: routing algorithm V2, static ------------------------------------------------------------//
return K.reshape(outputs, [-1, self.num_capsule, self.dim_vector])
After Change
// Begin: Routing algorithm ---------------------------------------------------------------------//
// In forward pass, `inputs_hat_stopped` = `inputs_hat`;
// In backward, no gradient can flow from `inputs_hat_stopped` back to `inputs_hat`.
inputs_hat_stopped = K.stop_gradient(inputs_hat)
// The prior for coupling coefficient, initialized as zeros.
b = K.zeros(shape=[self.batch_size, self.num_capsule, self.input_num_capsule])
assert self.num_routing > 0, "The num_routing should be > 0."
for i in range(self.num_routing):
// c.shape=[batch_size, num_capsule, input_num_capsule]
c = tf.nn.softmax(b, dim=1)
// At last iteration, use `inputs_hat` to compute `outputs` in order to backpropagate gradient
if i == self.num_routing - 1:
// c.shape = [batch_size, num_capsule, input_num_capsule]
// inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
// The first two dimensions as `batch` dimension,
// then matmal: [input_num_capsule] x [input_num_capsule, dim_capsule] -> [dim_capsule].
// outputs.shape=[None, num_capsule, dim_capsule]
outputs = squash(K.batch_dot(c, inputs_hat, [2, 2])) // [None, 10, 16]
else: // Otherwise, use `inputs_hat_stopped` to update `b`. No gradients flow on this path.
outputs = squash(K.batch_dot(c, inputs_hat_stopped, [2, 2]))
// outputs.shape = [None, num_capsule, dim_capsule]
// inputs_hat.shape=[None, num_capsule, input_num_capsule, dim_capsule]
// The first two dimensions as `batch` dimension,