X = inputs[0] // input graph (B x F)
G = inputs[1] // full graph (N x F) (this is necessary in code, but not in theory. Check section 2.2 of the paper)
B = K.shape(X)[0] // Get batch size at runtime
N = K.shape(G)[0] // Get number of nodes in the graph at runtime
outputs = [] // Will store the outputs of each attention head (B x F" or B x KF")
for head in range(self.attention_heads):
kernel = self.kernels[head] // W in the paper (F x F")
attention_kernel = self.attention_kernels[head] // Attention network a in paper (2*F" x 1)
// Compute inputs to attention network
linear_transf_X = K.dot(X, kernel) // B x F"
linear_transf_G = K.dot(G, kernel) // N x F"
// Repeat feature vectors of input: [[1], [2]] becomes [[1], [1], [2], [2]]
repeated = K.reshape(K.tile(linear_transf_X, [1, N]), (-1, self.F_)) // B*N x F"
// Tile feature vectors of full graph: [[1], [2]] becomes [[1], [2], [1], [2]]
tiled = K.tile(linear_transf_G, [B, 1]) // B*N x F"
// Build combinations
combinations = K.concatenate([repeated, tiled]) // N*B x 2F"
combination_slices = K.reshape(combinations, (B, -1, 2 * self.F_)) // B x N x 2F"
// Attention head
dense = K.dot(combination_slices, attention_kernel) // B x N x 1 (a(Wh_i, Wh_j) in the paper)
dense = K.squeeze(dense, -1) // B x N
dense = K.softmax(dense) // B x N
// TODO: masking with Vaswani method (add -inf to masked coefficients)