x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = pseudo.unsqueeze(-1) if pseudo.dim() == 1 else pseudo
row, col = edge_index
out = self.nn(pseudo)
out = out.view(-1, self.in_channels, self.out_channels)
out = torch.matmul(x[col].unsqueeze(1), out).squeeze(1)
out = scatter_(self.aggr, out, row, dim_size=x.size(0))
if self.root is not None:
out = out + torch.mm(x, self.root)
if self.bias is not None:
out = out + self.bias
return out
def __repr__(self):
After Change
def forward(self, x, edge_index, edge_attr):
x = x.unsqueeze(-1) if x.dim() == 1 else x
pseudo = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
return self.propagate(edge_index, x=x, pseudo=pseudo)
def message(self, x_j, pseudo):
weight = self.nn(pseudo).view(-1, self.in_channels, self.out_channels)
return torch.matmul(x_j.unsqueeze(1), weight).squeeze(1)