50b78f02cfe965eceff936f739825f8eb981b8fb,CapsNet.py,,,#,66

Before Change


                  metrics={"out_caps": "accuracy"})

    // begin training
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]])
    model.save("trained_model.h5")
    print("Trained model saved to \"trained_model.h5\"")

After Change


    // callbacks
    log = callbacks.CSVLogger(args.save_dir + "/log.csv")
    tb = callbacks.TensorBoard(log_dir=args.save_dir+"/tensorboard-logs", batch_size=args.batch_size)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + "/weights-{epoch:02d}.h5", save_best_only=True, verbose=1)

    // define model
    model = CapsNet(input_shape=[28, 28, 1],
                    n_class=len(np.unique(np.argmax(y_train, 1))),
                    batch_size=args.batch_size)
    model.summary()
    model.compile(optimizer="adam",
                  loss=[margin_loss, "mse"],
                  loss_weights=[1., args.lam_recon],
                  metrics={"out_caps": "accuracy"})

    // begin training
    // Training without data augmentation:
    // model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
    //           validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb])

    // Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, checkpoint])
    model.save(args.save_dir + "/trained_model.h5")
    print("Trained model saved to \"trained_model.h5\"")
Italian Trulli
In pattern: SUPERPATTERN

Frequency: 3

Non-data size: 3

Instances


Project Name: XifengGuo/CapsNet-Keras
Commit Name: 50b78f02cfe965eceff936f739825f8eb981b8fb
Time: 2017-10-30
Author: guoxifeng1990@163.com
File Name: CapsNet.py
Class Name:
Method Name:


Project Name: calico/basenji
Commit Name: 24a7afa3c5f08235cae29ba47e457ac75223e28c
Time: 2019-06-21
Author: drk@calicolabs.com
File Name: basenji/trainer.py
Class Name: Trainer
Method Name: fit


Project Name: autorope/donkeycar
Commit Name: d70ee60d35d7e0e004b885e6f6062fb51916dad1
Time: 2020-12-17
Author: 47540921+DocGarbanzo@users.noreply.github.com
File Name: donkeycar/parts/keras.py
Class Name: KerasPilot
Method Name: train