5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9,tests/utils/test_ensemble.py,,test_fit_generator,#,128

Before Change



    graph = example_graph_1(feature_size=10)

    base_model, keras_model, generator, train_gen = create_graphSAGE_model(graph)

    ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)

    ens.compile(

After Change


    graph = example_graph_1(feature_size=10)

    // base_model, keras_model, generator, train_gen
    gnn_models = [
        create_graphSAGE_model(graph),
        create_GCN_model(graph),
        create_GAT_model(graph),
    ]

    for gnn_model in gnn_models:
        keras_model = gnn_model[1]
        generator = gnn_model[2]
        train_gen = gnn_model[3]

        ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)

        ens.compile(
            optimizer=Adam(), loss=categorical_crossentropy, weighted_metrics=["acc"]
        )

        // Specifying train_data and train_targets, implies the use of bagging so train_gen would
        // be of the wrong type for this call to fit_generator.
        with pytest.raises(ValueError):
            ens.fit_generator(
                train_gen,
                train_data=train_data,
                train_targets=train_targets,
                epochs=20,
                validation_generator=train_gen,
                verbose=0,
                shuffle=False,
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=train_data,
                train_targets=None,  // Should not be None
                epochs=20,
                validation_generator=train_gen,
                verbose=0,
                shuffle=False,
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=None,
                train_targets=None,
                epochs=20,
                validation_generator=None,
                verbose=0,
                shuffle=False,
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=train_data,
                train_targets=train_targets,
                epochs=20,
                validation_generator=None,
                verbose=0,
                shuffle=False,
                bag_size=-1,  // should be positive integer smaller than or equal to len(train_data) or None
            )

        with pytest.raises(ValueError):
            ens.fit_generator(
                generator=generator,
                train_data=train_data,
                train_targets=train_targets,
                epochs=20,
                validation_generator=None,
                verbose=0,
                shuffle=False,
                bag_size=10,  // larger than the number of training points
            )


def test_evaluate_generator():

    test_data = np.array([3, 4, 5])
    test_targets = np.array([[1, 0], [0, 1], [0, 1]])
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 13

Instances


Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_fit_generator


Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_evaluate_generator


Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_fit_generator


Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_predict_generator