// the max value in x clipped to 1 and other to 0. Now `mask` is one-hot coding.
mask = K.clip(x, 0, 1)
return K.batch_flatten(inputs * mask) // masked inputs, shape = [None, num_capsule * dim_capsule]
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple: // true label provided
After Change
x = K.sqrt(K.sum(K.square(inputs), -1))
// generate the mask which is a one-hot code.
// mask.shape=[None, n_classes]=[None, num_capsule]
mask = K.one_hot(indices=K.argmax(x, 1), num_classes=x.get_shape().as_list()[1])
// inputs.shape=[None, num_capsule, dim_capsule]
// mask.shape=[None, num_capsule]
// masked.shape=[None, num_capsule * dim_capsule]
masked = K.batch_flatten(inputs * K.expand_dims(mask, -1))
return masked
def compute_output_shape(self, input_shape):
if type(input_shape[0]) is tuple: // true label provided