50b78f02cfe965eceff936f739825f8eb981b8fb,CapsNet.py,,,#,66
Before Change
metrics={"out_caps": "accuracy"})
// begin training
model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
validation_data=[[x_test, y_test], [y_test, x_test]])
model.save("trained_model.h5")
print("Trained model saved to \"trained_model.h5\"")
After Change
// callbacks
log = callbacks.CSVLogger(args.save_dir + "/log.csv")
tb = callbacks.TensorBoard(log_dir=args.save_dir+"/tensorboard-logs", batch_size=args.batch_size)
checkpoint = callbacks.ModelCheckpoint(args.save_dir + "/weights-{epoch:02d}.h5", save_best_only=True, verbose=1)
// define model
model = CapsNet(input_shape=[28, 28, 1],
n_class=len(np.unique(np.argmax(y_train, 1))),
batch_size=args.batch_size)
model.summary()
model.compile(optimizer="adam",
loss=[margin_loss, "mse"],
loss_weights=[1., args.lam_recon],
metrics={"out_caps": "accuracy"})
// begin training
// Training without data augmentation:
// model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
// validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb])
// Training with data augmentation. If shift_fraction=0., also no augmentation.
model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
steps_per_epoch=int(y_train.shape[0] / args.batch_size),
epochs=args.epochs,
validation_data=[[x_test, y_test], [y_test, x_test]],
callbacks=[log, tb, checkpoint])
model.save(args.save_dir + "/trained_model.h5")
print("Trained model saved to \"trained_model.h5\"")
In pattern: SUPERPATTERN
Frequency: 3
Non-data size: 3
Instances
Project Name: XifengGuo/CapsNet-Keras
Commit Name: 50b78f02cfe965eceff936f739825f8eb981b8fb
Time: 2017-10-30
Author: guoxifeng1990@163.com
File Name: CapsNet.py
Class Name:
Method Name:
Project Name: calico/basenji
Commit Name: 24a7afa3c5f08235cae29ba47e457ac75223e28c
Time: 2019-06-21
Author: drk@calicolabs.com
File Name: basenji/trainer.py
Class Name: Trainer
Method Name: fit
Project Name: autorope/donkeycar
Commit Name: d70ee60d35d7e0e004b885e6f6062fb51916dad1
Time: 2020-12-17
Author: 47540921+DocGarbanzo@users.noreply.github.com
File Name: donkeycar/parts/keras.py
Class Name: KerasPilot
Method Name: train