5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9,tests/utils/test_ensemble.py,,test_fit_generator,#,128
Before Change
graph = example_graph_1(feature_size=10)
base_model, keras_model, generator, train_gen = create_graphSAGE_model(graph)
ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)
After Change
graph = example_graph_1(feature_size=10)
// base_model, keras_model, generator, train_gen
gnn_models = [
create_graphSAGE_model(graph),
create_GCN_model(graph),
create_GAT_model(graph),
]
for gnn_model in gnn_models:
keras_model = gnn_model[1]
generator = gnn_model[2]
train_gen = gnn_model[3]
ens = Ensemble(keras_model, n_estimators=2, n_predictions=1)
ens.compile(
optimizer=Adam(), loss=categorical_crossentropy, weighted_metrics=["acc"]
)
// Specifying train_data and train_targets, implies the use of bagging so train_gen would
// be of the wrong type for this call to fit_generator.
with pytest.raises(ValueError):
ens.fit_generator(
train_gen,
train_data=train_data,
train_targets=train_targets,
epochs=20,
validation_generator=train_gen,
verbose=0,
shuffle=False,
)
with pytest.raises(ValueError):
ens.fit_generator(
generator=generator,
train_data=train_data,
train_targets=None, // Should not be None
epochs=20,
validation_generator=train_gen,
verbose=0,
shuffle=False,
)
with pytest.raises(ValueError):
ens.fit_generator(
generator=generator,
train_data=None,
train_targets=None,
epochs=20,
validation_generator=None,
verbose=0,
shuffle=False,
)
with pytest.raises(ValueError):
ens.fit_generator(
generator=generator,
train_data=train_data,
train_targets=train_targets,
epochs=20,
validation_generator=None,
verbose=0,
shuffle=False,
bag_size=-1, // should be positive integer smaller than or equal to len(train_data) or None
)
with pytest.raises(ValueError):
ens.fit_generator(
generator=generator,
train_data=train_data,
train_targets=train_targets,
epochs=20,
validation_generator=None,
verbose=0,
shuffle=False,
bag_size=10, // larger than the number of training points
)
def test_evaluate_generator():
test_data = np.array([3, 4, 5])
test_targets = np.array([[1, 0], [0, 1], [0, 1]])
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 12
Instances
Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_fit_generator
Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_fit_generator
Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_evaluate_generator
Project Name: stellargraph/stellargraph
Commit Name: 5cfd9a2f9beb209a8ec926a113c50f75c9e2b4a9
Time: 2019-03-03
Author: pantelis.elinas@data61.csiro.au
File Name: tests/utils/test_ensemble.py
Class Name:
Method Name: test_predict_generator