def aggregate(self, inputs, index, ptr=None, dim_size=None):
//this is probably not needed
if ptr is not None:
raise ValueError("`GravNetConv` does not support `ptr` in aggregate.")
return torch.cat([scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,reduce="mean"),
scatter(inputs, index, dim=self.node_dim, dim_size=dim_size,reduce="max")], dim=1)
def message(self, x_i, x_j, edge_weights):
After Change
def aggregate(self, inputs, index, ptr=None, dim_size=None):
if ptr is not None:
for _ in range(self.node_dim):
ptr = ptr.unsqueeze(0)
aggr_mean = segment_csr(inputs, ptr, reduce="mean")
aggr_max = segment_csr(inputs, ptr, reduce="max")
else:
aggr_mean = scatter(inputs, index, dim=self.node_dim,