out = scatter_(self.aggr, out, row, dim_size=x.size(0))
if self.root is not None:
out = out + torch.mm(x, self.root)
if self.bias is not None:
out = out + self.biasreturn out
def __repr__(self):
return "{}({}, {})".format(self.__class__.__name__, self.in_channels,
self.out_channels)
After Change
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
return self.propagate(edge_index, x=x, pseudo=pseudo)
def message(self, x_j, pseudo):
weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)