145170ca9bbd89aa01d8a40841e3c039d3683af8,tests/layer/test_graph_attention.py,Test_GAT,test_gat_serialize,#Test_GAT#,451

Before Change


        model2.set_weights(model_weights)

        // Test loaded model
        X = gen.features
        A = gen.Aadj
        actual = model2.predict([X, A])
        expected = np.ones((G.number_of_nodes(), self.layer_sizes[-1])) * (
            1.0 / G.number_of_nodes()
        )
        assert expected == pytest.approx(actual)

After Change



    def test_gat_serialize(self):
        G = example_graph_1(feature_size=self.F_in)
        gen = FullBatchNodeGenerator(G, sparse=False)
        gat = GAT(
            layer_sizes=self.layer_sizes,
            activations=self.activations,
            attn_heads=self.attn_heads,
            generator=gen,
            bias=True,
            normalize="l2",
        )

        x_in, x_out = gat.node_model()
        model = keras.Model(inputs=x_in, outputs=x_out)

        ng = gen.flow(G.nodes())

        // Save model
        model_json = model.to_json()

        // Set all weights to one
        model_weights = [np.ones_like(w) for w in model.get_weights()]

        // Load model from json & set all weights
        model2 = keras.models.model_from_json(
            model_json, custom_objects={"GraphAttention": GraphAttention}
        )
        model2.set_weights(model_weights)

        // Test deserialized model
        actual = model2.predict_generator(ng)
        expected = np.ones((G.number_of_nodes(), self.layer_sizes[-1])) * (
            1.0 / G.number_of_nodes()
        )
        assert np.allclose(expected, actual[0])
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 12

Instances


Project Name: stellargraph/stellargraph
Commit Name: 145170ca9bbd89aa01d8a40841e3c039d3683af8
Time: 2019-06-03
Author: andrew.docherty@data61.csiro.au
File Name: tests/layer/test_graph_attention.py
Class Name: Test_GAT
Method Name: test_gat_serialize


Project Name: stellargraph/stellargraph
Commit Name: 145170ca9bbd89aa01d8a40841e3c039d3683af8
Time: 2019-06-03
Author: andrew.docherty@data61.csiro.au
File Name: tests/layer/test_graph_attention.py
Class Name: Test_GAT
Method Name: test_gat_serialize


Project Name: stellargraph/stellargraph
Commit Name: 145170ca9bbd89aa01d8a40841e3c039d3683af8
Time: 2019-06-03
Author: andrew.docherty@data61.csiro.au
File Name: tests/layer/test_graph_attention.py
Class Name: Test_GAT
Method Name: test_gat_node_model_l2norm


Project Name: stellargraph/stellargraph
Commit Name: 145170ca9bbd89aa01d8a40841e3c039d3683af8
Time: 2019-06-03
Author: andrew.docherty@data61.csiro.au
File Name: tests/layer/test_graph_attention.py
Class Name: Test_GAT
Method Name: test_gat_node_model_no_norm