edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
x = x.unsqueeze(-1) if x.dim() == 1 else x
row, col = edge_index
x = torch.matmul(x, self.weight)
out = scatter_mean(x[col], row, dim=0, dim_size=x.size(0))
if self.bias is not None:
out = out + self.bias
if self.normalize:
out = F.normalize(out, p=2, dim=-1)
After Change
x = x.unsqueeze(-1) if x.dim() == 1 else x
x = torch.matmul(x, self.weight)
return self.propagate(edge_index, x=x)
def message(self, x_j):
return x_j