// Get sampling schema for head nodes
sampling_schema = self._sampling_schema
if sampling_schema is None:
head_node_types = set([self.schema.get_node_type(n) for n in head_nodes])
if len(head_node_types) > 1:
raise ValueError(
"Only a single head node type is currently supported for HinSAGE models"
After Change
// Get features
batch_feats = [
self.graph.get_feature_for_nodes(layer_nodes, nt)
for nt, layer_nodes in nodes_by_type
]
// Resize features to (batch_size, n_neighbours, feature_size)
batch_feats = [
np.reshape(a, (len(head_nodes), -1 if np.size(a) > 0 else 0, a.shape[1]))
for a in batch_feats
]