combination_slices = tf.unstack(K.reshape(combinations, (B, -1, 2 * self.F_)))
output_features = []
for slice in combination_slices:
dense = Dense(1)(slice) // N x 1 (basically "a(Wh_i, Wh_j)" in the paper)
// TODO: masking
e_i = K.reshape(dense, (1, -1)) // 1 x N (e_i in the paper)
softmax = K.squeeze(K.softmax(e_i)) // N (alpha_i in the paper)
softmax_broadcast = K.transpose(K.reshape(K.tile(softmax, [self.F_]), [self.F_, -1]))
node_features = K.sum(softmax_broadcast * linear_transf, axis=0)
if self.use_bias:
output = K.bias_add(node_features, self.bias)
if self.heads_combination == "concat" and self.activation is not None:
After Change
X = inputs[0] // input graph (B x F)
G = inputs[1] // full graph (N x F) (this is necessary in code, but not in theory. Check section 2.2 of the paper)
B = K.shape(X)[0] // Get batch size at runtime
N = K.shape(G)[0] // Get number of nodes in the graph at runtime
outputs = [] // Will store the outputs of each attention head (B x F" or B x KF")
for head in range(self.attention_heads):